aboutsummaryrefslogblamecommitdiffstats
path: root/vendor/github.com/99designs/gqlgen/codegen/field.go
blob: fab26f2b9a9c2245682c8ff10bcc50369ad50df2 (plain) (tree)





























































































































































































































































































                                                                                                                                                                





















                                                                                













































































                                                                                                                     


                                                                                                                                     





                                       
                                                






                                                  

                                          







                                                                                                                                        
package codegen

import (
	"fmt"
	"go/types"
	"log"
	"reflect"
	"strconv"
	"strings"

	"github.com/99designs/gqlgen/codegen/config"
	"github.com/99designs/gqlgen/codegen/templates"
	"github.com/pkg/errors"
	"github.com/vektah/gqlparser/ast"
)

type Field struct {
	*ast.FieldDefinition

	TypeReference    *config.TypeReference
	GoFieldType      GoFieldType      // The field type in go, if any
	GoReceiverName   string           // The name of method & var receiver in go, if any
	GoFieldName      string           // The name of the method or var in go, if any
	IsResolver       bool             // Does this field need a resolver
	Args             []*FieldArgument // A list of arguments to be passed to this field
	MethodHasContext bool             // If this is bound to a go method, does the method also take a context
	NoErr            bool             // If this is bound to a go method, does that method have an error as the second argument
	Object           *Object          // A link back to the parent object
	Default          interface{}      // The default value
	Directives       []*Directive
}

func (b *builder) buildField(obj *Object, field *ast.FieldDefinition) (*Field, error) {
	dirs, err := b.getDirectives(field.Directives)
	if err != nil {
		return nil, err
	}

	f := Field{
		FieldDefinition: field,
		Object:          obj,
		Directives:      dirs,
		GoFieldName:     templates.ToGo(field.Name),
		GoFieldType:     GoFieldVariable,
		GoReceiverName:  "obj",
	}

	if field.DefaultValue != nil {
		var err error
		f.Default, err = field.DefaultValue.Value(nil)
		if err != nil {
			return nil, errors.Errorf("default value %s is not valid: %s", field.Name, err.Error())
		}
	}

	for _, arg := range field.Arguments {
		newArg, err := b.buildArg(obj, arg)
		if err != nil {
			return nil, err
		}
		f.Args = append(f.Args, newArg)
	}

	if err = b.bindField(obj, &f); err != nil {
		f.IsResolver = true
		log.Println(err.Error())
	}

	if f.IsResolver && !f.TypeReference.IsPtr() && f.TypeReference.IsStruct() {
		f.TypeReference = b.Binder.PointerTo(f.TypeReference)
	}

	return &f, nil
}

func (b *builder) bindField(obj *Object, f *Field) error {
	defer func() {
		if f.TypeReference == nil {
			tr, err := b.Binder.TypeReference(f.Type, nil)
			if err != nil {
				panic(err)
			}
			f.TypeReference = tr
		}
	}()

	switch {
	case f.Name == "__schema":
		f.GoFieldType = GoFieldMethod
		f.GoReceiverName = "ec"
		f.GoFieldName = "introspectSchema"
		return nil
	case f.Name == "__type":
		f.GoFieldType = GoFieldMethod
		f.GoReceiverName = "ec"
		f.GoFieldName = "introspectType"
		return nil
	case obj.Root:
		f.IsResolver = true
		return nil
	case b.Config.Models[obj.Name].Fields[f.Name].Resolver:
		f.IsResolver = true
		return nil
	case obj.Type == config.MapType:
		f.GoFieldType = GoFieldMap
		return nil
	case b.Config.Models[obj.Name].Fields[f.Name].FieldName != "":
		f.GoFieldName = b.Config.Models[obj.Name].Fields[f.Name].FieldName
	}

	target, err := b.findBindTarget(obj.Type.(*types.Named), f.GoFieldName)
	if err != nil {
		return err
	}

	pos := b.Binder.ObjectPosition(target)

	switch target := target.(type) {
	case nil:
		objPos := b.Binder.TypePosition(obj.Type)
		return fmt.Errorf(
			"%s:%d adding resolver method for %s.%s, nothing matched",
			objPos.Filename,
			objPos.Line,
			obj.Name,
			f.Name,
		)

	case *types.Func:
		sig := target.Type().(*types.Signature)
		if sig.Results().Len() == 1 {
			f.NoErr = true
		} else if sig.Results().Len() != 2 {
			return fmt.Errorf("method has wrong number of args")
		}
		params := sig.Params()
		// If the first argument is the context, remove it from the comparison and set
		// the MethodHasContext flag so that the context will be passed to this model's method
		if params.Len() > 0 && params.At(0).Type().String() == "context.Context" {
			f.MethodHasContext = true
			vars := make([]*types.Var, params.Len()-1)
			for i := 1; i < params.Len(); i++ {
				vars[i-1] = params.At(i)
			}
			params = types.NewTuple(vars...)
		}

		if err = b.bindArgs(f, params); err != nil {
			return errors.Wrapf(err, "%s:%d", pos.Filename, pos.Line)
		}

		result := sig.Results().At(0)
		tr, err := b.Binder.TypeReference(f.Type, result.Type())
		if err != nil {
			return err
		}

		// success, args and return type match. Bind to method
		f.GoFieldType = GoFieldMethod
		f.GoReceiverName = "obj"
		f.GoFieldName = target.Name()
		f.TypeReference = tr

		return nil

	case *types.Var:
		tr, err := b.Binder.TypeReference(f.Type, target.Type())
		if err != nil {
			return err
		}

		// success, bind to var
		f.GoFieldType = GoFieldVariable
		f.GoReceiverName = "obj"
		f.GoFieldName = target.Name()
		f.TypeReference = tr

		return nil
	default:
		panic(fmt.Errorf("unknown bind target %T for %s", target, f.Name))
	}
}

// findField attempts to match the name to a struct field with the following
// priorites:
// 1. Any method with a matching name
// 2. Any Fields with a struct tag (see config.StructTag)
// 3. Any fields with a matching name
// 4. Same logic again for embedded fields
func (b *builder) findBindTarget(named *types.Named, name string) (types.Object, error) {
	for i := 0; i < named.NumMethods(); i++ {
		method := named.Method(i)
		if !method.Exported() {
			continue
		}

		if !strings.EqualFold(method.Name(), name) {
			continue
		}

		return method, nil
	}

	strukt, ok := named.Underlying().(*types.Struct)
	if !ok {
		return nil, fmt.Errorf("not a struct")
	}
	return b.findBindStructTarget(strukt, name)
}

func (b *builder) findBindStructTarget(strukt *types.Struct, name string) (types.Object, error) {
	// struct tags have the highest priority
	if b.Config.StructTag != "" {
		var foundField *types.Var
		for i := 0; i < strukt.NumFields(); i++ {
			field := strukt.Field(i)
			if !field.Exported() {
				continue
			}
			tags := reflect.StructTag(strukt.Tag(i))
			if val, ok := tags.Lookup(b.Config.StructTag); ok && equalFieldName(val, name) {
				if foundField != nil {
					return nil, errors.Errorf("tag %s is ambigious; multiple fields have the same tag value of %s", b.Config.StructTag, val)
				}

				foundField = field
			}
		}
		if foundField != nil {
			return foundField, nil
		}
	}

	// Then matching field names
	for i := 0; i < strukt.NumFields(); i++ {
		field := strukt.Field(i)
		if !field.Exported() {
			continue
		}
		if equalFieldName(field.Name(), name) { // aqui!
			return field, nil
		}
	}

	// Then look in embedded structs
	for i := 0; i < strukt.NumFields(); i++ {
		field := strukt.Field(i)
		if !field.Exported() {
			continue
		}

		if !field.Anonymous() {
			continue
		}

		fieldType := field.Type()
		if ptr, ok := fieldType.(*types.Pointer); ok {
			fieldType = ptr.Elem()
		}

		switch fieldType := fieldType.(type) {
		case *types.Named:
			f, err := b.findBindTarget(fieldType, name)
			if err != nil {
				return nil, err
			}
			if f != nil {
				return f, nil
			}
		case *types.Struct:
			f, err := b.findBindStructTarget(fieldType, name)
			if err != nil {
				return nil, err
			}
			if f != nil {
				return f, nil
			}
		default:
			panic(fmt.Errorf("unknown embedded field type %T", field.Type()))
		}
	}

	return nil, nil
}

func (f *Field) HasDirectives() bool {
	return len(f.ImplDirectives()) > 0
}

func (f *Field) DirectiveObjName() string {
	if f.Object.Root {
		return "nil"
	}
	return f.GoReceiverName
}

func (f *Field) ImplDirectives() []*Directive {
	var d []*Directive
	loc := ast.LocationFieldDefinition
	if f.Object.IsInputType() {
		loc = ast.LocationInputFieldDefinition
	}
	for i := range f.Directives {
		if !f.Directives[i].Builtin && f.Directives[i].IsLocation(loc) {
			d = append(d, f.Directives[i])
		}
	}
	return d
}

func (f *Field) IsReserved() bool {
	return strings.HasPrefix(f.Name, "__")
}

func (f *Field) IsMethod() bool {
	return f.GoFieldType == GoFieldMethod
}

func (f *Field) IsVariable() bool {
	return f.GoFieldType == GoFieldVariable
}

func (f *Field) IsMap() bool {
	return f.GoFieldType == GoFieldMap
}

func (f *Field) IsConcurrent() bool {
	if f.Object.DisableConcurrency {
		return false
	}
	return f.MethodHasContext || f.IsResolver
}

func (f *Field) GoNameUnexported() string {
	return templates.ToGoPrivate(f.Name)
}

func (f *Field) ShortInvocation() string {
	return fmt.Sprintf("%s().%s(%s)", f.Object.Definition.Name, f.GoFieldName, f.CallArgs())
}

func (f *Field) ArgsFunc() string {
	if len(f.Args) == 0 {
		return ""
	}

	return "field_" + f.Object.Definition.Name + "_" + f.Name + "_args"
}

func (f *Field) ResolverType() string {
	if !f.IsResolver {
		return ""
	}

	return fmt.Sprintf("%s().%s(%s)", f.Object.Definition.Name, f.GoFieldName, f.CallArgs())
}

func (f *Field) ShortResolverDeclaration() string {
	res := "(ctx context.Context"

	if !f.Object.Root {
		res += fmt.Sprintf(", obj *%s", templates.CurrentImports.LookupType(f.Object.Type))
	}
	for _, arg := range f.Args {
		res += fmt.Sprintf(", %s %s", arg.VarName, templates.CurrentImports.LookupType(arg.TypeReference.GO))
	}

	result := templates.CurrentImports.LookupType(f.TypeReference.GO)
	if f.Object.Stream {
		result = "<-chan " + result
	}

	res += fmt.Sprintf(") (%s, error)", result)
	return res
}

func (f *Field) ComplexitySignature() string {
	res := fmt.Sprintf("func(childComplexity int")
	for _, arg := range f.Args {
		res += fmt.Sprintf(", %s %s", arg.VarName, templates.CurrentImports.LookupType(arg.TypeReference.GO))
	}
	res += ") int"
	return res
}

func (f *Field) ComplexityArgs() string {
	args := make([]string, len(f.Args))
	for i, arg := range f.Args {
		args[i] = "args[" + strconv.Quote(arg.Name) + "].(" + templates.CurrentImports.LookupType(arg.TypeReference.GO) + ")"
	}

	return strings.Join(args, ", ")
}

func (f *Field) CallArgs() string {
	args := make([]string, 0, len(f.Args)+2)

	if f.IsResolver {
		args = append(args, "rctx")

		if !f.Object.Root {
			args = append(args, "obj")
		}
	} else if f.MethodHasContext {
		args = append(args, "ctx")
	}

	for _, arg := range f.Args {
		args = append(args, "args["+strconv.Quote(arg.Name)+"].("+templates.CurrentImports.LookupType(arg.TypeReference.GO)+")")
	}

	return strings.Join(args, ", ")
}