package codegen import ( "fmt" "go/types" "log" "reflect" "strconv" "strings" "github.com/99designs/gqlgen/codegen/config" "github.com/99designs/gqlgen/codegen/templates" "github.com/pkg/errors" "github.com/vektah/gqlparser/ast" ) type Field struct { *ast.FieldDefinition TypeReference *config.TypeReference GoFieldType GoFieldType // The field type in go, if any GoReceiverName string // The name of method & var receiver in go, if any GoFieldName string // The name of the method or var in go, if any IsResolver bool // Does this field need a resolver Args []*FieldArgument // A list of arguments to be passed to this field MethodHasContext bool // If this is bound to a go method, does the method also take a context NoErr bool // If this is bound to a go method, does that method have an error as the second argument Object *Object // A link back to the parent object Default interface{} // The default value Directives []*Directive } func (b *builder) buildField(obj *Object, field *ast.FieldDefinition) (*Field, error) { dirs, err := b.getDirectives(field.Directives) if err != nil { return nil, err } f := Field{ FieldDefinition: field, Object: obj, Directives: dirs, GoFieldName: templates.ToGo(field.Name), GoFieldType: GoFieldVariable, GoReceiverName: "obj", } if field.DefaultValue != nil { var err error f.Default, err = field.DefaultValue.Value(nil) if err != nil { return nil, errors.Errorf("default value %s is not valid: %s", field.Name, err.Error()) } } for _, arg := range field.Arguments { newArg, err := b.buildArg(obj, arg) if err != nil { return nil, err } f.Args = append(f.Args, newArg) } if err = b.bindField(obj, &f); err != nil { f.IsResolver = true log.Println(err.Error()) } if f.IsResolver && !f.TypeReference.IsPtr() && f.TypeReference.IsStruct() { f.TypeReference = b.Binder.PointerTo(f.TypeReference) } return &f, nil } func (b *builder) bindField(obj *Object, f *Field) error { defer func() { if f.TypeReference == nil { tr, err := b.Binder.TypeReference(f.Type, nil) if err != nil { panic(err) } f.TypeReference = tr } }() switch { case f.Name == "__schema": f.GoFieldType = GoFieldMethod f.GoReceiverName = "ec" f.GoFieldName = "introspectSchema" return nil case f.Name == "__type": f.GoFieldType = GoFieldMethod f.GoReceiverName = "ec" f.GoFieldName = "introspectType" return nil case obj.Root: f.IsResolver = true return nil case b.Config.Models[obj.Name].Fields[f.Name].Resolver: f.IsResolver = true return nil case obj.Type == config.MapType: f.GoFieldType = GoFieldMap return nil case b.Config.Models[obj.Name].Fields[f.Name].FieldName != "": f.GoFieldName = b.Config.Models[obj.Name].Fields[f.Name].FieldName } target, err := b.findBindTarget(obj.Type.(*types.Named), f.GoFieldName) if err != nil { return err } pos := b.Binder.ObjectPosition(target) switch target := target.(type) { case nil: objPos := b.Binder.TypePosition(obj.Type) return fmt.Errorf( "%s:%d adding resolver method for %s.%s, nothing matched", objPos.Filename, objPos.Line, obj.Name, f.Name, ) case *types.Func: sig := target.Type().(*types.Signature) if sig.Results().Len() == 1 { f.NoErr = true } else if sig.Results().Len() != 2 { return fmt.Errorf("method has wrong number of args") } params := sig.Params() // If the first argument is the context, remove it from the comparison and set // the MethodHasContext flag so that the context will be passed to this model's method if params.Len() > 0 && params.At(0).Type().String() == "context.Context" { f.MethodHasContext = true vars := make([]*types.Var, params.Len()-1) for i := 1; i < params.Len(); i++ { vars[i-1] = params.At(i) } params = types.NewTuple(vars...) } if err = b.bindArgs(f, params); err != nil { return errors.Wrapf(err, "%s:%d", pos.Filename, pos.Line) } result := sig.Results().At(0) tr, err := b.Binder.TypeReference(f.Type, result.Type()) if err != nil { return err } // success, args and return type match. Bind to method f.GoFieldType = GoFieldMethod f.GoReceiverName = "obj" f.GoFieldName = target.Name() f.TypeReference = tr return nil case *types.Var: tr, err := b.Binder.TypeReference(f.Type, target.Type()) if err != nil { return err } // success, bind to var f.GoFieldType = GoFieldVariable f.GoReceiverName = "obj" f.GoFieldName = target.Name() f.TypeReference = tr return nil default: panic(fmt.Errorf("unknown bind target %T for %s", target, f.Name)) } } // findField attempts to match the name to a struct field with the following // priorites: // 1. Any method with a matching name // 2. Any Fields with a struct tag (see config.StructTag) // 3. Any fields with a matching name // 4. Same logic again for embedded fields func (b *builder) findBindTarget(named *types.Named, name string) (types.Object, error) { for i := 0; i < named.NumMethods(); i++ { method := named.Method(i) if !method.Exported() { continue } if !strings.EqualFold(method.Name(), name) { continue } return method, nil } strukt, ok := named.Underlying().(*types.Struct) if !ok { return nil, fmt.Errorf("not a struct") } return b.findBindStructTarget(strukt, name) } func (b *builder) findBindStructTarget(strukt *types.Struct, name string) (types.Object, error) { // struct tags have the highest priority if b.Config.StructTag != "" { var foundField *types.Var for i := 0; i < strukt.NumFields(); i++ { field := strukt.Field(i) if !field.Exported() { continue } tags := reflect.StructTag(strukt.Tag(i)) if val, ok := tags.Lookup(b.Config.StructTag); ok && equalFieldName(val, name) { if foundField != nil { return nil, errors.Errorf("tag %s is ambigious; multiple fields have the same tag value of %s", b.Config.StructTag, val) } foundField = field } } if foundField != nil { return foundField, nil } } // Then matching field names for i := 0; i < strukt.NumFields(); i++ { field := strukt.Field(i) if !field.Exported() { continue } if equalFieldName(field.Name(), name) { // aqui! return field, nil } } // Then look in embedded structs for i := 0; i < strukt.NumFields(); i++ { field := strukt.Field(i) if !field.Exported() { continue } if !field.Anonymous() { continue } fieldType := field.Type() if ptr, ok := fieldType.(*types.Pointer); ok { fieldType = ptr.Elem() } switch fieldType := fieldType.(type) { case *types.Named: f, err := b.findBindTarget(fieldType, name) if err != nil { return nil, err } if f != nil { return f, nil } case *types.Struct: f, err := b.findBindStructTarget(fieldType, name) if err != nil { return nil, err } if f != nil { return f, nil } default: panic(fmt.Errorf("unknown embedded field type %T", field.Type())) } } return nil, nil } func (f *Field) HasDirectives() bool { return len(f.ImplDirectives()) > 0 } func (f *Field) DirectiveObjName() string { if f.Object.Root { return "nil" } return f.GoReceiverName } func (f *Field) ImplDirectives() []*Directive { var d []*Directive loc := ast.LocationFieldDefinition if f.Object.IsInputType() { loc = ast.LocationInputFieldDefinition } for i := range f.Directives { if !f.Directives[i].Builtin && f.Directives[i].IsLocation(loc) { d = append(d, f.Directives[i]) } } return d } func (f *Field) IsReserved() bool { return strings.HasPrefix(f.Name, "__") } func (f *Field) IsMethod() bool { return f.GoFieldType == GoFieldMethod } func (f *Field) IsVariable() bool { return f.GoFieldType == GoFieldVariable } func (f *Field) IsMap() bool { return f.GoFieldType == GoFieldMap } func (f *Field) IsConcurrent() bool { if f.Object.DisableConcurrency { return false } return f.MethodHasContext || f.IsResolver } func (f *Field) GoNameUnexported() string { return templates.ToGoPrivate(f.Name) } func (f *Field) ShortInvocation() string { return fmt.Sprintf("%s().%s(%s)", f.Object.Definition.Name, f.GoFieldName, f.CallArgs()) } func (f *Field) ArgsFunc() string { if len(f.Args) == 0 { return "" } return "field_" + f.Object.Definition.Name + "_" + f.Name + "_args" } func (f *Field) ResolverType() string { if !f.IsResolver { return "" } return fmt.Sprintf("%s().%s(%s)", f.Object.Definition.Name, f.GoFieldName, f.CallArgs()) } func (f *Field) ShortResolverDeclaration() string { res := "(ctx context.Context" if !f.Object.Root { res += fmt.Sprintf(", obj *%s", templates.CurrentImports.LookupType(f.Object.Type)) } for _, arg := range f.Args { res += fmt.Sprintf(", %s %s", arg.VarName, templates.CurrentImports.LookupType(arg.TypeReference.GO)) } result := templates.CurrentImports.LookupType(f.TypeReference.GO) if f.Object.Stream { result = "<-chan " + result } res += fmt.Sprintf(") (%s, error)", result) return res } func (f *Field) ComplexitySignature() string { res := fmt.Sprintf("func(childComplexity int") for _, arg := range f.Args { res += fmt.Sprintf(", %s %s", arg.VarName, templates.CurrentImports.LookupType(arg.TypeReference.GO)) } res += ") int" return res } func (f *Field) ComplexityArgs() string { args := make([]string, len(f.Args)) for i, arg := range f.Args { args[i] = "args[" + strconv.Quote(arg.Name) + "].(" + templates.CurrentImports.LookupType(arg.TypeReference.GO) + ")" } return strings.Join(args, ", ") } func (f *Field) CallArgs() string { args := make([]string, 0, len(f.Args)+2) if f.IsResolver { args = append(args, "rctx") if !f.Object.Root { args = append(args, "obj") } } else if f.MethodHasContext { args = append(args, "ctx") } for _, arg := range f.Args { args = append(args, "args["+strconv.Quote(arg.Name)+"].("+templates.CurrentImports.LookupType(arg.TypeReference.GO)+")") } return strings.Join(args, ", ") }