aboutsummaryrefslogtreecommitdiffstats
path: root/vendor/github.com/99designs/gqlgen/complexity/complexity.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/99designs/gqlgen/complexity/complexity.go')
-rw-r--r--vendor/github.com/99designs/gqlgen/complexity/complexity.go104
1 files changed, 104 insertions, 0 deletions
diff --git a/vendor/github.com/99designs/gqlgen/complexity/complexity.go b/vendor/github.com/99designs/gqlgen/complexity/complexity.go
new file mode 100644
index 00000000..d5b46bf4
--- /dev/null
+++ b/vendor/github.com/99designs/gqlgen/complexity/complexity.go
@@ -0,0 +1,104 @@
+package complexity
+
+import (
+ "github.com/99designs/gqlgen/graphql"
+ "github.com/vektah/gqlparser/ast"
+)
+
+func Calculate(es graphql.ExecutableSchema, op *ast.OperationDefinition, vars map[string]interface{}) int {
+ walker := complexityWalker{
+ es: es,
+ schema: es.Schema(),
+ vars: vars,
+ }
+ return walker.selectionSetComplexity(op.SelectionSet)
+}
+
+type complexityWalker struct {
+ es graphql.ExecutableSchema
+ schema *ast.Schema
+ vars map[string]interface{}
+}
+
+func (cw complexityWalker) selectionSetComplexity(selectionSet ast.SelectionSet) int {
+ var complexity int
+ for _, selection := range selectionSet {
+ switch s := selection.(type) {
+ case *ast.Field:
+ fieldDefinition := cw.schema.Types[s.Definition.Type.Name()]
+ var childComplexity int
+ switch fieldDefinition.Kind {
+ case ast.Object, ast.Interface, ast.Union:
+ childComplexity = cw.selectionSetComplexity(s.SelectionSet)
+ }
+
+ args := s.ArgumentMap(cw.vars)
+ var fieldComplexity int
+ if s.ObjectDefinition.Kind == ast.Interface {
+ fieldComplexity = cw.interfaceFieldComplexity(s.ObjectDefinition, s.Name, childComplexity, args)
+ } else {
+ fieldComplexity = cw.fieldComplexity(s.ObjectDefinition.Name, s.Name, childComplexity, args)
+ }
+ complexity = safeAdd(complexity, fieldComplexity)
+
+ case *ast.FragmentSpread:
+ complexity = safeAdd(complexity, cw.selectionSetComplexity(s.Definition.SelectionSet))
+
+ case *ast.InlineFragment:
+ complexity = safeAdd(complexity, cw.selectionSetComplexity(s.SelectionSet))
+ }
+ }
+ return complexity
+}
+
+func (cw complexityWalker) interfaceFieldComplexity(def *ast.Definition, field string, childComplexity int, args map[string]interface{}) int {
+ // Interfaces don't have their own separate field costs, so they have to assume the worst case.
+ // We iterate over all implementors and choose the most expensive one.
+ maxComplexity := 0
+ implementors := cw.schema.GetPossibleTypes(def)
+ for _, t := range implementors {
+ fieldComplexity := cw.fieldComplexity(t.Name, field, childComplexity, args)
+ if fieldComplexity > maxComplexity {
+ maxComplexity = fieldComplexity
+ }
+ }
+ return maxComplexity
+}
+
+func (cw complexityWalker) fieldComplexity(object, field string, childComplexity int, args map[string]interface{}) int {
+ if customComplexity, ok := cw.es.Complexity(object, field, childComplexity, args); ok && customComplexity >= childComplexity {
+ return customComplexity
+ }
+ // default complexity calculation
+ return safeAdd(1, childComplexity)
+}
+
+const maxInt = int(^uint(0) >> 1)
+
+// safeAdd is a saturating add of a and b that ignores negative operands.
+// If a + b would overflow through normal Go addition,
+// it returns the maximum integer value instead.
+//
+// Adding complexities with this function prevents attackers from intentionally
+// overflowing the complexity calculation to allow overly-complex queries.
+//
+// It also helps mitigate the impact of custom complexities that accidentally
+// return negative values.
+func safeAdd(a, b int) int {
+ // Ignore negative operands.
+ if a < 0 {
+ if b < 0 {
+ return 1
+ }
+ return b
+ } else if b < 0 {
+ return a
+ }
+
+ c := a + b
+ if c < a {
+ // Set c to maximum integer instead of overflowing.
+ c = maxInt
+ }
+ return c
+}