package templates import ( "bytes" "fmt" "go/types" "io/ioutil" "os" "path/filepath" "reflect" "runtime" "sort" "strconv" "strings" "text/template" "unicode" "github.com/99designs/gqlgen/internal/imports" "github.com/pkg/errors" ) // CurrentImports keeps track of all the import declarations that are needed during the execution of a plugin. // this is done with a global because subtemplates currently get called in functions. Lets aim to remove this eventually. var CurrentImports *Imports // Options specify various parameters to rendering a template. type Options struct { // PackageName is a helper that specifies the package header declaration. // In other words, when you write the template you don't need to specify `package X` // at the top of the file. By providing PackageName in the Options, the Render // function will do that for you. PackageName string // Template is a string of the entire template that // will be parsed and rendered. If it's empty, // the plugin processor will look for .gotpl files // in the same directory of where you wrote the plugin. Template string // Filename is the name of the file that will be // written to the system disk once the template is rendered. Filename string RegionTags bool GeneratedHeader bool // Data will be passed to the template execution. Data interface{} Funcs template.FuncMap } // Render renders a gql plugin template from the given Options. Render is an // abstraction of the text/template package that makes it easier to write gqlgen // plugins. If Options.Template is empty, the Render function will look for `.gotpl` // files inside the directory where you wrote the plugin. func Render(cfg Options) error { if CurrentImports != nil { panic(fmt.Errorf("recursive or concurrent call to RenderToFile detected")) } CurrentImports = &Imports{destDir: filepath.Dir(cfg.Filename)} // load path relative to calling source file _, callerFile, _, _ := runtime.Caller(1) rootDir := filepath.Dir(callerFile) funcs := Funcs() for n, f := range cfg.Funcs { funcs[n] = f } t := template.New("").Funcs(funcs) var roots []string if cfg.Template != "" { var err error t, err = t.New("template.gotpl").Parse(cfg.Template) if err != nil { return errors.Wrap(err, "error with provided template") } roots = append(roots, "template.gotpl") } else { // load all the templates in the directory err := filepath.Walk(rootDir, func(path string, info os.FileInfo, err error) error { if err != nil { return err } name := filepath.ToSlash(strings.TrimPrefix(path, rootDir+string(os.PathSeparator))) if !strings.HasSuffix(info.Name(), ".gotpl") { return nil } b, err := ioutil.ReadFile(path) if err != nil { return err } t, err = t.New(name).Parse(string(b)) if err != nil { return errors.Wrap(err, cfg.Filename) } roots = append(roots, name) return nil }) if err != nil { return errors.Wrap(err, "locating templates") } } // then execute all the important looking ones in order, adding them to the same file sort.Slice(roots, func(i, j int) bool { // important files go first if strings.HasSuffix(roots[i], "!.gotpl") { return true } if strings.HasSuffix(roots[j], "!.gotpl") { return false } return roots[i] < roots[j] }) var buf bytes.Buffer for _, root := range roots { if cfg.RegionTags { buf.WriteString("\n// region " + center(70, "*", " "+root+" ") + "\n") } err := t.Lookup(root).Execute(&buf, cfg.Data) if err != nil { return errors.Wrap(err, root) } if cfg.RegionTags { buf.WriteString("\n// endregion " + center(70, "*", " "+root+" ") + "\n") } } var result bytes.Buffer if cfg.GeneratedHeader { result.WriteString("// Code generated by github.com/99designs/gqlgen, DO NOT EDIT.\n\n") } result.WriteString("package ") result.WriteString(cfg.PackageName) result.WriteString("\n\n") result.WriteString("import (\n") result.WriteString(CurrentImports.String()) result.WriteString(")\n") _, err := buf.WriteTo(&result) if err != nil { return err } CurrentImports = nil return write(cfg.Filename, result.Bytes()) } func center(width int, pad string, s string) string { if len(s)+2 > width { return s } lpad := (width - len(s)) / 2 rpad := width - (lpad + len(s)) return strings.Repeat(pad, lpad) + s + strings.Repeat(pad, rpad) } func Funcs() template.FuncMap { return template.FuncMap{ "ucFirst": ucFirst, "lcFirst": lcFirst, "quote": strconv.Quote, "rawQuote": rawQuote, "dump": Dump, "ref": ref, "ts": TypeIdentifier, "call": Call, "prefixLines": prefixLines, "notNil": notNil, "reserveImport": CurrentImports.Reserve, "lookupImport": CurrentImports.Lookup, "go": ToGo, "goPrivate": ToGoPrivate, "add": func(a, b int) int { return a + b }, "render": func(filename string, tpldata interface{}) (*bytes.Buffer, error) { return render(resolveName(filename, 0), tpldata) }, } } func ucFirst(s string) string { if s == "" { return "" } r := []rune(s) r[0] = unicode.ToUpper(r[0]) return string(r) } func lcFirst(s string) string { if s == "" { return "" } r := []rune(s) r[0] = unicode.ToLower(r[0]) return string(r) } func isDelimiter(c rune) bool { return c == '-' || c == '_' || unicode.IsSpace(c) } func ref(p types.Type) string { return CurrentImports.LookupType(p) } var pkgReplacer = strings.NewReplacer( "/", "ᚋ", ".", "ᚗ", "-", "ᚑ", ) func TypeIdentifier(t types.Type) string { res := "" for { switch it := t.(type) { case *types.Pointer: t.Underlying() res += "ᚖ" t = it.Elem() case *types.Slice: res += "ᚕ" t = it.Elem() case *types.Named: res += pkgReplacer.Replace(it.Obj().Pkg().Path()) res += "ᚐ" res += it.Obj().Name() return res case *types.Basic: res += it.Name() return res case *types.Map: res += "map" return res case *types.Interface: res += "interface" return res default: panic(fmt.Errorf("unexpected type %T", it)) } } } func Call(p *types.Func) string { pkg := CurrentImports.Lookup(p.Pkg().Path()) if pkg != "" { pkg += "." } if p.Type() != nil { // make sure the returned type is listed in our imports. ref(p.Type().(*types.Signature).Results().At(0).Type()) } return pkg + p.Name() } func ToGo(name string) string { runes := make([]rune, 0, len(name)) wordWalker(name, func(info *wordInfo) { word := info.Word if info.MatchCommonInitial { word = strings.ToUpper(word) } else if !info.HasCommonInitial { if strings.ToUpper(word) == word || strings.ToLower(word) == word { // FOO or foo → Foo // FOo → FOo word = ucFirst(strings.ToLower(word)) } } runes = append(runes, []rune(word)...) }) return string(runes) } func ToGoPrivate(name string) string { runes := make([]rune, 0, len(name)) first := true wordWalker(name, func(info *wordInfo) { word := info.Word if first { if strings.ToUpper(word) == word || strings.ToLower(word) == word { // ID → id, CAMEL → camel word = strings.ToLower(info.Word) } else { // ITicket → iTicket word = lcFirst(info.Word) } first = false } else if info.MatchCommonInitial { word = strings.ToUpper(word) } else if !info.HasCommonInitial { word = ucFirst(strings.ToLower(word)) } runes = append(runes, []rune(word)...) }) return sanitizeKeywords(string(runes)) } type wordInfo struct { Word string MatchCommonInitial bool HasCommonInitial bool } // This function is based on the following code. // https://github.com/golang/lint/blob/06c8688daad7faa9da5a0c2f163a3d14aac986ca/lint.go#L679 func wordWalker(str string, f func(*wordInfo)) { runes := []rune(str) w, i := 0, 0 // index of start of word, scan hasCommonInitial := false for i+1 <= len(runes) { eow := false // whether we hit the end of a word if i+1 == len(runes) { eow = true } else if isDelimiter(runes[i+1]) { // underscore; shift the remainder forward over any run of underscores eow = true n := 1 for i+n+1 < len(runes) && isDelimiter(runes[i+n+1]) { n++ } // Leave at most one underscore if the underscore is between two digits if i+n+1 < len(runes) && unicode.IsDigit(runes[i]) && unicode.IsDigit(runes[i+n+1]) { n-- } copy(runes[i+1:], runes[i+n+1:]) runes = runes[:len(runes)-n] } else if unicode.IsLower(runes[i]) && !unicode.IsLower(runes[i+1]) { // lower->non-lower eow = true } i++ // [w,i) is a word. word := string(runes[w:i]) if !eow && commonInitialisms[word] && !unicode.IsLower(runes[i]) { // through // split IDFoo → ID, Foo // but URLs → URLs } else if !eow { if commonInitialisms[word] { hasCommonInitial = true } continue } matchCommonInitial := false if commonInitialisms[strings.ToUpper(word)] { hasCommonInitial = true matchCommonInitial = true } f(&wordInfo{ Word: word, MatchCommonInitial: matchCommonInitial, HasCommonInitial: hasCommonInitial, }) hasCommonInitial = false w = i } } var keywords = []string{ "break", "default", "func", "interface", "select", "case", "defer", "go", "map", "struct", "chan", "else", "goto", "package", "switch", "const", "fallthrough", "if", "range", "type", "continue", "for", "import", "return", "var", "_", } // sanitizeKeywords prevents collisions with go keywords for arguments to resolver functions func sanitizeKeywords(name string) string { for _, k := range keywords { if name == k { return name + "Arg" } } return name } // commonInitialisms is a set of common initialisms. // Only add entries that are highly unlikely to be non-initialisms. // For instance, "ID" is fine (Freudian code is rare), but "AND" is not. var commonInitialisms = map[string]bool{ "ACL": true, "API": true, "ASCII": true, "CPU": true, "CSS": true, "DNS": true, "EOF": true, "GUID": true, "HTML": true, "HTTP": true, "HTTPS": true, "ID": true, "IP": true, "JSON": true, "LHS": true, "QPS": true, "RAM": true, "RHS": true, "RPC": true, "SLA": true, "SMTP": true, "SQL": true, "SSH": true, "TCP": true, "TLS": true, "TTL": true, "UDP": true, "UI": true, "UID": true, "UUID": true, "URI": true, "URL": true, "UTF8": true, "VM": true, "XML": true, "XMPP": true, "XSRF": true, "XSS": true, } func rawQuote(s string) string { return "`" + strings.Replace(s, "`", "`+\"`\"+`", -1) + "`" } func notNil(field string, data interface{}) bool { v := reflect.ValueOf(data) if v.Kind() == reflect.Ptr { v = v.Elem() } if v.Kind() != reflect.Struct { return false } val := v.FieldByName(field) return val.IsValid() && !val.IsNil() } func Dump(val interface{}) string { switch val := val.(type) { case int: return strconv.Itoa(val) case int64: return fmt.Sprintf("%d", val) case float64: return fmt.Sprintf("%f", val) case string: return strconv.Quote(val) case bool: return strconv.FormatBool(val) case nil: return "nil" case []interface{}: var parts []string for _, part := range val { parts = append(parts, Dump(part)) } return "[]interface{}{" + strings.Join(parts, ",") + "}" case map[string]interface{}: buf := bytes.Buffer{} buf.WriteString("map[string]interface{}{") var keys []string for key := range val { keys = append(keys, key) } sort.Strings(keys) for _, key := range keys { data := val[key] buf.WriteString(strconv.Quote(key)) buf.WriteString(":") buf.WriteString(Dump(data)) buf.WriteString(",") } buf.WriteString("}") return buf.String() default: panic(fmt.Errorf("unsupported type %T", val)) } } func prefixLines(prefix, s string) string { return prefix + strings.Replace(s, "\n", "\n"+prefix, -1) } func resolveName(name string, skip int) string { if name[0] == '.' { // load path relative to calling source file _, callerFile, _, _ := runtime.Caller(skip + 1) return filepath.Join(filepath.Dir(callerFile), name[1:]) } // load path relative to this directory _, callerFile, _, _ := runtime.Caller(0) return filepath.Join(filepath.Dir(callerFile), name) } func render(filename string, tpldata interface{}) (*bytes.Buffer, error) { t := template.New("").Funcs(Funcs()) b, err := ioutil.ReadFile(filename) if err != nil { return nil, err } t, err = t.New(filepath.Base(filename)).Parse(string(b)) if err != nil { panic(err) } buf := &bytes.Buffer{} return buf, t.Execute(buf, tpldata) } func write(filename string, b []byte) error { err := os.MkdirAll(filepath.Dir(filename), 0755) if err != nil { return errors.Wrap(err, "failed to create directory") } formatted, err := imports.Prune(filename, b) if err != nil { fmt.Fprintf(os.Stderr, "gofmt failed on %s: %s\n", filepath.Base(filename), err.Error()) formatted = b } err = ioutil.WriteFile(filename, formatted, 0644) if err != nil { return errors.Wrapf(err, "failed to write %s", filename) } return nil }