package templates

import (


// CurrentImports keeps track of all the import declarations that are needed during the execution of a plugin.
// this is done with a global because subtemplates currently get called in functions. Lets aim to remove this eventually.
var CurrentImports *Imports

// Options specify various parameters to rendering a template.
type Options struct {
	// PackageName is a helper that specifies the package header declaration.
	// In other words, when you write the template you don't need to specify `package X`
	// at the top of the file. By providing PackageName in the Options, the Render
	// function will do that for you.
	PackageName string
	// Template is a string of the entire template that
	// will be parsed and rendered. If it's empty,
	// the plugin processor will look for .gotpl files
	// in the same directory of where you wrote the plugin.
	Template string
	// Filename is the name of the file that will be
	// written to the system disk once the template is rendered.
	Filename        string
	RegionTags      bool
	GeneratedHeader bool
	// Data will be passed to the template execution.
	Data  interface{}
	Funcs template.FuncMap

// Render renders a gql plugin template from the given Options. Render is an
// abstraction of the text/template package that makes it easier to write gqlgen
// plugins. If Options.Template is empty, the Render function will look for `.gotpl`
// files inside the directory where you wrote the plugin.
func Render(cfg Options) error {
	if CurrentImports != nil {
		panic(fmt.Errorf("recursive or concurrent call to RenderToFile detected"))
	CurrentImports = &Imports{destDir: filepath.Dir(cfg.Filename)}

	// load path relative to calling source file
	_, callerFile, _, _ := runtime.Caller(1)
	rootDir := filepath.Dir(callerFile)

	funcs := Funcs()
	for n, f := range cfg.Funcs {
		funcs[n] = f
	t := template.New("").Funcs(funcs)

	var roots []string
	if cfg.Template != "" {
		var err error
		t, err = t.New("template.gotpl").Parse(cfg.Template)
		if err != nil {
			return errors.Wrap(err, "error with provided template")
		roots = append(roots, "template.gotpl")
	} else {
		// load all the templates in the directory
		err := filepath.Walk(rootDir, func(path string, info os.FileInfo, err error) error {
			if err != nil {
				return err
			name := filepath.ToSlash(strings.TrimPrefix(path, rootDir+string(os.PathSeparator)))
			if !strings.HasSuffix(info.Name(), ".gotpl") {
				return nil
			b, err := ioutil.ReadFile(path)
			if err != nil {
				return err

			t, err = t.New(name).Parse(string(b))
			if err != nil {
				return errors.Wrap(err, cfg.Filename)

			roots = append(roots, name)

			return nil
		if err != nil {
			return errors.Wrap(err, "locating templates")

	// then execute all the important looking ones in order, adding them to the same file
	sort.Slice(roots, func(i, j int) bool {
		// important files go first
		if strings.HasSuffix(roots[i], "!.gotpl") {
			return true
		if strings.HasSuffix(roots[j], "!.gotpl") {
			return false
		return roots[i] < roots[j]
	var buf bytes.Buffer
	for _, root := range roots {
		if cfg.RegionTags {
			buf.WriteString("\n// region    " + center(70, "*", " "+root+" ") + "\n")
		err := t.Lookup(root).Execute(&buf, cfg.Data)
		if err != nil {
			return errors.Wrap(err, root)
		if cfg.RegionTags {
			buf.WriteString("\n// endregion " + center(70, "*", " "+root+" ") + "\n")

	var result bytes.Buffer
	if cfg.GeneratedHeader {
		result.WriteString("// Code generated by github.com/99designs/gqlgen, DO NOT EDIT.\n\n")
	result.WriteString("package ")
	result.WriteString("import (\n")
	_, err := buf.WriteTo(&result)
	if err != nil {
		return err
	CurrentImports = nil

	return write(cfg.Filename, result.Bytes())

func center(width int, pad string, s string) string {
	if len(s)+2 > width {
		return s
	lpad := (width - len(s)) / 2
	rpad := width - (lpad + len(s))
	return strings.Repeat(pad, lpad) + s + strings.Repeat(pad, rpad)

func Funcs() template.FuncMap {
	return template.FuncMap{
		"ucFirst":       ucFirst,
		"lcFirst":       lcFirst,
		"quote":         strconv.Quote,
		"rawQuote":      rawQuote,
		"dump":          Dump,
		"ref":           ref,
		"ts":            TypeIdentifier,
		"call":          Call,
		"prefixLines":   prefixLines,
		"notNil":        notNil,
		"reserveImport": CurrentImports.Reserve,
		"lookupImport":  CurrentImports.Lookup,
		"go":            ToGo,
		"goPrivate":     ToGoPrivate,
		"add": func(a, b int) int {
			return a + b
		"render": func(filename string, tpldata interface{}) (*bytes.Buffer, error) {
			return render(resolveName(filename, 0), tpldata)

func ucFirst(s string) string {
	if s == "" {
		return ""
	r := []rune(s)
	r[0] = unicode.ToUpper(r[0])
	return string(r)

func lcFirst(s string) string {
	if s == "" {
		return ""

	r := []rune(s)
	r[0] = unicode.ToLower(r[0])
	return string(r)

func isDelimiter(c rune) bool {
	return c == '-' || c == '_' || unicode.IsSpace(c)

func ref(p types.Type) string {
	return CurrentImports.LookupType(p)

var pkgReplacer = strings.NewReplacer(
	"/", "ᚋ",
	".", "ᚗ",
	"-", "ᚑ",

func TypeIdentifier(t types.Type) string {
	res := ""
	for {
		switch it := t.(type) {
		case *types.Pointer:
			res += "ᚖ"
			t = it.Elem()
		case *types.Slice:
			res += "ᚕ"
			t = it.Elem()
		case *types.Named:
			res += pkgReplacer.Replace(it.Obj().Pkg().Path())
			res += "ᚐ"
			res += it.Obj().Name()
			return res
		case *types.Basic:
			res += it.Name()
			return res
		case *types.Map:
			res += "map"
			return res
		case *types.Interface:
			res += "interface"
			return res
			panic(fmt.Errorf("unexpected type %T", it))

func Call(p *types.Func) string {
	pkg := CurrentImports.Lookup(p.Pkg().Path())

	if pkg != "" {
		pkg += "."

	if p.Type() != nil {
		// make sure the returned type is listed in our imports.

	return pkg + p.Name()

func ToGo(name string) string {
	runes := make([]rune, 0, len(name))

	wordWalker(name, func(info *wordInfo) {
		word := info.Word
		if info.MatchCommonInitial {
			word = strings.ToUpper(word)
		} else if !info.HasCommonInitial {
			if strings.ToUpper(word) == word || strings.ToLower(word) == word {
				// FOO or foo → Foo
				// FOo → FOo
				word = ucFirst(strings.ToLower(word))
		runes = append(runes, []rune(word)...)

	return string(runes)

func ToGoPrivate(name string) string {
	runes := make([]rune, 0, len(name))

	first := true
	wordWalker(name, func(info *wordInfo) {
		word := info.Word
		if first {
			if strings.ToUpper(word) == word || strings.ToLower(word) == word {
				// ID → id, CAMEL → camel
				word = strings.ToLower(info.Word)
			} else {
				// ITicket → iTicket
				word = lcFirst(info.Word)
			first = false
		} else if info.MatchCommonInitial {
			word = strings.ToUpper(word)
		} else if !info.HasCommonInitial {
			word = ucFirst(strings.ToLower(word))
		runes = append(runes, []rune(word)...)

	return sanitizeKeywords(string(runes))

type wordInfo struct {
	Word               string
	MatchCommonInitial bool
	HasCommonInitial   bool

// This function is based on the following code.
// https://github.com/golang/lint/blob/06c8688daad7faa9da5a0c2f163a3d14aac986ca/lint.go#L679
func wordWalker(str string, f func(*wordInfo)) {
	runes := []rune(str)
	w, i := 0, 0 // index of start of word, scan
	hasCommonInitial := false
	for i+1 <= len(runes) {
		eow := false // whether we hit the end of a word
		if i+1 == len(runes) {
			eow = true
		} else if isDelimiter(runes[i+1]) {
			// underscore; shift the remainder forward over any run of underscores
			eow = true
			n := 1
			for i+n+1 < len(runes) && isDelimiter(runes[i+n+1]) {

			// Leave at most one underscore if the underscore is between two digits
			if i+n+1 < len(runes) && unicode.IsDigit(runes[i]) && unicode.IsDigit(runes[i+n+1]) {

			copy(runes[i+1:], runes[i+n+1:])
			runes = runes[:len(runes)-n]
		} else if unicode.IsLower(runes[i]) && !unicode.IsLower(runes[i+1]) {
			// lower->non-lower
			eow = true

		// [w,i) is a word.
		word := string(runes[w:i])
		if !eow && commonInitialisms[word] && !unicode.IsLower(runes[i]) {
			// through
			// split IDFoo → ID, Foo
			// but URLs → URLs
		} else if !eow {
			if commonInitialisms[word] {
				hasCommonInitial = true

		matchCommonInitial := false
		if commonInitialisms[strings.ToUpper(word)] {
			hasCommonInitial = true
			matchCommonInitial = true

			Word:               word,
			MatchCommonInitial: matchCommonInitial,
			HasCommonInitial:   hasCommonInitial,
		hasCommonInitial = false
		w = i

var keywords = []string{

// sanitizeKeywords prevents collisions with go keywords for arguments to resolver functions
func sanitizeKeywords(name string) string {
	for _, k := range keywords {
		if name == k {
			return name + "Arg"
	return name

// commonInitialisms is a set of common initialisms.
// Only add entries that are highly unlikely to be non-initialisms.
// For instance, "ID" is fine (Freudian code is rare), but "AND" is not.
var commonInitialisms = map[string]bool{
	"ACL":   true,
	"API":   true,
	"ASCII": true,
	"CPU":   true,
	"CSS":   true,
	"DNS":   true,
	"EOF":   true,
	"GUID":  true,
	"HTML":  true,
	"HTTP":  true,
	"HTTPS": true,
	"ID":    true,
	"IP":    true,
	"JSON":  true,
	"LHS":   true,
	"QPS":   true,
	"RAM":   true,
	"RHS":   true,
	"RPC":   true,
	"SLA":   true,
	"SMTP":  true,
	"SQL":   true,
	"SSH":   true,
	"TCP":   true,
	"TLS":   true,
	"TTL":   true,
	"UDP":   true,
	"UI":    true,
	"UID":   true,
	"UUID":  true,
	"URI":   true,
	"URL":   true,
	"UTF8":  true,
	"VM":    true,
	"XML":   true,
	"XMPP":  true,
	"XSRF":  true,
	"XSS":   true,

func rawQuote(s string) string {
	return "`" + strings.Replace(s, "`", "`+\"`\"+`", -1) + "`"

func notNil(field string, data interface{}) bool {
	v := reflect.ValueOf(data)

	if v.Kind() == reflect.Ptr {
		v = v.Elem()
	if v.Kind() != reflect.Struct {
		return false
	val := v.FieldByName(field)

	return val.IsValid() && !val.IsNil()

func Dump(val interface{}) string {
	switch val := val.(type) {
	case int:
		return strconv.Itoa(val)
	case int64:
		return fmt.Sprintf("%d", val)
	case float64:
		return fmt.Sprintf("%f", val)
	case string:
		return strconv.Quote(val)
	case bool:
		return strconv.FormatBool(val)
	case nil:
		return "nil"
	case []interface{}:
		var parts []string
		for _, part := range val {
			parts = append(parts, Dump(part))
		return "[]interface{}{" + strings.Join(parts, ",") + "}"
	case map[string]interface{}:
		buf := bytes.Buffer{}
		var keys []string
		for key := range val {
			keys = append(keys, key)

		for _, key := range keys {
			data := val[key]

		return buf.String()
		panic(fmt.Errorf("unsupported type %T", val))

func prefixLines(prefix, s string) string {
	return prefix + strings.Replace(s, "\n", "\n"+prefix, -1)

func resolveName(name string, skip int) string {
	if name[0] == '.' {
		// load path relative to calling source file
		_, callerFile, _, _ := runtime.Caller(skip + 1)
		return filepath.Join(filepath.Dir(callerFile), name[1:])

	// load path relative to this directory
	_, callerFile, _, _ := runtime.Caller(0)
	return filepath.Join(filepath.Dir(callerFile), name)

func render(filename string, tpldata interface{}) (*bytes.Buffer, error) {
	t := template.New("").Funcs(Funcs())

	b, err := ioutil.ReadFile(filename)
	if err != nil {
		return nil, err

	t, err = t.New(filepath.Base(filename)).Parse(string(b))
	if err != nil {

	buf := &bytes.Buffer{}
	return buf, t.Execute(buf, tpldata)

func write(filename string, b []byte) error {
	err := os.MkdirAll(filepath.Dir(filename), 0755)
	if err != nil {
		return errors.Wrap(err, "failed to create directory")

	formatted, err := imports.Prune(filename, b)
	if err != nil {
		fmt.Fprintf(os.Stderr, "gofmt failed on %s: %s\n", filepath.Base(filename), err.Error())
		formatted = b

	err = ioutil.WriteFile(filename, formatted, 0644)
	if err != nil {
		return errors.Wrapf(err, "failed to write %s", filename)

	return nil