diff options
Diffstat (limited to 'vendor/github.com/99designs/gqlgen/codegen/util.go')
-rw-r--r-- | vendor/github.com/99designs/gqlgen/codegen/util.go | 68 |
1 files changed, 44 insertions, 24 deletions
diff --git a/vendor/github.com/99designs/gqlgen/codegen/util.go b/vendor/github.com/99designs/gqlgen/codegen/util.go index 1849f100..cc6246fd 100644 --- a/vendor/github.com/99designs/gqlgen/codegen/util.go +++ b/vendor/github.com/99designs/gqlgen/codegen/util.go @@ -105,6 +105,12 @@ func findMethod(typ *types.Named, name string) *types.Func { return nil } +func equalFieldName(source, target string) bool { + source = strings.Replace(source, "_", "", -1) + target = strings.Replace(target, "_", "", -1) + return strings.EqualFold(source, target) +} + // findField attempts to match the name to a struct field with the following // priorites: // 1. If struct tag is passed then struct tag has highest priority @@ -120,7 +126,7 @@ func findField(typ *types.Struct, name, structTag string) (*types.Var, error) { if structTag != "" { tags := reflect.StructTag(typ.Tag(i)) if val, ok := tags.Lookup(structTag); ok { - if strings.EqualFold(val, name) { + if equalFieldName(val, name) { if foundField != nil && foundFieldWasTag { return nil, errors.Errorf("tag %s is ambigious; multiple fields have the same tag value of %s", structTag, val) } @@ -132,17 +138,16 @@ func findField(typ *types.Struct, name, structTag string) (*types.Var, error) { } if field.Anonymous() { - if named, ok := field.Type().(*types.Struct); ok { - f, err := findField(named, name, structTag) - if err != nil && !strings.HasPrefix(err.Error(), "no field named") { - return nil, err - } - if f != nil && foundField == nil { - foundField = f - } + + fieldType := field.Type() + + if ptr, ok := fieldType.(*types.Pointer); ok { + fieldType = ptr.Elem() } - if named, ok := field.Type().Underlying().(*types.Struct); ok { + // Type.Underlying() returns itself for all types except types.Named, where it returns a struct type. + // It should be safe to always call. + if named, ok := fieldType.Underlying().(*types.Struct); ok { f, err := findField(named, name, structTag) if err != nil && !strings.HasPrefix(err.Error(), "no field named") { return nil, err @@ -157,7 +162,7 @@ func findField(typ *types.Struct, name, structTag string) (*types.Var, error) { continue } - if strings.EqualFold(field.Name(), name) && foundField == nil { + if equalFieldName(field.Name(), name) && foundField == nil { // aqui! foundField = field } } @@ -198,7 +203,7 @@ func (b BindErrors) Error() string { return strings.Join(errs, "\n\n") } -func bindObject(t types.Type, object *Object, imports *Imports, structTag string) BindErrors { +func bindObject(t types.Type, object *Object, structTag string) BindErrors { var errs BindErrors for i := range object.Fields { field := &object.Fields[i] @@ -208,13 +213,13 @@ func bindObject(t types.Type, object *Object, imports *Imports, structTag string } // first try binding to a method - methodErr := bindMethod(imports, t, field) + methodErr := bindMethod(t, field) if methodErr == nil { continue } // otherwise try binding to a var - varErr := bindVar(imports, t, field, structTag) + varErr := bindVar(t, field, structTag) if varErr != nil { errs = append(errs, BindError{ @@ -229,7 +234,7 @@ func bindObject(t types.Type, object *Object, imports *Imports, structTag string return errs } -func bindMethod(imports *Imports, t types.Type, field *Field) error { +func bindMethod(t types.Type, field *Field) error { namedType, ok := t.(*types.Named) if !ok { return fmt.Errorf("not a named type") @@ -250,13 +255,25 @@ func bindMethod(imports *Imports, t types.Type, field *Field) error { } else if sig.Results().Len() != 2 { return fmt.Errorf("method has wrong number of args") } - newArgs, err := matchArgs(field, sig.Params()) + params := sig.Params() + // If the first argument is the context, remove it from the comparison and set + // the MethodHasContext flag so that the context will be passed to this model's method + if params.Len() > 0 && params.At(0).Type().String() == "context.Context" { + field.MethodHasContext = true + vars := make([]*types.Var, params.Len()-1) + for i := 1; i < params.Len(); i++ { + vars[i-1] = params.At(i) + } + params = types.NewTuple(vars...) + } + + newArgs, err := matchArgs(field, params) if err != nil { return err } result := sig.Results().At(0) - if err := validateTypeBinding(imports, field, result.Type()); err != nil { + if err := validateTypeBinding(field, result.Type()); err != nil { return errors.Wrap(err, "method has wrong return type") } @@ -268,7 +285,7 @@ func bindMethod(imports *Imports, t types.Type, field *Field) error { return nil } -func bindVar(imports *Imports, t types.Type, field *Field, structTag string) error { +func bindVar(t types.Type, field *Field, structTag string) error { underlying, ok := t.Underlying().(*types.Struct) if !ok { return fmt.Errorf("not a struct") @@ -283,7 +300,7 @@ func bindVar(imports *Imports, t types.Type, field *Field, structTag string) err return err } - if err := validateTypeBinding(imports, field, structField.Type()); err != nil { + if err := validateTypeBinding(field, structField.Type()); err != nil { return errors.Wrap(err, "field has wrong type") } @@ -316,22 +333,21 @@ nextArg: return newArgs, nil } -func validateTypeBinding(imports *Imports, field *Field, goType types.Type) error { +func validateTypeBinding(field *Field, goType types.Type) error { gqlType := normalizeVendor(field.Type.FullSignature()) goTypeStr := normalizeVendor(goType.String()) - if goTypeStr == gqlType || "*"+goTypeStr == gqlType || goTypeStr == "*"+gqlType { + if equalTypes(goTypeStr, gqlType) { field.Type.Modifiers = modifiersFromGoType(goType) return nil } // deal with type aliases underlyingStr := normalizeVendor(goType.Underlying().String()) - if underlyingStr == gqlType || "*"+underlyingStr == gqlType || underlyingStr == "*"+gqlType { + if equalTypes(underlyingStr, gqlType) { field.Type.Modifiers = modifiersFromGoType(goType) pkg, typ := pkgAndType(goType.String()) - imp := imports.findByPath(pkg) - field.AliasedType = &Ref{GoType: typ, Import: imp} + field.AliasedType = &Ref{GoType: typ, Package: pkg} return nil } @@ -365,3 +381,7 @@ func normalizeVendor(pkg string) string { parts := strings.Split(pkg, "/vendor/") return modifiers + parts[len(parts)-1] } + +func equalTypes(goType string, gqlType string) bool { + return goType == gqlType || "*"+goType == gqlType || goType == "*"+gqlType || strings.Replace(goType, "[]*", "[]", -1) == gqlType +} |