package validator
import (
"bytes"
"fmt"
"reflect"
"github.com/vektah/gqlparser/ast"
. "github.com/vektah/gqlparser/validator"
)
func init() {
AddRule("OverlappingFieldsCanBeMerged", func(observers *Events, addError AddErrFunc) {
/**
* Algorithm:
*
* Conflicts occur when two fields exist in a query which will produce the same
* response name, but represent differing values, thus creating a conflict.
* The algorithm below finds all conflicts via making a series of comparisons
* between fields. In order to compare as few fields as possible, this makes
* a series of comparisons "within" sets of fields and "between" sets of fields.
*
* Given any selection set, a collection produces both a set of fields by
* also including all inline fragments, as well as a list of fragments
* referenced by fragment spreads.
*
* A) Each selection set represented in the document first compares "within" its
* collected set of fields, finding any conflicts between every pair of
* overlapping fields.
* Note: This is the *only time* that a the fields "within" a set are compared
* to each other. After this only fields "between" sets are compared.
*
* B) Also, if any fragment is referenced in a selection set, then a
* comparison is made "between" the original set of fields and the
* referenced fragment.
*
* C) Also, if multiple fragments are referenced, then comparisons
* are made "between" each referenced fragment.
*
* D) When comparing "between" a set of fields and a referenced fragment, first
* a comparison is made between each field in the original set of fields and
* each field in the the referenced set of fields.
*
* E) Also, if any fragment is referenced in the referenced selection set,
* then a comparison is made "between" the original set of fields and the
* referenced fragment (recursively referring to step D).
*
* F) When comparing "between" two fragments, first a comparison is made between
* each field in the first referenced set of fields and each field in the the
* second referenced set of fields.
*
* G) Also, any fragments referenced by the first must be compared to the
* second, and any fragments referenced by the second must be compared to the
* first (recursively referring to step F).
*
* H) When comparing two fields, if both have selection sets, then a comparison
* is made "between" both selection sets, first comparing the set of fields in
* the first selection set with the set of fields in the second.
*
* I) Also, if any fragment is referenced in either selection set, then a
* comparison is made "between" the other set of fields and the
* referenced fragment.
*
* J) Also, if two fragments are referenced in both selection sets, then a
* comparison is made "between" the two fragments.
*
*/
m := &overlappingFieldsCanBeMergedManager{
comparedFragmentPairs: pairSet{data: make(map[string]map[string]bool)},
}
observers.OnOperation(func(walker *Walker, operation *ast.OperationDefinition) {
m.walker = walker
conflicts := m.findConflictsWithinSelectionSet(operation.SelectionSet)
for _, conflict := range conflicts {
conflict.addFieldsConflictMessage(addError)
}
})
observers.OnField(func(walker *Walker, field *ast.Field) {
if walker.CurrentOperation == nil {
// When checking both Operation and Fragment, errors are duplicated when processing FragmentDefinition referenced from Operation
return
}
m.walker = walker
conflicts := m.findConflictsWithinSelectionSet(field.SelectionSet)
for _, conflict := range conflicts {
conflict.addFieldsConflictMessage(addError)
}
})
observers.OnInlineFragment(func(walker *Walker, inlineFragment *ast.InlineFragment) {
m.walker = walker
conflicts := m.findConflictsWithinSelectionSet(inlineFragment.SelectionSet)
for _, conflict := range conflicts {
conflict.addFieldsConflictMessage(addError)
}
})
observers.OnFragment(func(walker *Walker, fragment *ast.FragmentDefinition) {
m.walker = walker
conflicts := m.findConflictsWithinSelectionSet(fragment.SelectionSet)
for _, conflict := range conflicts {
conflict.addFieldsConflictMessage(addError)
}
})
})
}
type pairSet struct {
data map[string]map[string]bool
}
func (pairSet *pairSet) Add(a *ast.FragmentSpread, b *ast.FragmentSpread, areMutuallyExclusive bool) {
add := func(a *ast.FragmentSpread, b *ast.FragmentSpread) {
m := pairSet.data[a.Name]
if m == nil {
m = make(map[string]bool)
pairSet.data[a.Name] = m
}
m[b.Name] = areMutuallyExclusive
}
add(a, b)
add(b, a)
}
func (pairSet *pairSet) Has(a *ast.FragmentSpread, b *ast.FragmentSpread, areMutuallyExclusive bool) bool {
am, ok := pairSet.data[a.Name]
if !ok {
return false
}
result, ok := am[b.Name]
if !ok {
return false
}
// areMutuallyExclusive being false is a superset of being true,
// hence if we want to know if this PairSet "has" these two with no
// exclusivity, we have to ensure it was added as such.
if !areMutuallyExclusive {
return !result
}
return true
}
type sequentialFieldsMap struct {
// We can't use map[string][]*ast.Field. because map is not stable...
seq []string
data map[string][]*ast.Field
}
type fieldIterateEntry struct {
ResponseName string
Fields []*ast.Field
}
func (m *sequentialFieldsMap) Push(responseName string, field *ast.Field) {
fields, ok := m.data[responseName]
if !ok {
m.seq = append(m.seq, responseName)
}
fields = append(fields, field)
m.data[responseName] = fields
}
func (m *sequentialFieldsMap) Get(responseName string) ([]*ast.Field, bool) {
fields, ok := m.data[responseName]
return fields, ok
}
func (m *sequentialFieldsMap) Iterator() [][]*ast.Field {
fieldsList := make([][]*ast.Field, 0, len(m.seq))
for _, responseName := range m.seq {
fields := m.data[responseName]
fieldsList = append(fieldsList, fields)
}
return fieldsList
}
func (m *sequentialFieldsMap) KeyValueIterator() []*fieldIterateEntry {
fieldEntriesList := make([]*fieldIterateEntry, 0, len(m.seq))
for _, responseName := range m.seq {
fields := m.data[responseName]
fieldEntriesList = append(fieldEntriesList, &fieldIterateEntry{
ResponseName: responseName,
Fields: fields,
})
}
return fieldEntriesList
}
type conflictMessageContainer struct {
Conflicts []*ConflictMessage
}
type ConflictMessage struct {
Message string
ResponseName string
Names []string
SubMessage []*ConflictMessage
Position *ast.Position
}
func (m *ConflictMessage) String(buf *bytes.Buffer) {
if len(m.SubMessage) == 0 {
buf.WriteString(m.Message)
return
}
for idx, subMessage := range m.SubMessage {
buf.WriteString(`subfields "`)
buf.WriteString(subMessage.ResponseName)
buf.WriteString(`" conflict because `)
subMessage.String(buf)
if idx != len(m.SubMessage)-1 {
buf.WriteString(" and ")
}
}
}
func (m *ConflictMessage) addFieldsConflictMessage(addError AddErrFunc) {
var buf bytes.Buffer
m.String(&buf)
addError(
Message(`Fields "%s" conflict because %s. Use different aliases on the fields to fetch both if this was intentional.`, m.ResponseName, buf.String()),
At(m.Position),
)
}
type overlappingFieldsCanBeMergedManager struct {
walker *Walker
// per walker
comparedFragmentPairs pairSet
// cachedFieldsAndFragmentNames interface{}
// per selectionSet
comparedFragments map[string]bool
}
func (m *overlappingFieldsCanBeMergedManager) findConflictsWithinSelectionSet(selectionSet ast.SelectionSet) []*ConflictMessage {
if len(selectionSet) == 0 {
return nil
}
fieldsMap, fragmentSpreads := getFieldsAndFragmentNames(selectionSet)
var conflicts conflictMessageContainer
// (A) Find find all conflicts "within" the fieldMap of this selection set.
// Note: this is the *only place* `collectConflictsWithin` is called.
m.collectConflictsWithin(&conflicts, fieldsMap)
m.comparedFragments = make(map[string]bool)
for idx, fragmentSpreadA := range fragmentSpreads {
// (B) Then collect conflicts between these fieldMap and those represented by
// each spread fragment name found.
m.collectConflictsBetweenFieldsAndFragment(&conflicts, false, fieldsMap, fragmentSpreadA)
for _, fragmentSpreadB := range fragmentSpreads[idx+1:] {
// (C) Then compare this fragment with all other fragments found in this
// selection set to collect conflicts between fragments spread together.
// This compares each item in the list of fragment names to every other
// item in that same list (except for itself).
m.collectConflictsBetweenFragments(&conflicts, false, fragmentSpreadA, fragmentSpreadB)
}
}
return conflicts.Conflicts
}
func (m *overlappingFieldsCanBeMergedManager) collectConflictsBetweenFieldsAndFragment(conflicts *conflictMessageContainer, areMutuallyExclusive bool, fieldsMap *sequentialFieldsMap, fragmentSpread *ast.FragmentSpread) {
if m.comparedFragments[fragmentSpread.Name] {
return
}
m.comparedFragments[fragmentSpread.Name] = true
if fragmentSpread.Definition == nil {
return
}
fieldsMapB, fragmentSpreads := getFieldsAndFragmentNames(fragmentSpread.Definition.SelectionSet)
// Do not compare a fragment's fieldMap to itself.
if reflect.DeepEqual(fieldsMap, fieldsMapB) {
return
}
// (D) First collect any conflicts between the provided collection of fields
// and the collection of fields represented by the given fragment.
m.collectConflictsBetween(conflicts, areMutuallyExclusive, fieldsMap, fieldsMapB)
// (E) Then collect any conflicts between the provided collection of fields
// and any fragment names found in the given fragment.
for _, fragmentSpread := range fragmentSpreads {
m.collectConflictsBetweenFieldsAndFragment(conflicts, areMutuallyExclusive, fieldsMap, fragmentSpread)
}
}
func (m *overlappingFieldsCanBeMergedManager) collectConflictsBetweenFragments(conflicts *conflictMessageContainer, areMutuallyExclusive bool, fragmentSpreadA *ast.FragmentSpread, fragmentSpreadB *ast.FragmentSpread) {
var check func(fragmentSpreadA *ast.FragmentSpread, fragmentSpreadB *ast.FragmentSpread)
check = func(fragmentSpreadA *ast.FragmentSpread, fragmentSpreadB *ast.FragmentSpread) {
if fragmentSpreadA.Name == fragmentSpreadB.Name {
return
}
if m.comparedFragmentPairs.Has(fragmentSpreadA, fragmentSpreadB, areMutuallyExclusive) {
return
}
m.comparedFragmentPairs.Add(fragmentSpreadA, fragmentSpreadB, areMutuallyExclusive)
if fragmentSpreadA.Definition == nil {
return
}
if fragmentSpreadB.Definition == nil {
return
}
fieldsMapA, fragmentSpreadsA := getFieldsAndFragmentNames(fragmentSpreadA.Definition.SelectionSet)
fieldsMapB, fragmentSpreadsB := getFieldsAndFragmentNames(fragmentSpreadB.Definition.SelectionSet)
// (F) First, collect all conflicts between these two collections of fields
// (not including any nested fragments).
m.collectConflictsBetween(conflicts, areMutuallyExclusive, fieldsMapA, fieldsMapB)
// (G) Then collect conflicts between the first fragment and any nested
// fragments spread in the second fragment.
for _, fragmentSpread := range fragmentSpreadsB {
check(fragmentSpreadA, fragmentSpread)
}
// (G) Then collect conflicts between the second fragment and any nested
// fragments spread in the first fragment.
for _, fragmentSpread := range fragmentSpreadsA {
check(fragmentSpread, fragmentSpreadB)
}
}
check(fragmentSpreadA, fragmentSpreadB)
}
func (m *overlappingFieldsCanBeMergedManager) findConflictsBetweenSubSelectionSets(areMutuallyExclusive bool, selectionSetA ast.SelectionSet, selectionSetB ast.SelectionSet) *conflictMessageContainer {
var conflicts conflictMessageContainer
fieldsMapA, fragmentSpreadsA := getFieldsAndFragmentNames(selectionSetA)
fieldsMapB, fragmentSpreadsB := getFieldsAndFragmentNames(selectionSetB)
// (H) First, collect all conflicts between these two collections of field.
m.collectConflictsBetween(&conflicts, areMutuallyExclusive, fieldsMapA, fieldsMapB)
// (I) Then collect conflicts between the first collection of fields and
// those referenced by each fragment name associated with the second.
for _, fragmentSpread := range fragmentSpreadsB {
m.comparedFragments = make(map[string]bool)
m.collectConflictsBetweenFieldsAndFragment(&conflicts, areMutuallyExclusive, fieldsMapA, fragmentSpread)
}
// (I) Then collect conflicts between the second collection of fields and
// those referenced by each fragment name associated with the first.
for _, fragmentSpread := range fragmentSpreadsA {
m.comparedFragments = make(map[string]bool)
m.collectConflictsBetweenFieldsAndFragment(&conflicts, areMutuallyExclusive, fieldsMapB, fragmentSpread)
}
// (J) Also collect conflicts between any fragment names by the first and
// fragment names by the second. This compares each item in the first set of
// names to each item in the second set of names.
for _, fragmentSpreadA := range fragmentSpreadsA {
for _, fragmentSpreadB := range fragmentSpreadsB {
m.collectConflictsBetweenFragments(&conflicts, areMutuallyExclusive, fragmentSpreadA, fragmentSpreadB)
}
}
if len(conflicts.Conflicts) == 0 {
return nil
}
return &conflicts
}
func (m *overlappingFieldsCanBeMergedManager) collectConflictsWithin(conflicts *conflictMessageContainer, fieldsMap *sequentialFieldsMap) {
for _, fields := range fieldsMap.Iterator() {
for idx, fieldA := range fields {
for _, fieldB := range fields[idx+1:] {
conflict := m.findConflict(false, fieldA, fieldB)
if conflict != nil {
conflicts.Conflicts = append(conflicts.Conflicts, conflict)
}
}
}
}
}
func (m *overlappingFieldsCanBeMergedManager) collectConflictsBetween(conflicts *conflictMessageContainer, parentFieldsAreMutuallyExclusive bool, fieldsMapA *sequentialFieldsMap, fieldsMapB *sequentialFieldsMap) {
for _, fieldsEntryA := range fieldsMapA.KeyValueIterator() {
fieldsB, ok := fieldsMapB.Get(fieldsEntryA.ResponseName)
if !ok {
continue
}
for _, fieldA := range fieldsEntryA.Fields {
for _, fieldB := range fieldsB {
conflict := m.findConflict(parentFieldsAreMutuallyExclusive, fieldA, fieldB)
if conflict != nil {
conflicts.Conflicts = append(conflicts.Conflicts, conflict)
}
}
}
}
}
func (m *overlappingFieldsCanBeMergedManager) findConflict(parentFieldsAreMutuallyExclusive bool, fieldA *ast.Field, fieldB *ast.Field) *ConflictMessage {
if fieldA.Definition == nil || fieldA.ObjectDefinition == nil || fieldB.Definition == nil || fieldB.ObjectDefinition == nil {
return nil
}
areMutuallyExclusive := parentFieldsAreMutuallyExclusive
if !areMutuallyExclusive {
tmp := fieldA.ObjectDefinition.Name != fieldB.ObjectDefinition.Name
tmp = tmp && fieldA.ObjectDefinition.Kind == ast.Object
tmp = tmp && fieldB.ObjectDefinition.Kind == ast.Object
areMutuallyExclusive = tmp
}
fieldNameA := fieldA.Name
if fieldA.Alias != "" {
fieldNameA = fieldA.Alias
}
if !areMutuallyExclusive {
// Two aliases must refer to the same field.
if fieldA.Name != fieldB.Name {
return &ConflictMessage{
ResponseName: fieldNameA,
Message: fmt.Sprintf(`%s and %s are different fields`, fieldA.Name, fieldB.Name),
Position: fieldB.Position,
}
}
// Two field calls must have the same arguments.
if !sameArguments(fieldA.Arguments, fieldB.Arguments) {
return &ConflictMessage{
ResponseName: fieldNameA,
Message: "they have differing arguments",
Position: fieldB.Position,
}
}
}
if doTypesConflict(m.walker, fieldA.Definition.Type, fieldB.Definition.Type) {
return &ConflictMessage{
ResponseName: fieldNameA,
Message: fmt.Sprintf(`they return conflicting types %s and %s`, fieldA.Definition.Type.String(), fieldB.Definition.Type.String()),
Position: fieldB.Position,
}
}
// Collect and compare sub-fields. Use the same "visited fragment names" list
// for both collections so fields in a fragment reference are never
// compared to themselves.
conflicts := m.findConflictsBetweenSubSelectionSets(areMutuallyExclusive, fieldA.SelectionSet, fieldB.SelectionSet)
if conflicts == nil {
return nil
}
return &ConflictMessage{
ResponseName: fieldNameA,
SubMessage: conflicts.Conflicts,
Position: fieldB.Position,
}
}
func sameArguments(args1 []*ast.Argument, args2 []*ast.Argument) bool {
if len(args1) != len(args2) {
return false
}
for _, arg1 := range args1 {
for _, arg2 := range args2 {
if arg1.Name != arg2.Name {
return false
}
if !sameValue(arg1.Value, arg2.Value) {
return false
}
}
}
return true
}
func sameValue(value1 *ast.Value, value2 *ast.Value) bool {
if value1.Kind != value2.Kind {
return false
}
if value1.Raw != value2.Raw {
return false
}
return true
}
func doTypesConflict(walker *Walker, type1 *ast.Type, type2 *ast.Type) bool {
if type1.Elem != nil {
if type2.Elem != nil {
return doTypesConflict(walker, type1.Elem, type2.Elem)
}
return true
}
if type2.Elem != nil {
return true
}
if type1.NonNull && !type2.NonNull {
return true
}
if !type1.NonNull && type2.NonNull {
return true
}
t1 := walker.Schema.Types[type1.NamedType]
t2 := walker.Schema.Types[type2.NamedType]
if (t1.Kind == ast.Scalar || t1.Kind == ast.Enum) && (t2.Kind == ast.Scalar || t2.Kind == ast.Enum) {
return t1.Name != t2.Name
}
return false
}
func getFieldsAndFragmentNames(selectionSet ast.SelectionSet) (*sequentialFieldsMap, []*ast.FragmentSpread) {
fieldsMap := sequentialFieldsMap{
data: make(map[string][]*ast.Field),
}
var fragmentSpreads []*ast.FragmentSpread
var walk func(selectionSet ast.SelectionSet)
walk = func(selectionSet ast.SelectionSet) {
for _, selection := range selectionSet {
switch selection := selection.(type) {
case *ast.Field:
responseName := selection.Name
if selection.Alias != "" {
responseName = selection.Alias
}
fieldsMap.Push(responseName, selection)
case *ast.InlineFragment:
walk(selection.SelectionSet)
case *ast.FragmentSpread:
fragmentSpreads = append(fragmentSpreads, selection)
}
}
}
walk(selectionSet)
return &fieldsMap, fragmentSpreads
}