aboutsummaryrefslogblamecommitdiffstats
path: root/vendor/github.com/vektah/gqlparser/validator/walk.go
blob: 751ba1f117bf8dda255d616e10a4c485385dd500 (plain) (tree)





























































































































































































































































































                                                                                                                         
package validator

import (
	"context"
	"fmt"

	"github.com/vektah/gqlparser/ast"
)

type Events struct {
	operationVisitor []func(walker *Walker, operation *ast.OperationDefinition)
	field            []func(walker *Walker, field *ast.Field)
	fragment         []func(walker *Walker, fragment *ast.FragmentDefinition)
	inlineFragment   []func(walker *Walker, inlineFragment *ast.InlineFragment)
	fragmentSpread   []func(walker *Walker, fragmentSpread *ast.FragmentSpread)
	directive        []func(walker *Walker, directive *ast.Directive)
	directiveList    []func(walker *Walker, directives []*ast.Directive)
	value            []func(walker *Walker, value *ast.Value)
}

func (o *Events) OnOperation(f func(walker *Walker, operation *ast.OperationDefinition)) {
	o.operationVisitor = append(o.operationVisitor, f)
}
func (o *Events) OnField(f func(walker *Walker, field *ast.Field)) {
	o.field = append(o.field, f)
}
func (o *Events) OnFragment(f func(walker *Walker, fragment *ast.FragmentDefinition)) {
	o.fragment = append(o.fragment, f)
}
func (o *Events) OnInlineFragment(f func(walker *Walker, inlineFragment *ast.InlineFragment)) {
	o.inlineFragment = append(o.inlineFragment, f)
}
func (o *Events) OnFragmentSpread(f func(walker *Walker, fragmentSpread *ast.FragmentSpread)) {
	o.fragmentSpread = append(o.fragmentSpread, f)
}
func (o *Events) OnDirective(f func(walker *Walker, directive *ast.Directive)) {
	o.directive = append(o.directive, f)
}
func (o *Events) OnDirectiveList(f func(walker *Walker, directives []*ast.Directive)) {
	o.directiveList = append(o.directiveList, f)
}
func (o *Events) OnValue(f func(walker *Walker, value *ast.Value)) {
	o.value = append(o.value, f)
}

func Walk(schema *ast.Schema, document *ast.QueryDocument, observers *Events) {
	w := Walker{
		Observers: observers,
		Schema:    schema,
		Document:  document,
	}

	w.walk()
}

type Walker struct {
	Context   context.Context
	Observers *Events
	Schema    *ast.Schema
	Document  *ast.QueryDocument

	validatedFragmentSpreads map[string]bool
	CurrentOperation         *ast.OperationDefinition
}

func (w *Walker) walk() {
	for _, child := range w.Document.Operations {
		w.validatedFragmentSpreads = make(map[string]bool)
		w.walkOperation(child)
	}
	for _, child := range w.Document.Fragments {
		w.validatedFragmentSpreads = make(map[string]bool)
		w.walkFragment(child)
	}
}

func (w *Walker) walkOperation(operation *ast.OperationDefinition) {
	w.CurrentOperation = operation
	for _, varDef := range operation.VariableDefinitions {
		varDef.Definition = w.Schema.Types[varDef.Type.Name()]

		if varDef.DefaultValue != nil {
			varDef.DefaultValue.ExpectedType = varDef.Type
			varDef.DefaultValue.Definition = w.Schema.Types[varDef.Type.Name()]
		}
	}

	var def *ast.Definition
	var loc ast.DirectiveLocation
	switch operation.Operation {
	case ast.Query, "":
		def = w.Schema.Query
		loc = ast.LocationQuery
	case ast.Mutation:
		def = w.Schema.Mutation
		loc = ast.LocationMutation
	case ast.Subscription:
		def = w.Schema.Subscription
		loc = ast.LocationSubscription
	}

	w.walkDirectives(def, operation.Directives, loc)

	for _, varDef := range operation.VariableDefinitions {
		if varDef.DefaultValue != nil {
			w.walkValue(varDef.DefaultValue)
		}
	}

	w.walkSelectionSet(def, operation.SelectionSet)

	for _, v := range w.Observers.operationVisitor {
		v(w, operation)
	}
	w.CurrentOperation = nil
}

func (w *Walker) walkFragment(it *ast.FragmentDefinition) {
	def := w.Schema.Types[it.TypeCondition]

	it.Definition = def

	w.walkDirectives(def, it.Directives, ast.LocationFragmentDefinition)
	w.walkSelectionSet(def, it.SelectionSet)

	for _, v := range w.Observers.fragment {
		v(w, it)
	}
}

func (w *Walker) walkDirectives(parentDef *ast.Definition, directives []*ast.Directive, location ast.DirectiveLocation) {
	for _, dir := range directives {
		def := w.Schema.Directives[dir.Name]
		dir.Definition = def
		dir.ParentDefinition = parentDef
		dir.Location = location

		for _, arg := range dir.Arguments {
			var argDef *ast.ArgumentDefinition
			if def != nil {
				argDef = def.Arguments.ForName(arg.Name)
			}

			w.walkArgument(argDef, arg)
		}

		for _, v := range w.Observers.directive {
			v(w, dir)
		}
	}

	for _, v := range w.Observers.directiveList {
		v(w, directives)
	}
}

func (w *Walker) walkValue(value *ast.Value) {
	if value.Kind == ast.Variable && w.CurrentOperation != nil {
		value.VariableDefinition = w.CurrentOperation.VariableDefinitions.ForName(value.Raw)
		if value.VariableDefinition != nil {
			value.VariableDefinition.Used = true
		}
	}

	if value.Kind == ast.ObjectValue {
		for _, child := range value.Children {
			if value.Definition != nil {
				fieldDef := value.Definition.Fields.ForName(child.Name)
				if fieldDef != nil {
					child.Value.ExpectedType = fieldDef.Type
					child.Value.Definition = w.Schema.Types[fieldDef.Type.Name()]
				}
			}
			w.walkValue(child.Value)
		}
	}

	if value.Kind == ast.ListValue {
		for _, child := range value.Children {
			if value.ExpectedType != nil && value.ExpectedType.Elem != nil {
				child.Value.ExpectedType = value.ExpectedType.Elem
				child.Value.Definition = value.Definition
			}

			w.walkValue(child.Value)
		}
	}

	for _, v := range w.Observers.value {
		v(w, value)
	}
}

func (w *Walker) walkArgument(argDef *ast.ArgumentDefinition, arg *ast.Argument) {
	if argDef != nil {
		arg.Value.ExpectedType = argDef.Type
		arg.Value.Definition = w.Schema.Types[argDef.Type.Name()]
	}

	w.walkValue(arg.Value)
}

func (w *Walker) walkSelectionSet(parentDef *ast.Definition, it ast.SelectionSet) {
	for _, child := range it {
		w.walkSelection(parentDef, child)
	}
}

func (w *Walker) walkSelection(parentDef *ast.Definition, it ast.Selection) {
	switch it := it.(type) {
	case *ast.Field:
		var def *ast.FieldDefinition
		if it.Name == "__typename" {
			def = &ast.FieldDefinition{
				Name: "__typename",
				Type: ast.NamedType("String", nil),
			}
		} else if parentDef != nil {
			def = parentDef.Fields.ForName(it.Name)
		}

		it.Definition = def
		it.ObjectDefinition = parentDef

		var nextParentDef *ast.Definition
		if def != nil {
			nextParentDef = w.Schema.Types[def.Type.Name()]
		}

		for _, arg := range it.Arguments {
			var argDef *ast.ArgumentDefinition
			if def != nil {
				argDef = def.Arguments.ForName(arg.Name)
			}

			w.walkArgument(argDef, arg)
		}

		w.walkDirectives(nextParentDef, it.Directives, ast.LocationField)
		w.walkSelectionSet(nextParentDef, it.SelectionSet)

		for _, v := range w.Observers.field {
			v(w, it)
		}

	case *ast.InlineFragment:
		it.ObjectDefinition = parentDef

		nextParentDef := parentDef
		if it.TypeCondition != "" {
			nextParentDef = w.Schema.Types[it.TypeCondition]
		}

		w.walkDirectives(nextParentDef, it.Directives, ast.LocationInlineFragment)
		w.walkSelectionSet(nextParentDef, it.SelectionSet)

		for _, v := range w.Observers.inlineFragment {
			v(w, it)
		}

	case *ast.FragmentSpread:
		def := w.Document.Fragments.ForName(it.Name)
		it.Definition = def
		it.ObjectDefinition = parentDef

		var nextParentDef *ast.Definition
		if def != nil {
			nextParentDef = w.Schema.Types[def.TypeCondition]
		}

		w.walkDirectives(nextParentDef, it.Directives, ast.LocationFragmentSpread)

		if def != nil && !w.validatedFragmentSpreads[def.Name] {
			// prevent inifinite recursion
			w.validatedFragmentSpreads[def.Name] = true
			w.walkSelectionSet(nextParentDef, def.SelectionSet)
		}

		for _, v := range w.Observers.fragmentSpread {
			v(w, it)
		}

	default:
		panic(fmt.Errorf("unsupported %T", it))
	}
}