package codegen import ( "bytes" "fmt" "strconv" "strings" "text/template" "unicode" ) type Object struct { *NamedType Fields []Field Satisfies []string Root bool DisableConcurrency bool Stream bool } type Field struct { *Type GQLName string // The name of the field in graphql GoMethodName string // The name of the method in go, if any GoVarName string // The name of the var in go, if any Args []FieldArgument // A list of arguments to be passed to this field ForceResolver bool // Should be emit Resolver method NoErr bool // If this is bound to a go method, does that method have an error as the second argument Object *Object // A link back to the parent object Default interface{} // The default value } type FieldArgument struct { *Type GQLName string // The name of the argument in graphql GoVarName string // The name of the var in go Object *Object // A link back to the parent object Default interface{} // The default value } type Objects []*Object func (o *Object) Implementors() string { satisfiedBy := strconv.Quote(o.GQLType) for _, s := range o.Satisfies { satisfiedBy += ", " + strconv.Quote(s) } return "[]string{" + satisfiedBy + "}" } func (o *Object) HasResolvers() bool { for _, f := range o.Fields { if f.IsResolver() { return true } } return false } func (f *Field) IsResolver() bool { return f.ForceResolver || f.GoMethodName == "" && f.GoVarName == "" } func (f *Field) IsConcurrent() bool { return f.IsResolver() && !f.Object.DisableConcurrency } func (f *Field) ShortInvocation() string { if !f.IsResolver() { return "" } shortName := strings.ToUpper(f.GQLName[:1]) + f.GQLName[1:] res := fmt.Sprintf("%s().%s(ctx", f.Object.GQLType, shortName) if !f.Object.Root { res += fmt.Sprintf(", obj") } for _, arg := range f.Args { res += fmt.Sprintf(", %s", arg.GoVarName) } res += ")" return res } func (f *Field) ShortResolverDeclaration() string { if !f.IsResolver() { return "" } decl := strings.TrimPrefix(f.ResolverDeclaration(), f.Object.GQLType+"_") return strings.ToUpper(decl[:1]) + decl[1:] } func (f *Field) ResolverDeclaration() string { if !f.IsResolver() { return "" } res := fmt.Sprintf("%s_%s(ctx context.Context", f.Object.GQLType, f.GQLName) if !f.Object.Root { res += fmt.Sprintf(", obj *%s", f.Object.FullName()) } for _, arg := range f.Args { res += fmt.Sprintf(", %s %s", arg.GoVarName, arg.Signature()) } result := f.Signature() if f.Object.Stream { result = "<-chan " + result } res += fmt.Sprintf(") (%s, error)", result) return res } func (f *Field) CallArgs() string { var args []string if f.GoMethodName == "" { args = append(args, "ctx") if !f.Object.Root { args = append(args, "obj") } } for _, arg := range f.Args { args = append(args, "args["+strconv.Quote(arg.GQLName)+"].("+arg.Signature()+")") } return strings.Join(args, ", ") } // should be in the template, but its recursive and has a bunch of args func (f *Field) WriteJson() string { return f.doWriteJson("res", f.Type.Modifiers, false, 1) } func (f *Field) doWriteJson(val string, remainingMods []string, isPtr bool, depth int) string { switch { case len(remainingMods) > 0 && remainingMods[0] == modPtr: return fmt.Sprintf("if %s == nil { return graphql.Null }\n%s", val, f.doWriteJson(val, remainingMods[1:], true, depth+1)) case len(remainingMods) > 0 && remainingMods[0] == modList: if isPtr { val = "*" + val } var arr = "arr" + strconv.Itoa(depth) var index = "idx" + strconv.Itoa(depth) return tpl(`{{.arr}} := graphql.Array{} for {{.index}} := range {{.val}} { {{.arr}} = append({{.arr}}, func() graphql.Marshaler { rctx := graphql.GetResolverContext(ctx) rctx.PushIndex({{.index}}) defer rctx.Pop() {{ .next }} }()) } return {{.arr}}`, map[string]interface{}{ "val": val, "arr": arr, "index": index, "next": f.doWriteJson(val+"["+index+"]", remainingMods[1:], false, depth+1), }) case f.IsScalar: if isPtr { val = "*" + val } return f.Marshal(val) default: if !isPtr { val = "&" + val } return fmt.Sprintf("return ec._%s(ctx, field.Selections, %s)", f.GQLType, val) } } func (os Objects) ByName(name string) *Object { for i, o := range os { if strings.EqualFold(o.GQLType, name) { return os[i] } } return nil } func tpl(tpl string, vars map[string]interface{}) string { b := &bytes.Buffer{} err := template.Must(template.New("inline").Parse(tpl)).Execute(b, vars) if err != nil { panic(err) } return b.String() } func ucFirst(s string) string { if s == "" { return "" } r := []rune(s) r[0] = unicode.ToUpper(r[0]) return string(r) }