aboutsummaryrefslogblamecommitdiffstats
path: root/vendor/github.com/vektah/gqlgen/graphql/context.go
blob: 8f544100162fa51767c4abf1e5af7a3814bd900e (plain) (tree)
















































































































































                                                                                                                       
package graphql

import (
	"context"
	"fmt"
	"sync"

	"github.com/vektah/gqlgen/neelance/query"
)

type Resolver func(ctx context.Context) (res interface{}, err error)
type ResolverMiddleware func(ctx context.Context, next Resolver) (res interface{}, err error)
type RequestMiddleware func(ctx context.Context, next func(ctx context.Context) []byte) []byte

type RequestContext struct {
	RawQuery  string
	Variables map[string]interface{}
	Doc       *query.Document
	// ErrorPresenter will be used to generate the error
	// message from errors given to Error().
	ErrorPresenter     ErrorPresenterFunc
	Recover            RecoverFunc
	ResolverMiddleware ResolverMiddleware
	RequestMiddleware  RequestMiddleware

	errorsMu sync.Mutex
	Errors   []*Error
}

func DefaultResolverMiddleware(ctx context.Context, next Resolver) (res interface{}, err error) {
	return next(ctx)
}

func DefaultRequestMiddleware(ctx context.Context, next func(ctx context.Context) []byte) []byte {
	return next(ctx)
}

func NewRequestContext(doc *query.Document, query string, variables map[string]interface{}) *RequestContext {
	return &RequestContext{
		Doc:                doc,
		RawQuery:           query,
		Variables:          variables,
		ResolverMiddleware: DefaultResolverMiddleware,
		RequestMiddleware:  DefaultRequestMiddleware,
		Recover:            DefaultRecover,
		ErrorPresenter:     DefaultErrorPresenter,
	}
}

type key string

const (
	request  key = "request_context"
	resolver key = "resolver_context"
)

func GetRequestContext(ctx context.Context) *RequestContext {
	val := ctx.Value(request)
	if val == nil {
		return nil
	}

	return val.(*RequestContext)
}

func WithRequestContext(ctx context.Context, rc *RequestContext) context.Context {
	return context.WithValue(ctx, request, rc)
}

type ResolverContext struct {
	// The name of the type this field belongs to
	Object string
	// These are the args after processing, they can be mutated in middleware to change what the resolver will get.
	Args map[string]interface{}
	// The raw field
	Field CollectedField
	// The path of fields to get to this resolver
	Path []interface{}
}

func (r *ResolverContext) PushField(alias string) {
	r.Path = append(r.Path, alias)
}

func (r *ResolverContext) PushIndex(index int) {
	r.Path = append(r.Path, index)
}

func (r *ResolverContext) Pop() {
	r.Path = r.Path[0 : len(r.Path)-1]
}

func GetResolverContext(ctx context.Context) *ResolverContext {
	val := ctx.Value(resolver)
	if val == nil {
		return nil
	}

	return val.(*ResolverContext)
}

func WithResolverContext(ctx context.Context, rc *ResolverContext) context.Context {
	parent := GetResolverContext(ctx)
	rc.Path = nil
	if parent != nil {
		rc.Path = append(rc.Path, parent.Path...)
	}
	if rc.Field.Alias != "" {
		rc.PushField(rc.Field.Alias)
	}
	return context.WithValue(ctx, resolver, rc)
}

// This is just a convenient wrapper method for CollectFields
func CollectFieldsCtx(ctx context.Context, satisfies []string) []CollectedField {
	reqctx := GetRequestContext(ctx)
	resctx := GetResolverContext(ctx)
	return CollectFields(reqctx.Doc, resctx.Field.Selections, satisfies, reqctx.Variables)
}

// Errorf sends an error string to the client, passing it through the formatter.
func (c *RequestContext) Errorf(ctx context.Context, format string, args ...interface{}) {
	c.errorsMu.Lock()
	defer c.errorsMu.Unlock()

	c.Errors = append(c.Errors, c.ErrorPresenter(ctx, fmt.Errorf(format, args...)))
}

// Error sends an error to the client, passing it through the formatter.
func (c *RequestContext) Error(ctx context.Context, err error) {
	c.errorsMu.Lock()
	defer c.errorsMu.Unlock()

	c.Errors = append(c.Errors, c.ErrorPresenter(ctx, err))
}

// AddError is a convenience method for adding an error to the current response
func AddError(ctx context.Context, err error) {
	GetRequestContext(ctx).Error(ctx, err)
}

// AddErrorf is a convenience method for adding an error to the current response
func AddErrorf(ctx context.Context, format string, args ...interface{}) {
	GetRequestContext(ctx).Errorf(ctx, format, args...)
}