aboutsummaryrefslogtreecommitdiffstats
path: root/config/parse.go
diff options
context:
space:
mode:
Diffstat (limited to 'config/parse.go')
-rw-r--r--config/parse.go233
1 files changed, 233 insertions, 0 deletions
diff --git a/config/parse.go b/config/parse.go
new file mode 100644
index 00000000..d836c760
--- /dev/null
+++ b/config/parse.go
@@ -0,0 +1,233 @@
+package config
+
+import (
+ "errors"
+ "fmt"
+ "reflect"
+ "regexp"
+
+ "git.sr.ht/~rjarry/aerc/lib/templates"
+ "github.com/emersion/go-message/mail"
+ "github.com/go-ini/ini"
+)
+
+func MapToStruct(s *ini.Section, v interface{}, useDefaults bool) error {
+ typ := reflect.TypeOf(v)
+ val := reflect.ValueOf(v)
+ if typ.Kind() == reflect.Ptr {
+ typ = typ.Elem()
+ val = val.Elem()
+ } else {
+ panic("MapToStruct requires a pointer")
+ }
+ if typ.Kind() != reflect.Struct {
+ panic("MapToStruct requires a pointer to a struct")
+ }
+
+ for i := 0; i < typ.NumField(); i++ {
+ fieldVal := val.Field(i)
+ fieldType := typ.Field(i)
+
+ name := fieldType.Tag.Get("ini")
+ if name == "" || name == "-" {
+ continue
+ }
+ key, err := s.GetKey(name)
+ if err != nil {
+ defValue, found := fieldType.Tag.Lookup("default")
+ if useDefaults && found {
+ key, _ = s.NewKey(name, defValue)
+ } else {
+ continue
+ }
+ }
+ err = setField(s, key, reflect.ValueOf(v), fieldVal, fieldType)
+ if err != nil {
+ return fmt.Errorf("[%s].%s: %w", s.Name(), name, err)
+ }
+ }
+ return nil
+}
+
+func setField(
+ s *ini.Section, key *ini.Key, struc reflect.Value,
+ fieldVal reflect.Value, fieldType reflect.StructField,
+) error {
+ var methodValue reflect.Value
+ method := getParseMethod(s, key, struc, fieldType)
+ if method.IsValid() {
+ in := []reflect.Value{reflect.ValueOf(s), reflect.ValueOf(key)}
+ out := method.Call(in)
+ err, _ := out[1].Interface().(error)
+ if err != nil {
+ return err
+ }
+ methodValue = out[0]
+ }
+
+ ft := fieldType.Type
+
+ switch ft.Kind() {
+ case reflect.String:
+ if method.IsValid() {
+ fieldVal.SetString(methodValue.String())
+ } else {
+ fieldVal.SetString(key.String())
+ }
+ case reflect.Bool:
+ if method.IsValid() {
+ fieldVal.SetBool(methodValue.Bool())
+ } else {
+ boolVal, err := key.Bool()
+ if err != nil {
+ return err
+ }
+ fieldVal.SetBool(boolVal)
+ }
+ case reflect.Int32:
+ // impossible to differentiate rune from int32, they are aliases
+ // this is an ugly hack but there is no alternative...
+ if fieldType.Tag.Get("type") == "rune" {
+ if method.IsValid() {
+ fieldVal.Set(methodValue)
+ } else {
+ runes := []rune(key.String())
+ if len(runes) != 1 {
+ return errors.New("value must be 1 character long")
+ }
+ fieldVal.Set(reflect.ValueOf(runes[0]))
+ }
+ return nil
+ }
+ fallthrough
+ case reflect.Int64:
+ // ParseDuration will not return err for `0`, so check the type name
+ if ft.PkgPath() == "time" && ft.Name() == "Duration" {
+ durationVal, err := key.Duration()
+ if err != nil {
+ return err
+ }
+ fieldVal.Set(reflect.ValueOf(durationVal))
+ return nil
+ }
+ fallthrough
+ case reflect.Int, reflect.Int8, reflect.Int16:
+ if method.IsValid() {
+ fieldVal.SetInt(methodValue.Int())
+ } else {
+ intVal, err := key.Int64()
+ if err != nil {
+ return err
+ }
+ fieldVal.SetInt(intVal)
+ }
+ case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+ if method.IsValid() {
+ fieldVal.SetUint(methodValue.Uint())
+ } else {
+ uintVal, err := key.Uint64()
+ if err != nil {
+ return err
+ }
+ fieldVal.SetUint(uintVal)
+ }
+ case reflect.Float32, reflect.Float64:
+ if method.IsValid() {
+ fieldVal.SetFloat(methodValue.Float())
+ } else {
+ floatVal, err := key.Float64()
+ if err != nil {
+ return err
+ }
+ fieldVal.SetFloat(floatVal)
+ }
+ case reflect.Slice, reflect.Array:
+ switch {
+ case method.IsValid():
+ fieldVal.Set(methodValue)
+ case ft.Elem().Kind() == reflect.Ptr &&
+ typePath(ft.Elem().Elem()) == "net/mail.Address":
+ addrs, err := mail.ParseAddressList(key.String())
+ if err != nil {
+ return err
+ }
+ fieldVal.Set(reflect.ValueOf(addrs))
+ case ft.Elem().Kind() == reflect.String:
+ delim := fieldType.Tag.Get("delim")
+ fieldVal.Set(reflect.ValueOf(key.Strings(delim)))
+ default:
+ panic(fmt.Sprintf("unsupported type []%s", typePath(ft.Elem())))
+ }
+ case reflect.Struct:
+ if method.IsValid() {
+ fieldVal.Set(methodValue)
+ } else {
+ panic(fmt.Sprintf("unsupported type %s", typePath(ft)))
+ }
+ case reflect.Ptr:
+ if method.IsValid() {
+ fieldVal.Set(methodValue)
+ } else {
+ switch typePath(ft.Elem()) {
+ case "net/mail.Address":
+ addr, err := mail.ParseAddress(key.String())
+ if err != nil {
+ return err
+ }
+ fieldVal.Set(reflect.ValueOf(addr))
+ case "regexp.Regexp":
+ r, err := regexp.Compile(key.String())
+ if err != nil {
+ return err
+ }
+ fieldVal.Set(reflect.ValueOf(r))
+ case "text/template.Template":
+ t, err := templates.ParseTemplate(key.String(), key.String())
+ if err != nil {
+ return err
+ }
+ fieldVal.Set(reflect.ValueOf(t))
+ default:
+ panic(fmt.Sprintf("unsupported type %s", typePath(ft)))
+ }
+ }
+ default:
+ panic(fmt.Sprintf("unsupported type %s", typePath(ft)))
+ }
+ return nil
+}
+
+func getParseMethod(
+ section *ini.Section, key *ini.Key,
+ struc reflect.Value, typ reflect.StructField,
+) reflect.Value {
+ methodName, found := typ.Tag.Lookup("parse")
+ if !found {
+ return reflect.Value{}
+ }
+ method := struc.MethodByName(methodName)
+ if !method.IsValid() {
+ panic(fmt.Sprintf("(*%s).%s: method not found",
+ struc, methodName))
+ }
+
+ if method.Type().NumIn() != 2 ||
+ method.Type().In(0) != reflect.TypeOf(section) ||
+ method.Type().In(1) != reflect.TypeOf(key) ||
+ method.Type().NumOut() != 2 {
+ panic(fmt.Sprintf("(*%s).%s: invalid signature, expected %s",
+ struc.Elem().Type().Name(), methodName,
+ "func(*ini.Section, *ini.Key) (any, error)"))
+ }
+
+ return method
+}
+
+func typePath(t reflect.Type) string {
+ var prefix string
+ if t.Kind() == reflect.Ptr {
+ t = t.Elem()
+ prefix = "*"
+ }
+ return fmt.Sprintf("%s%s.%s", prefix, t.PkgPath(), t.Name())
+}