aboutsummaryrefslogtreecommitdiffstats
path: root/vendor/github.com/99designs/gqlgen/graphql/context.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/99designs/gqlgen/graphql/context.go')
-rw-r--r--vendor/github.com/99designs/gqlgen/graphql/context.go69
1 files changed, 62 insertions, 7 deletions
diff --git a/vendor/github.com/99designs/gqlgen/graphql/context.go b/vendor/github.com/99designs/gqlgen/graphql/context.go
index f83fa36f..58d3c741 100644
--- a/vendor/github.com/99designs/gqlgen/graphql/context.go
+++ b/vendor/github.com/99designs/gqlgen/graphql/context.go
@@ -12,6 +12,7 @@ import (
type Resolver func(ctx context.Context) (res interface{}, err error)
type FieldMiddleware func(ctx context.Context, next Resolver) (res interface{}, err error)
type RequestMiddleware func(ctx context.Context, next func(ctx context.Context) []byte) []byte
+type ComplexityLimitFunc func(ctx context.Context) int
type RequestContext struct {
RawQuery string
@@ -71,12 +72,10 @@ const (
)
func GetRequestContext(ctx context.Context) *RequestContext {
- val := ctx.Value(request)
- if val == nil {
- return nil
+ if val, ok := ctx.Value(request).(*RequestContext); ok {
+ return val
}
-
- return val.(*RequestContext)
+ return nil
}
func WithRequestContext(ctx context.Context, rc *RequestContext) context.Context {
@@ -95,6 +94,8 @@ type ResolverContext struct {
Index *int
// The result object of resolver
Result interface{}
+ // IsMethod indicates if the resolver is a method
+ IsMethod bool
}
func (r *ResolverContext) Path() []interface{} {
@@ -117,8 +118,10 @@ func (r *ResolverContext) Path() []interface{} {
}
func GetResolverContext(ctx context.Context) *ResolverContext {
- val, _ := ctx.Value(resolver).(*ResolverContext)
- return val
+ if val, ok := ctx.Value(resolver).(*ResolverContext); ok {
+ return val
+ }
+ return nil
}
func WithResolverContext(ctx context.Context, rc *ResolverContext) context.Context {
@@ -132,6 +135,24 @@ func CollectFieldsCtx(ctx context.Context, satisfies []string) []CollectedField
return CollectFields(ctx, resctx.Field.Selections, satisfies)
}
+// CollectAllFields returns a slice of all GraphQL field names that were selected for the current resolver context.
+// The slice will contain the unique set of all field names requested regardless of fragment type conditions.
+func CollectAllFields(ctx context.Context) []string {
+ resctx := GetResolverContext(ctx)
+ collected := CollectFields(ctx, resctx.Field.Selections, nil)
+ uniq := make([]string, 0, len(collected))
+Next:
+ for _, f := range collected {
+ for _, name := range uniq {
+ if name == f.Name {
+ continue Next
+ }
+ }
+ uniq = append(uniq, f.Name)
+ }
+ return uniq
+}
+
// Errorf sends an error string to the client, passing it through the formatter.
func (c *RequestContext) Errorf(ctx context.Context, format string, args ...interface{}) {
c.errorsMu.Lock()
@@ -217,3 +238,37 @@ func (c *RequestContext) RegisterExtension(key string, value interface{}) error
c.Extensions[key] = value
return nil
}
+
+// ChainFieldMiddleware add chain by FieldMiddleware
+func ChainFieldMiddleware(handleFunc ...FieldMiddleware) FieldMiddleware {
+ n := len(handleFunc)
+
+ if n > 1 {
+ lastI := n - 1
+ return func(ctx context.Context, next Resolver) (interface{}, error) {
+ var (
+ chainHandler Resolver
+ curI int
+ )
+ chainHandler = func(currentCtx context.Context) (interface{}, error) {
+ if curI == lastI {
+ return next(currentCtx)
+ }
+ curI++
+ res, err := handleFunc[curI](currentCtx, chainHandler)
+ curI--
+ return res, err
+
+ }
+ return handleFunc[0](ctx, chainHandler)
+ }
+ }
+
+ if n == 1 {
+ return handleFunc[0]
+ }
+
+ return func(ctx context.Context, next Resolver) (interface{}, error) {
+ return next(ctx)
+ }
+}