package complexity import ( "github.com/99designs/gqlgen/graphql" "github.com/vektah/gqlparser/ast" ) func Calculate(es graphql.ExecutableSchema, op *ast.OperationDefinition, vars map[string]interface{}) int { walker := complexityWalker{ es: es, schema: es.Schema(), vars: vars, } return walker.selectionSetComplexity(op.SelectionSet) } type complexityWalker struct { es graphql.ExecutableSchema schema *ast.Schema vars map[string]interface{} } func (cw complexityWalker) selectionSetComplexity(selectionSet ast.SelectionSet) int { var complexity int for _, selection := range selectionSet { switch s := selection.(type) { case *ast.Field: fieldDefinition := cw.schema.Types[s.Definition.Type.Name()] var childComplexity int switch fieldDefinition.Kind { case ast.Object, ast.Interface, ast.Union: childComplexity = cw.selectionSetComplexity(s.SelectionSet) } args := s.ArgumentMap(cw.vars) var fieldComplexity int if s.ObjectDefinition.Kind == ast.Interface { fieldComplexity = cw.interfaceFieldComplexity(s.ObjectDefinition, s.Name, childComplexity, args) } else { fieldComplexity = cw.fieldComplexity(s.ObjectDefinition.Name, s.Name, childComplexity, args) } complexity = safeAdd(complexity, fieldComplexity) case *ast.FragmentSpread: complexity = safeAdd(complexity, cw.selectionSetComplexity(s.Definition.SelectionSet)) case *ast.InlineFragment: complexity = safeAdd(complexity, cw.selectionSetComplexity(s.SelectionSet)) } } return complexity } func (cw complexityWalker) interfaceFieldComplexity(def *ast.Definition, field string, childComplexity int, args map[string]interface{}) int { // Interfaces don't have their own separate field costs, so they have to assume the worst case. // We iterate over all implementors and choose the most expensive one. maxComplexity := 0 implementors := cw.schema.GetPossibleTypes(def) for _, t := range implementors { fieldComplexity := cw.fieldComplexity(t.Name, field, childComplexity, args) if fieldComplexity > maxComplexity { maxComplexity = fieldComplexity } } return maxComplexity } func (cw complexityWalker) fieldComplexity(object, field string, childComplexity int, args map[string]interface{}) int { if customComplexity, ok := cw.es.Complexity(object, field, childComplexity, args); ok && customComplexity >= childComplexity { return customComplexity } // default complexity calculation return safeAdd(1, childComplexity) } const maxInt = int(^uint(0) >> 1) // safeAdd is a saturating add of a and b that ignores negative operands. // If a + b would overflow through normal Go addition, // it returns the maximum integer value instead. // // Adding complexities with this function prevents attackers from intentionally // overflowing the complexity calculation to allow overly-complex queries. // // It also helps mitigate the impact of custom complexities that accidentally // return negative values. func safeAdd(a, b int) int { // Ignore negative operands. if a < 0 { if b < 0 { return 1 } return b } else if b < 0 { return a } c := a + b if c < a { // Set c to maximum integer instead of overflowing. c = maxInt } return c }