aboutsummaryrefslogtreecommitdiffstats
path: root/vendor/github.com/99designs/gqlgen/codegen/util.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/99designs/gqlgen/codegen/util.go')
-rw-r--r--vendor/github.com/99designs/gqlgen/codegen/util.go68
1 files changed, 44 insertions, 24 deletions
diff --git a/vendor/github.com/99designs/gqlgen/codegen/util.go b/vendor/github.com/99designs/gqlgen/codegen/util.go
index 1849f100..cc6246fd 100644
--- a/vendor/github.com/99designs/gqlgen/codegen/util.go
+++ b/vendor/github.com/99designs/gqlgen/codegen/util.go
@@ -105,6 +105,12 @@ func findMethod(typ *types.Named, name string) *types.Func {
return nil
}
+func equalFieldName(source, target string) bool {
+ source = strings.Replace(source, "_", "", -1)
+ target = strings.Replace(target, "_", "", -1)
+ return strings.EqualFold(source, target)
+}
+
// findField attempts to match the name to a struct field with the following
// priorites:
// 1. If struct tag is passed then struct tag has highest priority
@@ -120,7 +126,7 @@ func findField(typ *types.Struct, name, structTag string) (*types.Var, error) {
if structTag != "" {
tags := reflect.StructTag(typ.Tag(i))
if val, ok := tags.Lookup(structTag); ok {
- if strings.EqualFold(val, name) {
+ if equalFieldName(val, name) {
if foundField != nil && foundFieldWasTag {
return nil, errors.Errorf("tag %s is ambigious; multiple fields have the same tag value of %s", structTag, val)
}
@@ -132,17 +138,16 @@ func findField(typ *types.Struct, name, structTag string) (*types.Var, error) {
}
if field.Anonymous() {
- if named, ok := field.Type().(*types.Struct); ok {
- f, err := findField(named, name, structTag)
- if err != nil && !strings.HasPrefix(err.Error(), "no field named") {
- return nil, err
- }
- if f != nil && foundField == nil {
- foundField = f
- }
+
+ fieldType := field.Type()
+
+ if ptr, ok := fieldType.(*types.Pointer); ok {
+ fieldType = ptr.Elem()
}
- if named, ok := field.Type().Underlying().(*types.Struct); ok {
+ // Type.Underlying() returns itself for all types except types.Named, where it returns a struct type.
+ // It should be safe to always call.
+ if named, ok := fieldType.Underlying().(*types.Struct); ok {
f, err := findField(named, name, structTag)
if err != nil && !strings.HasPrefix(err.Error(), "no field named") {
return nil, err
@@ -157,7 +162,7 @@ func findField(typ *types.Struct, name, structTag string) (*types.Var, error) {
continue
}
- if strings.EqualFold(field.Name(), name) && foundField == nil {
+ if equalFieldName(field.Name(), name) && foundField == nil { // aqui!
foundField = field
}
}
@@ -198,7 +203,7 @@ func (b BindErrors) Error() string {
return strings.Join(errs, "\n\n")
}
-func bindObject(t types.Type, object *Object, imports *Imports, structTag string) BindErrors {
+func bindObject(t types.Type, object *Object, structTag string) BindErrors {
var errs BindErrors
for i := range object.Fields {
field := &object.Fields[i]
@@ -208,13 +213,13 @@ func bindObject(t types.Type, object *Object, imports *Imports, structTag string
}
// first try binding to a method
- methodErr := bindMethod(imports, t, field)
+ methodErr := bindMethod(t, field)
if methodErr == nil {
continue
}
// otherwise try binding to a var
- varErr := bindVar(imports, t, field, structTag)
+ varErr := bindVar(t, field, structTag)
if varErr != nil {
errs = append(errs, BindError{
@@ -229,7 +234,7 @@ func bindObject(t types.Type, object *Object, imports *Imports, structTag string
return errs
}
-func bindMethod(imports *Imports, t types.Type, field *Field) error {
+func bindMethod(t types.Type, field *Field) error {
namedType, ok := t.(*types.Named)
if !ok {
return fmt.Errorf("not a named type")
@@ -250,13 +255,25 @@ func bindMethod(imports *Imports, t types.Type, field *Field) error {
} else if sig.Results().Len() != 2 {
return fmt.Errorf("method has wrong number of args")
}
- newArgs, err := matchArgs(field, sig.Params())
+ params := sig.Params()
+ // If the first argument is the context, remove it from the comparison and set
+ // the MethodHasContext flag so that the context will be passed to this model's method
+ if params.Len() > 0 && params.At(0).Type().String() == "context.Context" {
+ field.MethodHasContext = true
+ vars := make([]*types.Var, params.Len()-1)
+ for i := 1; i < params.Len(); i++ {
+ vars[i-1] = params.At(i)
+ }
+ params = types.NewTuple(vars...)
+ }
+
+ newArgs, err := matchArgs(field, params)
if err != nil {
return err
}
result := sig.Results().At(0)
- if err := validateTypeBinding(imports, field, result.Type()); err != nil {
+ if err := validateTypeBinding(field, result.Type()); err != nil {
return errors.Wrap(err, "method has wrong return type")
}
@@ -268,7 +285,7 @@ func bindMethod(imports *Imports, t types.Type, field *Field) error {
return nil
}
-func bindVar(imports *Imports, t types.Type, field *Field, structTag string) error {
+func bindVar(t types.Type, field *Field, structTag string) error {
underlying, ok := t.Underlying().(*types.Struct)
if !ok {
return fmt.Errorf("not a struct")
@@ -283,7 +300,7 @@ func bindVar(imports *Imports, t types.Type, field *Field, structTag string) err
return err
}
- if err := validateTypeBinding(imports, field, structField.Type()); err != nil {
+ if err := validateTypeBinding(field, structField.Type()); err != nil {
return errors.Wrap(err, "field has wrong type")
}
@@ -316,22 +333,21 @@ nextArg:
return newArgs, nil
}
-func validateTypeBinding(imports *Imports, field *Field, goType types.Type) error {
+func validateTypeBinding(field *Field, goType types.Type) error {
gqlType := normalizeVendor(field.Type.FullSignature())
goTypeStr := normalizeVendor(goType.String())
- if goTypeStr == gqlType || "*"+goTypeStr == gqlType || goTypeStr == "*"+gqlType {
+ if equalTypes(goTypeStr, gqlType) {
field.Type.Modifiers = modifiersFromGoType(goType)
return nil
}
// deal with type aliases
underlyingStr := normalizeVendor(goType.Underlying().String())
- if underlyingStr == gqlType || "*"+underlyingStr == gqlType || underlyingStr == "*"+gqlType {
+ if equalTypes(underlyingStr, gqlType) {
field.Type.Modifiers = modifiersFromGoType(goType)
pkg, typ := pkgAndType(goType.String())
- imp := imports.findByPath(pkg)
- field.AliasedType = &Ref{GoType: typ, Import: imp}
+ field.AliasedType = &Ref{GoType: typ, Package: pkg}
return nil
}
@@ -365,3 +381,7 @@ func normalizeVendor(pkg string) string {
parts := strings.Split(pkg, "/vendor/")
return modifiers + parts[len(parts)-1]
}
+
+func equalTypes(goType string, gqlType string) bool {
+ return goType == gqlType || "*"+goType == gqlType || goType == "*"+gqlType || strings.Replace(goType, "[]*", "[]", -1) == gqlType
+}