package graphql
import (
"context"
"fmt"
"sync"
"github.com/vektah/gqlgen/neelance/query"
)
type Resolver func(ctx context.Context) (res interface{}, err error)
type ResolverMiddleware func(ctx context.Context, next Resolver) (res interface{}, err error)
type RequestMiddleware func(ctx context.Context, next func(ctx context.Context) []byte) []byte
type RequestContext struct {
RawQuery string
Variables map[string]interface{}
Doc *query.Document
// ErrorPresenter will be used to generate the error
// message from errors given to Error().
ErrorPresenter ErrorPresenterFunc
Recover RecoverFunc
ResolverMiddleware ResolverMiddleware
RequestMiddleware RequestMiddleware
errorsMu sync.Mutex
Errors []*Error
}
func DefaultResolverMiddleware(ctx context.Context, next Resolver) (res interface{}, err error) {
return next(ctx)
}
func DefaultRequestMiddleware(ctx context.Context, next func(ctx context.Context) []byte) []byte {
return next(ctx)
}
func NewRequestContext(doc *query.Document, query string, variables map[string]interface{}) *RequestContext {
return &RequestContext{
Doc: doc,
RawQuery: query,
Variables: variables,
ResolverMiddleware: DefaultResolverMiddleware,
RequestMiddleware: DefaultRequestMiddleware,
Recover: DefaultRecover,
ErrorPresenter: DefaultErrorPresenter,
}
}
type key string
const (
request key = "request_context"
resolver key = "resolver_context"
)
func GetRequestContext(ctx context.Context) *RequestContext {
val := ctx.Value(request)
if val == nil {
return nil
}
return val.(*RequestContext)
}
func WithRequestContext(ctx context.Context, rc *RequestContext) context.Context {
return context.WithValue(ctx, request, rc)
}
type ResolverContext struct {
// The name of the type this field belongs to
Object string
// These are the args after processing, they can be mutated in middleware to change what the resolver will get.
Args map[string]interface{}
// The raw field
Field CollectedField
// The path of fields to get to this resolver
Path []interface{}
}
func (r *ResolverContext) PushField(alias string) {
r.Path = append(r.Path, alias)
}
func (r *ResolverContext) PushIndex(index int) {
r.Path = append(r.Path, index)
}
func (r *ResolverContext) Pop() {
r.Path = r.Path[0 : len(r.Path)-1]
}
func GetResolverContext(ctx context.Context) *ResolverContext {
val := ctx.Value(resolver)
if val == nil {
return nil
}
return val.(*ResolverContext)
}
func WithResolverContext(ctx context.Context, rc *ResolverContext) context.Context {
parent := GetResolverContext(ctx)
rc.Path = nil
if parent != nil {
rc.Path = append(rc.Path, parent.Path...)
}
if rc.Field.Alias != "" {
rc.PushField(rc.Field.Alias)
}
return context.WithValue(ctx, resolver, rc)
}
// This is just a convenient wrapper method for CollectFields
func CollectFieldsCtx(ctx context.Context, satisfies []string) []CollectedField {
reqctx := GetRequestContext(ctx)
resctx := GetResolverContext(ctx)
return CollectFields(reqctx.Doc, resctx.Field.Selections, satisfies, reqctx.Variables)
}
// Errorf sends an error string to the client, passing it through the formatter.
func (c *RequestContext) Errorf(ctx context.Context, format string, args ...interface{}) {
c.errorsMu.Lock()
defer c.errorsMu.Unlock()
c.Errors = append(c.Errors, c.ErrorPresenter(ctx, fmt.Errorf(format, args...)))
}
// Error sends an error to the client, passing it through the formatter.
func (c *RequestContext) Error(ctx context.Context, err error) {
c.errorsMu.Lock()
defer c.errorsMu.Unlock()
c.Errors = append(c.Errors, c.ErrorPresenter(ctx, err))
}
// AddError is a convenience method for adding an error to the current response
func AddError(ctx context.Context, err error) {
GetRequestContext(ctx).Error(ctx, err)
}
// AddErrorf is a convenience method for adding an error to the current response
func AddErrorf(ctx context.Context, format string, args ...interface{}) {
GetRequestContext(ctx).Errorf(ctx, format, args...)
}