diff options
Diffstat (limited to 'config/parse.go')
-rw-r--r-- | config/parse.go | 233 |
1 files changed, 233 insertions, 0 deletions
diff --git a/config/parse.go b/config/parse.go new file mode 100644 index 00000000..d836c760 --- /dev/null +++ b/config/parse.go @@ -0,0 +1,233 @@ +package config + +import ( + "errors" + "fmt" + "reflect" + "regexp" + + "git.sr.ht/~rjarry/aerc/lib/templates" + "github.com/emersion/go-message/mail" + "github.com/go-ini/ini" +) + +func MapToStruct(s *ini.Section, v interface{}, useDefaults bool) error { + typ := reflect.TypeOf(v) + val := reflect.ValueOf(v) + if typ.Kind() == reflect.Ptr { + typ = typ.Elem() + val = val.Elem() + } else { + panic("MapToStruct requires a pointer") + } + if typ.Kind() != reflect.Struct { + panic("MapToStruct requires a pointer to a struct") + } + + for i := 0; i < typ.NumField(); i++ { + fieldVal := val.Field(i) + fieldType := typ.Field(i) + + name := fieldType.Tag.Get("ini") + if name == "" || name == "-" { + continue + } + key, err := s.GetKey(name) + if err != nil { + defValue, found := fieldType.Tag.Lookup("default") + if useDefaults && found { + key, _ = s.NewKey(name, defValue) + } else { + continue + } + } + err = setField(s, key, reflect.ValueOf(v), fieldVal, fieldType) + if err != nil { + return fmt.Errorf("[%s].%s: %w", s.Name(), name, err) + } + } + return nil +} + +func setField( + s *ini.Section, key *ini.Key, struc reflect.Value, + fieldVal reflect.Value, fieldType reflect.StructField, +) error { + var methodValue reflect.Value + method := getParseMethod(s, key, struc, fieldType) + if method.IsValid() { + in := []reflect.Value{reflect.ValueOf(s), reflect.ValueOf(key)} + out := method.Call(in) + err, _ := out[1].Interface().(error) + if err != nil { + return err + } + methodValue = out[0] + } + + ft := fieldType.Type + + switch ft.Kind() { + case reflect.String: + if method.IsValid() { + fieldVal.SetString(methodValue.String()) + } else { + fieldVal.SetString(key.String()) + } + case reflect.Bool: + if method.IsValid() { + fieldVal.SetBool(methodValue.Bool()) + } else { + boolVal, err := key.Bool() + if err != nil { + return err + } + fieldVal.SetBool(boolVal) + } + case reflect.Int32: + // impossible to differentiate rune from int32, they are aliases + // this is an ugly hack but there is no alternative... + if fieldType.Tag.Get("type") == "rune" { + if method.IsValid() { + fieldVal.Set(methodValue) + } else { + runes := []rune(key.String()) + if len(runes) != 1 { + return errors.New("value must be 1 character long") + } + fieldVal.Set(reflect.ValueOf(runes[0])) + } + return nil + } + fallthrough + case reflect.Int64: + // ParseDuration will not return err for `0`, so check the type name + if ft.PkgPath() == "time" && ft.Name() == "Duration" { + durationVal, err := key.Duration() + if err != nil { + return err + } + fieldVal.Set(reflect.ValueOf(durationVal)) + return nil + } + fallthrough + case reflect.Int, reflect.Int8, reflect.Int16: + if method.IsValid() { + fieldVal.SetInt(methodValue.Int()) + } else { + intVal, err := key.Int64() + if err != nil { + return err + } + fieldVal.SetInt(intVal) + } + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + if method.IsValid() { + fieldVal.SetUint(methodValue.Uint()) + } else { + uintVal, err := key.Uint64() + if err != nil { + return err + } + fieldVal.SetUint(uintVal) + } + case reflect.Float32, reflect.Float64: + if method.IsValid() { + fieldVal.SetFloat(methodValue.Float()) + } else { + floatVal, err := key.Float64() + if err != nil { + return err + } + fieldVal.SetFloat(floatVal) + } + case reflect.Slice, reflect.Array: + switch { + case method.IsValid(): + fieldVal.Set(methodValue) + case ft.Elem().Kind() == reflect.Ptr && + typePath(ft.Elem().Elem()) == "net/mail.Address": + addrs, err := mail.ParseAddressList(key.String()) + if err != nil { + return err + } + fieldVal.Set(reflect.ValueOf(addrs)) + case ft.Elem().Kind() == reflect.String: + delim := fieldType.Tag.Get("delim") + fieldVal.Set(reflect.ValueOf(key.Strings(delim))) + default: + panic(fmt.Sprintf("unsupported type []%s", typePath(ft.Elem()))) + } + case reflect.Struct: + if method.IsValid() { + fieldVal.Set(methodValue) + } else { + panic(fmt.Sprintf("unsupported type %s", typePath(ft))) + } + case reflect.Ptr: + if method.IsValid() { + fieldVal.Set(methodValue) + } else { + switch typePath(ft.Elem()) { + case "net/mail.Address": + addr, err := mail.ParseAddress(key.String()) + if err != nil { + return err + } + fieldVal.Set(reflect.ValueOf(addr)) + case "regexp.Regexp": + r, err := regexp.Compile(key.String()) + if err != nil { + return err + } + fieldVal.Set(reflect.ValueOf(r)) + case "text/template.Template": + t, err := templates.ParseTemplate(key.String(), key.String()) + if err != nil { + return err + } + fieldVal.Set(reflect.ValueOf(t)) + default: + panic(fmt.Sprintf("unsupported type %s", typePath(ft))) + } + } + default: + panic(fmt.Sprintf("unsupported type %s", typePath(ft))) + } + return nil +} + +func getParseMethod( + section *ini.Section, key *ini.Key, + struc reflect.Value, typ reflect.StructField, +) reflect.Value { + methodName, found := typ.Tag.Lookup("parse") + if !found { + return reflect.Value{} + } + method := struc.MethodByName(methodName) + if !method.IsValid() { + panic(fmt.Sprintf("(*%s).%s: method not found", + struc, methodName)) + } + + if method.Type().NumIn() != 2 || + method.Type().In(0) != reflect.TypeOf(section) || + method.Type().In(1) != reflect.TypeOf(key) || + method.Type().NumOut() != 2 { + panic(fmt.Sprintf("(*%s).%s: invalid signature, expected %s", + struc.Elem().Type().Name(), methodName, + "func(*ini.Section, *ini.Key) (any, error)")) + } + + return method +} + +func typePath(t reflect.Type) string { + var prefix string + if t.Kind() == reflect.Ptr { + t = t.Elem() + prefix = "*" + } + return fmt.Sprintf("%s%s.%s", prefix, t.PkgPath(), t.Name()) +} |