aboutsummaryrefslogblamecommitdiffstats
path: root/config/parse.go
blob: d836c760193579c99c4364cdd110ae6b5bc71d4a (plain) (tree)








































































































































































































































                                                                                             
package config

import (
	"errors"
	"fmt"
	"reflect"
	"regexp"

	"git.sr.ht/~rjarry/aerc/lib/templates"
	"github.com/emersion/go-message/mail"
	"github.com/go-ini/ini"
)

func MapToStruct(s *ini.Section, v interface{}, useDefaults bool) error {
	typ := reflect.TypeOf(v)
	val := reflect.ValueOf(v)
	if typ.Kind() == reflect.Ptr {
		typ = typ.Elem()
		val = val.Elem()
	} else {
		panic("MapToStruct requires a pointer")
	}
	if typ.Kind() != reflect.Struct {
		panic("MapToStruct requires a pointer to a struct")
	}

	for i := 0; i < typ.NumField(); i++ {
		fieldVal := val.Field(i)
		fieldType := typ.Field(i)

		name := fieldType.Tag.Get("ini")
		if name == "" || name == "-" {
			continue
		}
		key, err := s.GetKey(name)
		if err != nil {
			defValue, found := fieldType.Tag.Lookup("default")
			if useDefaults && found {
				key, _ = s.NewKey(name, defValue)
			} else {
				continue
			}
		}
		err = setField(s, key, reflect.ValueOf(v), fieldVal, fieldType)
		if err != nil {
			return fmt.Errorf("[%s].%s: %w", s.Name(), name, err)
		}
	}
	return nil
}

func setField(
	s *ini.Section, key *ini.Key, struc reflect.Value,
	fieldVal reflect.Value, fieldType reflect.StructField,
) error {
	var methodValue reflect.Value
	method := getParseMethod(s, key, struc, fieldType)
	if method.IsValid() {
		in := []reflect.Value{reflect.ValueOf(s), reflect.ValueOf(key)}
		out := method.Call(in)
		err, _ := out[1].Interface().(error)
		if err != nil {
			return err
		}
		methodValue = out[0]
	}

	ft := fieldType.Type

	switch ft.Kind() {
	case reflect.String:
		if method.IsValid() {
			fieldVal.SetString(methodValue.String())
		} else {
			fieldVal.SetString(key.String())
		}
	case reflect.Bool:
		if method.IsValid() {
			fieldVal.SetBool(methodValue.Bool())
		} else {
			boolVal, err := key.Bool()
			if err != nil {
				return err
			}
			fieldVal.SetBool(boolVal)
		}
	case reflect.Int32:
		// impossible to differentiate rune from int32, they are aliases
		// this is an ugly hack but there is no alternative...
		if fieldType.Tag.Get("type") == "rune" {
			if method.IsValid() {
				fieldVal.Set(methodValue)
			} else {
				runes := []rune(key.String())
				if len(runes) != 1 {
					return errors.New("value must be 1 character long")
				}
				fieldVal.Set(reflect.ValueOf(runes[0]))
			}
			return nil
		}
		fallthrough
	case reflect.Int64:
		// ParseDuration will not return err for `0`, so check the type name
		if ft.PkgPath() == "time" && ft.Name() == "Duration" {
			durationVal, err := key.Duration()
			if err != nil {
				return err
			}
			fieldVal.Set(reflect.ValueOf(durationVal))
			return nil
		}
		fallthrough
	case reflect.Int, reflect.Int8, reflect.Int16:
		if method.IsValid() {
			fieldVal.SetInt(methodValue.Int())
		} else {
			intVal, err := key.Int64()
			if err != nil {
				return err
			}
			fieldVal.SetInt(intVal)
		}
	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
		if method.IsValid() {
			fieldVal.SetUint(methodValue.Uint())
		} else {
			uintVal, err := key.Uint64()
			if err != nil {
				return err
			}
			fieldVal.SetUint(uintVal)
		}
	case reflect.Float32, reflect.Float64:
		if method.IsValid() {
			fieldVal.SetFloat(methodValue.Float())
		} else {
			floatVal, err := key.Float64()
			if err != nil {
				return err
			}
			fieldVal.SetFloat(floatVal)
		}
	case reflect.Slice, reflect.Array:
		switch {
		case method.IsValid():
			fieldVal.Set(methodValue)
		case ft.Elem().Kind() == reflect.Ptr &&
			typePath(ft.Elem().Elem()) == "net/mail.Address":
			addrs, err := mail.ParseAddressList(key.String())
			if err != nil {
				return err
			}
			fieldVal.Set(reflect.ValueOf(addrs))
		case ft.Elem().Kind() == reflect.String:
			delim := fieldType.Tag.Get("delim")
			fieldVal.Set(reflect.ValueOf(key.Strings(delim)))
		default:
			panic(fmt.Sprintf("unsupported type []%s", typePath(ft.Elem())))
		}
	case reflect.Struct:
		if method.IsValid() {
			fieldVal.Set(methodValue)
		} else {
			panic(fmt.Sprintf("unsupported type %s", typePath(ft)))
		}
	case reflect.Ptr:
		if method.IsValid() {
			fieldVal.Set(methodValue)
		} else {
			switch typePath(ft.Elem()) {
			case "net/mail.Address":
				addr, err := mail.ParseAddress(key.String())
				if err != nil {
					return err
				}
				fieldVal.Set(reflect.ValueOf(addr))
			case "regexp.Regexp":
				r, err := regexp.Compile(key.String())
				if err != nil {
					return err
				}
				fieldVal.Set(reflect.ValueOf(r))
			case "text/template.Template":
				t, err := templates.ParseTemplate(key.String(), key.String())
				if err != nil {
					return err
				}
				fieldVal.Set(reflect.ValueOf(t))
			default:
				panic(fmt.Sprintf("unsupported type %s", typePath(ft)))
			}
		}
	default:
		panic(fmt.Sprintf("unsupported type %s", typePath(ft)))
	}
	return nil
}

func getParseMethod(
	section *ini.Section, key *ini.Key,
	struc reflect.Value, typ reflect.StructField,
) reflect.Value {
	methodName, found := typ.Tag.Lookup("parse")
	if !found {
		return reflect.Value{}
	}
	method := struc.MethodByName(methodName)
	if !method.IsValid() {
		panic(fmt.Sprintf("(*%s).%s: method not found",
			struc, methodName))
	}

	if method.Type().NumIn() != 2 ||
		method.Type().In(0) != reflect.TypeOf(section) ||
		method.Type().In(1) != reflect.TypeOf(key) ||
		method.Type().NumOut() != 2 {
		panic(fmt.Sprintf("(*%s).%s: invalid signature, expected %s",
			struc.Elem().Type().Name(), methodName,
			"func(*ini.Section, *ini.Key) (any, error)"))
	}

	return method
}

func typePath(t reflect.Type) string {
	var prefix string
	if t.Kind() == reflect.Ptr {
		t = t.Elem()
		prefix = "*"
	}
	return fmt.Sprintf("%s%s.%s", prefix, t.PkgPath(), t.Name())
}