package validator
import (
"context"
"fmt"
"github.com/vektah/gqlparser/ast"
)
type Events struct {
operationVisitor []func(walker *Walker, operation *ast.OperationDefinition)
field []func(walker *Walker, field *ast.Field)
fragment []func(walker *Walker, fragment *ast.FragmentDefinition)
inlineFragment []func(walker *Walker, inlineFragment *ast.InlineFragment)
fragmentSpread []func(walker *Walker, fragmentSpread *ast.FragmentSpread)
directive []func(walker *Walker, directive *ast.Directive)
directiveList []func(walker *Walker, directives []*ast.Directive)
value []func(walker *Walker, value *ast.Value)
}
func (o *Events) OnOperation(f func(walker *Walker, operation *ast.OperationDefinition)) {
o.operationVisitor = append(o.operationVisitor, f)
}
func (o *Events) OnField(f func(walker *Walker, field *ast.Field)) {
o.field = append(o.field, f)
}
func (o *Events) OnFragment(f func(walker *Walker, fragment *ast.FragmentDefinition)) {
o.fragment = append(o.fragment, f)
}
func (o *Events) OnInlineFragment(f func(walker *Walker, inlineFragment *ast.InlineFragment)) {
o.inlineFragment = append(o.inlineFragment, f)
}
func (o *Events) OnFragmentSpread(f func(walker *Walker, fragmentSpread *ast.FragmentSpread)) {
o.fragmentSpread = append(o.fragmentSpread, f)
}
func (o *Events) OnDirective(f func(walker *Walker, directive *ast.Directive)) {
o.directive = append(o.directive, f)
}
func (o *Events) OnDirectiveList(f func(walker *Walker, directives []*ast.Directive)) {
o.directiveList = append(o.directiveList, f)
}
func (o *Events) OnValue(f func(walker *Walker, value *ast.Value)) {
o.value = append(o.value, f)
}
func Walk(schema *ast.Schema, document *ast.QueryDocument, observers *Events) {
w := Walker{
Observers: observers,
Schema: schema,
Document: document,
}
w.walk()
}
type Walker struct {
Context context.Context
Observers *Events
Schema *ast.Schema
Document *ast.QueryDocument
validatedFragmentSpreads map[string]bool
CurrentOperation *ast.OperationDefinition
}
func (w *Walker) walk() {
for _, child := range w.Document.Operations {
w.validatedFragmentSpreads = make(map[string]bool)
w.walkOperation(child)
}
for _, child := range w.Document.Fragments {
w.validatedFragmentSpreads = make(map[string]bool)
w.walkFragment(child)
}
}
func (w *Walker) walkOperation(operation *ast.OperationDefinition) {
w.CurrentOperation = operation
for _, varDef := range operation.VariableDefinitions {
varDef.Definition = w.Schema.Types[varDef.Type.Name()]
if varDef.DefaultValue != nil {
varDef.DefaultValue.ExpectedType = varDef.Type
varDef.DefaultValue.Definition = w.Schema.Types[varDef.Type.Name()]
}
}
var def *ast.Definition
var loc ast.DirectiveLocation
switch operation.Operation {
case ast.Query, "":
def = w.Schema.Query
loc = ast.LocationQuery
case ast.Mutation:
def = w.Schema.Mutation
loc = ast.LocationMutation
case ast.Subscription:
def = w.Schema.Subscription
loc = ast.LocationSubscription
}
w.walkDirectives(def, operation.Directives, loc)
for _, varDef := range operation.VariableDefinitions {
if varDef.DefaultValue != nil {
w.walkValue(varDef.DefaultValue)
}
}
w.walkSelectionSet(def, operation.SelectionSet)
for _, v := range w.Observers.operationVisitor {
v(w, operation)
}
w.CurrentOperation = nil
}
func (w *Walker) walkFragment(it *ast.FragmentDefinition) {
def := w.Schema.Types[it.TypeCondition]
it.Definition = def
w.walkDirectives(def, it.Directives, ast.LocationFragmentDefinition)
w.walkSelectionSet(def, it.SelectionSet)
for _, v := range w.Observers.fragment {
v(w, it)
}
}
func (w *Walker) walkDirectives(parentDef *ast.Definition, directives []*ast.Directive, location ast.DirectiveLocation) {
for _, dir := range directives {
def := w.Schema.Directives[dir.Name]
dir.Definition = def
dir.ParentDefinition = parentDef
dir.Location = location
for _, arg := range dir.Arguments {
var argDef *ast.ArgumentDefinition
if def != nil {
argDef = def.Arguments.ForName(arg.Name)
}
w.walkArgument(argDef, arg)
}
for _, v := range w.Observers.directive {
v(w, dir)
}
}
for _, v := range w.Observers.directiveList {
v(w, directives)
}
}
func (w *Walker) walkValue(value *ast.Value) {
if value.Kind == ast.Variable && w.CurrentOperation != nil {
value.VariableDefinition = w.CurrentOperation.VariableDefinitions.ForName(value.Raw)
if value.VariableDefinition != nil {
value.VariableDefinition.Used = true
}
}
if value.Kind == ast.ObjectValue {
for _, child := range value.Children {
if value.Definition != nil {
fieldDef := value.Definition.Fields.ForName(child.Name)
if fieldDef != nil {
child.Value.ExpectedType = fieldDef.Type
child.Value.Definition = w.Schema.Types[fieldDef.Type.Name()]
}
}
w.walkValue(child.Value)
}
}
if value.Kind == ast.ListValue {
for _, child := range value.Children {
if value.ExpectedType != nil && value.ExpectedType.Elem != nil {
child.Value.ExpectedType = value.ExpectedType.Elem
child.Value.Definition = value.Definition
}
w.walkValue(child.Value)
}
}
for _, v := range w.Observers.value {
v(w, value)
}
}
func (w *Walker) walkArgument(argDef *ast.ArgumentDefinition, arg *ast.Argument) {
if argDef != nil {
arg.Value.ExpectedType = argDef.Type
arg.Value.Definition = w.Schema.Types[argDef.Type.Name()]
}
w.walkValue(arg.Value)
}
func (w *Walker) walkSelectionSet(parentDef *ast.Definition, it ast.SelectionSet) {
for _, child := range it {
w.walkSelection(parentDef, child)
}
}
func (w *Walker) walkSelection(parentDef *ast.Definition, it ast.Selection) {
switch it := it.(type) {
case *ast.Field:
var def *ast.FieldDefinition
if it.Name == "__typename" {
def = &ast.FieldDefinition{
Name: "__typename",
Type: ast.NamedType("String", nil),
}
} else if parentDef != nil {
def = parentDef.Fields.ForName(it.Name)
}
it.Definition = def
it.ObjectDefinition = parentDef
var nextParentDef *ast.Definition
if def != nil {
nextParentDef = w.Schema.Types[def.Type.Name()]
}
for _, arg := range it.Arguments {
var argDef *ast.ArgumentDefinition
if def != nil {
argDef = def.Arguments.ForName(arg.Name)
}
w.walkArgument(argDef, arg)
}
w.walkDirectives(nextParentDef, it.Directives, ast.LocationField)
w.walkSelectionSet(nextParentDef, it.SelectionSet)
for _, v := range w.Observers.field {
v(w, it)
}
case *ast.InlineFragment:
it.ObjectDefinition = parentDef
nextParentDef := parentDef
if it.TypeCondition != "" {
nextParentDef = w.Schema.Types[it.TypeCondition]
}
w.walkDirectives(nextParentDef, it.Directives, ast.LocationInlineFragment)
w.walkSelectionSet(nextParentDef, it.SelectionSet)
for _, v := range w.Observers.inlineFragment {
v(w, it)
}
case *ast.FragmentSpread:
def := w.Document.Fragments.ForName(it.Name)
it.Definition = def
it.ObjectDefinition = parentDef
var nextParentDef *ast.Definition
if def != nil {
nextParentDef = w.Schema.Types[def.TypeCondition]
}
w.walkDirectives(nextParentDef, it.Directives, ast.LocationFragmentSpread)
if def != nil && !w.validatedFragmentSpreads[def.Name] {
// prevent inifinite recursion
w.validatedFragmentSpreads[def.Name] = true
w.walkSelectionSet(nextParentDef, def.SelectionSet)
}
for _, v := range w.Observers.fragmentSpread {
v(w, it)
}
default:
panic(fmt.Errorf("unsupported %T", it))
}
}