aboutsummaryrefslogblamecommitdiffstats
path: root/vendor/github.com/vektah/gqlgen/codegen/object.go
blob: 1c03c0bae3d79efc0052a8d1ffcd1eb61a0faec8 (plain) (tree)













































































































































































































                                                                                                                                         
package codegen

import (
	"bytes"
	"fmt"
	"strconv"
	"strings"
	"text/template"
	"unicode"
)

type Object struct {
	*NamedType

	Fields             []Field
	Satisfies          []string
	Root               bool
	DisableConcurrency bool
	Stream             bool
}

type Field struct {
	*Type

	GQLName       string          // The name of the field in graphql
	GoMethodName  string          // The name of the method in go, if any
	GoVarName     string          // The name of the var in go, if any
	Args          []FieldArgument // A list of arguments to be passed to this field
	ForceResolver bool            // Should be emit Resolver method
	NoErr         bool            // If this is bound to a go method, does that method have an error as the second argument
	Object        *Object         // A link back to the parent object
	Default       interface{}     // The default value
}

type FieldArgument struct {
	*Type

	GQLName   string      // The name of the argument in graphql
	GoVarName string      // The name of the var in go
	Object    *Object     // A link back to the parent object
	Default   interface{} // The default value
}

type Objects []*Object

func (o *Object) Implementors() string {
	satisfiedBy := strconv.Quote(o.GQLType)
	for _, s := range o.Satisfies {
		satisfiedBy += ", " + strconv.Quote(s)
	}
	return "[]string{" + satisfiedBy + "}"
}

func (o *Object) HasResolvers() bool {
	for _, f := range o.Fields {
		if f.IsResolver() {
			return true
		}
	}
	return false
}

func (f *Field) IsResolver() bool {
	return f.ForceResolver || f.GoMethodName == "" && f.GoVarName == ""
}

func (f *Field) IsConcurrent() bool {
	return f.IsResolver() && !f.Object.DisableConcurrency
}
func (f *Field) ShortInvocation() string {
	if !f.IsResolver() {
		return ""
	}
	shortName := strings.ToUpper(f.GQLName[:1]) + f.GQLName[1:]
	res := fmt.Sprintf("%s().%s(ctx", f.Object.GQLType, shortName)
	if !f.Object.Root {
		res += fmt.Sprintf(", obj")
	}
	for _, arg := range f.Args {
		res += fmt.Sprintf(", %s", arg.GoVarName)
	}
	res += ")"
	return res
}
func (f *Field) ShortResolverDeclaration() string {
	if !f.IsResolver() {
		return ""
	}
	decl := strings.TrimPrefix(f.ResolverDeclaration(), f.Object.GQLType+"_")
	return strings.ToUpper(decl[:1]) + decl[1:]
}

func (f *Field) ResolverDeclaration() string {
	if !f.IsResolver() {
		return ""
	}
	res := fmt.Sprintf("%s_%s(ctx context.Context", f.Object.GQLType, f.GQLName)

	if !f.Object.Root {
		res += fmt.Sprintf(", obj *%s", f.Object.FullName())
	}
	for _, arg := range f.Args {
		res += fmt.Sprintf(", %s %s", arg.GoVarName, arg.Signature())
	}

	result := f.Signature()
	if f.Object.Stream {
		result = "<-chan " + result
	}

	res += fmt.Sprintf(") (%s, error)", result)
	return res
}

func (f *Field) CallArgs() string {
	var args []string

	if f.GoMethodName == "" {
		args = append(args, "ctx")

		if !f.Object.Root {
			args = append(args, "obj")
		}
	}

	for _, arg := range f.Args {
		args = append(args, "args["+strconv.Quote(arg.GQLName)+"].("+arg.Signature()+")")
	}

	return strings.Join(args, ", ")
}

// should be in the template, but its recursive and has a bunch of args
func (f *Field) WriteJson() string {
	return f.doWriteJson("res", f.Type.Modifiers, false, 1)
}

func (f *Field) doWriteJson(val string, remainingMods []string, isPtr bool, depth int) string {
	switch {
	case len(remainingMods) > 0 && remainingMods[0] == modPtr:
		return fmt.Sprintf("if %s == nil { return graphql.Null }\n%s", val, f.doWriteJson(val, remainingMods[1:], true, depth+1))

	case len(remainingMods) > 0 && remainingMods[0] == modList:
		if isPtr {
			val = "*" + val
		}
		var arr = "arr" + strconv.Itoa(depth)
		var index = "idx" + strconv.Itoa(depth)

		return tpl(`{{.arr}} := graphql.Array{}
			for {{.index}} := range {{.val}} {
				{{.arr}} = append({{.arr}}, func() graphql.Marshaler {
					rctx := graphql.GetResolverContext(ctx)
					rctx.PushIndex({{.index}})
					defer rctx.Pop()
					{{ .next }} 
				}())
			}
			return {{.arr}}`, map[string]interface{}{
			"val":   val,
			"arr":   arr,
			"index": index,
			"next":  f.doWriteJson(val+"["+index+"]", remainingMods[1:], false, depth+1),
		})

	case f.IsScalar:
		if isPtr {
			val = "*" + val
		}
		return f.Marshal(val)

	default:
		if !isPtr {
			val = "&" + val
		}
		return fmt.Sprintf("return ec._%s(ctx, field.Selections, %s)", f.GQLType, val)
	}
}

func (os Objects) ByName(name string) *Object {
	for i, o := range os {
		if strings.EqualFold(o.GQLType, name) {
			return os[i]
		}
	}
	return nil
}

func tpl(tpl string, vars map[string]interface{}) string {
	b := &bytes.Buffer{}
	err := template.Must(template.New("inline").Parse(tpl)).Execute(b, vars)
	if err != nil {
		panic(err)
	}
	return b.String()
}

func ucFirst(s string) string {
	if s == "" {
		return ""
	}

	r := []rune(s)
	r[0] = unicode.ToUpper(r[0])
	return string(r)
}