package codegen
import (
"bytes"
"fmt"
"strconv"
"strings"
"text/template"
"unicode"
"github.com/vektah/gqlparser/ast"
)
type GoFieldType int
const (
GoFieldUndefined GoFieldType = iota
GoFieldMethod
GoFieldVariable
)
type Object struct {
*NamedType
Fields []Field
Satisfies []string
ResolverInterface *Ref
Root bool
DisableConcurrency bool
Stream bool
}
type Field struct {
*Type
Description string // Description of a field
GQLName string // The name of the field in graphql
GoFieldType GoFieldType // The field type in go, if any
GoReceiverName string // The name of method & var receiver in go, if any
GoFieldName string // The name of the method or var in go, if any
Args []FieldArgument // A list of arguments to be passed to this field
ForceResolver bool // Should be emit Resolver method
NoErr bool // If this is bound to a go method, does that method have an error as the second argument
Object *Object // A link back to the parent object
Default interface{} // The default value
}
type FieldArgument struct {
*Type
GQLName string // The name of the argument in graphql
GoVarName string // The name of the var in go
Object *Object // A link back to the parent object
Default interface{} // The default value
}
type Objects []*Object
func (o *Object) Implementors() string {
satisfiedBy := strconv.Quote(o.GQLType)
for _, s := range o.Satisfies {
satisfiedBy += ", " + strconv.Quote(s)
}
return "[]string{" + satisfiedBy + "}"
}
func (o *Object) HasResolvers() bool {
for _, f := range o.Fields {
if f.IsResolver() {
return true
}
}
return false
}
func (o *Object) IsConcurrent() bool {
for _, f := range o.Fields {
if f.IsConcurrent() {
return true
}
}
return false
}
func (o *Object) IsReserved() bool {
return strings.HasPrefix(o.GQLType, "__")
}
func (f *Field) IsResolver() bool {
return f.GoFieldName == ""
}
func (f *Field) IsReserved() bool {
return strings.HasPrefix(f.GQLName, "__")
}
func (f *Field) IsMethod() bool {
return f.GoFieldType == GoFieldMethod
}
func (f *Field) IsVariable() bool {
return f.GoFieldType == GoFieldVariable
}
func (f *Field) IsConcurrent() bool {
return f.IsResolver() && !f.Object.DisableConcurrency
}
func (f *Field) GoNameExported() string {
return lintName(ucFirst(f.GQLName))
}
func (f *Field) GoNameUnexported() string {
return lintName(f.GQLName)
}
func (f *Field) ShortInvocation() string {
if !f.IsResolver() {
return ""
}
return fmt.Sprintf("%s().%s(%s)", f.Object.GQLType, f.GoNameExported(), f.CallArgs())
}
func (f *Field) ArgsFunc() string {
if len(f.Args) == 0 {
return ""
}
return "field_" + f.Object.GQLType + "_" + f.GQLName + "_args"
}
func (f *Field) ResolverType() string {
if !f.IsResolver() {
return ""
}
return fmt.Sprintf("%s().%s(%s)", f.Object.GQLType, f.GoNameExported(), f.CallArgs())
}
func (f *Field) ShortResolverDeclaration() string {
if !f.IsResolver() {
return ""
}
res := fmt.Sprintf("%s(ctx context.Context", f.GoNameExported())
if !f.Object.Root {
res += fmt.Sprintf(", obj *%s", f.Object.FullName())
}
for _, arg := range f.Args {
res += fmt.Sprintf(", %s %s", arg.GoVarName, arg.Signature())
}
result := f.Signature()
if f.Object.Stream {
result = "<-chan " + result
}
res += fmt.Sprintf(") (%s, error)", result)
return res
}
func (f *Field) ResolverDeclaration() string {
if !f.IsResolver() {
return ""
}
res := fmt.Sprintf("%s_%s(ctx context.Context", f.Object.GQLType, f.GoNameUnexported())
if !f.Object.Root {
res += fmt.Sprintf(", obj *%s", f.Object.FullName())
}
for _, arg := range f.Args {
res += fmt.Sprintf(", %s %s", arg.GoVarName, arg.Signature())
}
result := f.Signature()
if f.Object.Stream {
result = "<-chan " + result
}
res += fmt.Sprintf(") (%s, error)", result)
return res
}
func (f *Field) ComplexitySignature() string {
res := fmt.Sprintf("func(childComplexity int")
for _, arg := range f.Args {
res += fmt.Sprintf(", %s %s", arg.GoVarName, arg.Signature())
}
res += ") int"
return res
}
func (f *Field) ComplexityArgs() string {
var args []string
for _, arg := range f.Args {
args = append(args, "args["+strconv.Quote(arg.GQLName)+"].("+arg.Signature()+")")
}
return strings.Join(args, ", ")
}
func (f *Field) CallArgs() string {
var args []string
if f.IsResolver() {
args = append(args, "ctx")
if !f.Object.Root {
args = append(args, "obj")
}
}
for _, arg := range f.Args {
args = append(args, "args["+strconv.Quote(arg.GQLName)+"].("+arg.Signature()+")")
}
return strings.Join(args, ", ")
}
// should be in the template, but its recursive and has a bunch of args
func (f *Field) WriteJson() string {
return f.doWriteJson("res", f.Type.Modifiers, f.ASTType, false, 1)
}
func (f *Field) doWriteJson(val string, remainingMods []string, astType *ast.Type, isPtr bool, depth int) string {
switch {
case len(remainingMods) > 0 && remainingMods[0] == modPtr:
return tpl(`
if {{.val}} == nil {
{{- if .nonNull }}
if !ec.HasError(rctx) {
ec.Errorf(ctx, "must not be null")
}
{{- end }}
return graphql.Null
}
{{.next }}`, map[string]interface{}{
"val": val,
"nonNull": astType.NonNull,
"next": f.doWriteJson(val, remainingMods[1:], astType, true, depth+1),
})
case len(remainingMods) > 0 && remainingMods[0] == modList:
if isPtr {
val = "*" + val
}
var arr = "arr" + strconv.Itoa(depth)
var index = "idx" + strconv.Itoa(depth)
var usePtr bool
if len(remainingMods) == 1 && !isPtr {
usePtr = true
}
return tpl(`
{{.arr}} := make(graphql.Array, len({{.val}}))
{{ if and .top (not .isScalar) }} var wg sync.WaitGroup {{ end }}
{{ if not .isScalar }}
isLen1 := len({{.val}}) == 1
if !isLen1 {
wg.Add(len({{.val}}))
}
{{ end }}
for {{.index}} := range {{.val}} {
{{- if not .isScalar }}
{{.index}} := {{.index}}
rctx := &graphql.ResolverContext{
Index: &{{.index}},
Result: {{ if .usePtr }}&{{end}}{{.val}}[{{.index}}],
}
ctx := graphql.WithResolverContext(ctx, rctx)
f := func({{.index}} int) {
if !isLen1 {
defer wg.Done()
}
{{.arr}}[{{.index}}] = func() graphql.Marshaler {
{{ .next }}
}()
}
if isLen1 {
f({{.index}})
} else {
go f({{.index}})
}
{{ else }}
{{.arr}}[{{.index}}] = func() graphql.Marshaler {
{{ .next }}
}()
{{- end}}
}
{{ if and .top (not .isScalar) }} wg.Wait() {{ end }}
return {{.arr}}`, map[string]interface{}{
"val": val,
"arr": arr,
"index": index,
"top": depth == 1,
"arrayLen": len(val),
"isScalar": f.IsScalar,
"usePtr": usePtr,
"next": f.doWriteJson(val+"["+index+"]", remainingMods[1:], astType.Elem, false, depth+1),
})
case f.IsScalar:
if isPtr {
val = "*" + val
}
return f.Marshal(val)
default:
if !isPtr {
val = "&" + val
}
return tpl(`
return ec._{{.type}}(ctx, field.Selections, {{.val}})`, map[string]interface{}{
"type": f.GQLType,
"val": val,
})
}
}
func (f *FieldArgument) Stream() bool {
return f.Object != nil && f.Object.Stream
}
func (os Objects) ByName(name string) *Object {
for i, o := range os {
if strings.EqualFold(o.GQLType, name) {
return os[i]
}
}
return nil
}
func tpl(tpl string, vars map[string]interface{}) string {
b := &bytes.Buffer{}
err := template.Must(template.New("inline").Parse(tpl)).Execute(b, vars)
if err != nil {
panic(err)
}
return b.String()
}
func ucFirst(s string) string {
if s == "" {
return ""
}
r := []rune(s)
r[0] = unicode.ToUpper(r[0])
return string(r)
}
// copy from https://github.com/golang/lint/blob/06c8688daad7faa9da5a0c2f163a3d14aac986ca/lint.go#L679
// lintName returns a different name if it should be different.
func lintName(name string) (should string) {
// Fast path for simple cases: "_" and all lowercase.
if name == "_" {
return name
}
allLower := true
for _, r := range name {
if !unicode.IsLower(r) {
allLower = false
break
}
}
if allLower {
return name
}
// Split camelCase at any lower->upper transition, and split on underscores.
// Check each word for common initialisms.
runes := []rune(name)
w, i := 0, 0 // index of start of word, scan
for i+1 <= len(runes) {
eow := false // whether we hit the end of a word
if i+1 == len(runes) {
eow = true
} else if runes[i+1] == '_' {
// underscore; shift the remainder forward over any run of underscores
eow = true
n := 1
for i+n+1 < len(runes) && runes[i+n+1] == '_' {
n++
}
// Leave at most one underscore if the underscore is between two digits
if i+n+1 < len(runes) && unicode.IsDigit(runes[i]) && unicode.IsDigit(runes[i+n+1]) {
n--
}
copy(runes[i+1:], runes[i+n+1:])
runes = runes[:len(runes)-n]
} else if unicode.IsLower(runes[i]) && !unicode.IsLower(runes[i+1]) {
// lower->non-lower
eow = true
}
i++
if !eow {
continue
}
// [w,i) is a word.
word := string(runes[w:i])
if u := strings.ToUpper(word); commonInitialisms[u] {
// Keep consistent case, which is lowercase only at the start.
if w == 0 && unicode.IsLower(runes[w]) {
u = strings.ToLower(u)
}
// All the common initialisms are ASCII,
// so we can replace the bytes exactly.
copy(runes[w:], []rune(u))
} else if w > 0 && strings.ToLower(word) == word {
// already all lowercase, and not the first word, so uppercase the first character.
runes[w] = unicode.ToUpper(runes[w])
}
w = i
}
return string(runes)
}
// commonInitialisms is a set of common initialisms.
// Only add entries that are highly unlikely to be non-initialisms.
// For instance, "ID" is fine (Freudian code is rare), but "AND" is not.
var commonInitialisms = map[string]bool{
"ACL": true,
"API": true,
"ASCII": true,
"CPU": true,
"CSS": true,
"DNS": true,
"EOF": true,
"GUID": true,
"HTML": true,
"HTTP": true,
"HTTPS": true,
"ID": true,
"IP": true,
"JSON": true,
"LHS": true,
"QPS": true,
"RAM": true,
"RHS": true,
"RPC": true,
"SLA": true,
"SMTP": true,
"SQL": true,
"SSH": true,
"TCP": true,
"TLS": true,
"TTL": true,
"UDP": true,
"UI": true,
"UID": true,
"UUID": true,
"URI": true,
"URL": true,
"UTF8": true,
"VM": true,
"XML": true,
"XMPP": true,
"XSRF": true,
"XSS": true,
}