aboutsummaryrefslogblamecommitdiffstats
path: root/vendor/github.com/99designs/gqlgen/complexity/complexity.go
blob: d5b46bf451a4c4811c65b6825a981a508b1beef4 (plain) (tree)







































































































                                                                                                                                              
package complexity

import (
	"github.com/99designs/gqlgen/graphql"
	"github.com/vektah/gqlparser/ast"
)

func Calculate(es graphql.ExecutableSchema, op *ast.OperationDefinition, vars map[string]interface{}) int {
	walker := complexityWalker{
		es:     es,
		schema: es.Schema(),
		vars:   vars,
	}
	return walker.selectionSetComplexity(op.SelectionSet)
}

type complexityWalker struct {
	es     graphql.ExecutableSchema
	schema *ast.Schema
	vars   map[string]interface{}
}

func (cw complexityWalker) selectionSetComplexity(selectionSet ast.SelectionSet) int {
	var complexity int
	for _, selection := range selectionSet {
		switch s := selection.(type) {
		case *ast.Field:
			fieldDefinition := cw.schema.Types[s.Definition.Type.Name()]
			var childComplexity int
			switch fieldDefinition.Kind {
			case ast.Object, ast.Interface, ast.Union:
				childComplexity = cw.selectionSetComplexity(s.SelectionSet)
			}

			args := s.ArgumentMap(cw.vars)
			var fieldComplexity int
			if s.ObjectDefinition.Kind == ast.Interface {
				fieldComplexity = cw.interfaceFieldComplexity(s.ObjectDefinition, s.Name, childComplexity, args)
			} else {
				fieldComplexity = cw.fieldComplexity(s.ObjectDefinition.Name, s.Name, childComplexity, args)
			}
			complexity = safeAdd(complexity, fieldComplexity)

		case *ast.FragmentSpread:
			complexity = safeAdd(complexity, cw.selectionSetComplexity(s.Definition.SelectionSet))

		case *ast.InlineFragment:
			complexity = safeAdd(complexity, cw.selectionSetComplexity(s.SelectionSet))
		}
	}
	return complexity
}

func (cw complexityWalker) interfaceFieldComplexity(def *ast.Definition, field string, childComplexity int, args map[string]interface{}) int {
	// Interfaces don't have their own separate field costs, so they have to assume the worst case.
	// We iterate over all implementors and choose the most expensive one.
	maxComplexity := 0
	implementors := cw.schema.GetPossibleTypes(def)
	for _, t := range implementors {
		fieldComplexity := cw.fieldComplexity(t.Name, field, childComplexity, args)
		if fieldComplexity > maxComplexity {
			maxComplexity = fieldComplexity
		}
	}
	return maxComplexity
}

func (cw complexityWalker) fieldComplexity(object, field string, childComplexity int, args map[string]interface{}) int {
	if customComplexity, ok := cw.es.Complexity(object, field, childComplexity, args); ok && customComplexity >= childComplexity {
		return customComplexity
	}
	// default complexity calculation
	return safeAdd(1, childComplexity)
}

const maxInt = int(^uint(0) >> 1)

// safeAdd is a saturating add of a and b that ignores negative operands.
// If a + b would overflow through normal Go addition,
// it returns the maximum integer value instead.
//
// Adding complexities with this function prevents attackers from intentionally
// overflowing the complexity calculation to allow overly-complex queries.
//
// It also helps mitigate the impact of custom complexities that accidentally
// return negative values.
func safeAdd(a, b int) int {
	// Ignore negative operands.
	if a < 0 {
		if b < 0 {
			return 1
		}
		return b
	} else if b < 0 {
		return a
	}

	c := a + b
	if c < a {
		// Set c to maximum integer instead of overflowing.
		c = maxInt
	}
	return c
}