diff options
Diffstat (limited to 'vendor/github.com/99designs/gqlgen/handler')
4 files changed, 359 insertions, 102 deletions
diff --git a/vendor/github.com/99designs/gqlgen/handler/context.go b/vendor/github.com/99designs/gqlgen/handler/context.go new file mode 100644 index 00000000..2992aa3d --- /dev/null +++ b/vendor/github.com/99designs/gqlgen/handler/context.go @@ -0,0 +1,57 @@ +package handler + +import "context" + +type key string + +const ( + initpayload key = "ws_initpayload_context" +) + +// InitPayload is a structure that is parsed from the websocket init message payload. TO use +// request headers for non-websocket, instead wrap the graphql handler in a middleware. +type InitPayload map[string]interface{} + +// GetString safely gets a string value from the payload. It returns an empty string if the +// payload is nil or the value isn't set. +func (payload InitPayload) GetString(key string) string { + if payload == nil { + return "" + } + + if value, ok := payload[key]; ok { + res, _ := value.(string) + return res + } + + return "" +} + +// Authorization is a short hand for getting the Authorization header from the +// payload. +func (payload InitPayload) Authorization() string { + if value := payload.GetString("Authorization"); value != "" { + return value + } + + if value := payload.GetString("authorization"); value != "" { + return value + } + + return "" +} + +func withInitPayload(ctx context.Context, payload InitPayload) context.Context { + return context.WithValue(ctx, initpayload, payload) +} + +// GetInitPayload gets a map of the data sent with the connection_init message, which is used by +// graphql clients as a stand-in for HTTP headers. +func GetInitPayload(ctx context.Context) InitPayload { + payload, ok := ctx.Value(initpayload).(InitPayload) + if !ok { + return nil + } + + return payload +} diff --git a/vendor/github.com/99designs/gqlgen/handler/graphql.go b/vendor/github.com/99designs/gqlgen/handler/graphql.go index 9d222826..eb8880de 100644 --- a/vendor/github.com/99designs/gqlgen/handler/graphql.go +++ b/vendor/github.com/99designs/gqlgen/handler/graphql.go @@ -12,9 +12,9 @@ import ( "github.com/99designs/gqlgen/graphql" "github.com/gorilla/websocket" "github.com/hashicorp/golang-lru" - "github.com/vektah/gqlparser" "github.com/vektah/gqlparser/ast" "github.com/vektah/gqlparser/gqlerror" + "github.com/vektah/gqlparser/parser" "github.com/vektah/gqlparser/validator" ) @@ -25,17 +25,21 @@ type params struct { } type Config struct { - cacheSize int - upgrader websocket.Upgrader - recover graphql.RecoverFunc - errorPresenter graphql.ErrorPresenterFunc - resolverHook graphql.FieldMiddleware - requestHook graphql.RequestMiddleware - complexityLimit int + cacheSize int + upgrader websocket.Upgrader + recover graphql.RecoverFunc + errorPresenter graphql.ErrorPresenterFunc + resolverHook graphql.FieldMiddleware + requestHook graphql.RequestMiddleware + tracer graphql.Tracer + complexityLimit int + disableIntrospection bool } -func (c *Config) newRequestContext(doc *ast.QueryDocument, query string, variables map[string]interface{}) *graphql.RequestContext { +func (c *Config) newRequestContext(es graphql.ExecutableSchema, doc *ast.QueryDocument, op *ast.OperationDefinition, query string, variables map[string]interface{}) *graphql.RequestContext { reqCtx := graphql.NewRequestContext(doc, query, variables) + reqCtx.DisableIntrospection = c.disableIntrospection + if hook := c.recover; hook != nil { reqCtx.Recover = hook } @@ -52,6 +56,18 @@ func (c *Config) newRequestContext(doc *ast.QueryDocument, query string, variabl reqCtx.RequestMiddleware = hook } + if hook := c.tracer; hook != nil { + reqCtx.Tracer = hook + } else { + reqCtx.Tracer = &graphql.NopTracer{} + } + + if c.complexityLimit > 0 { + reqCtx.ComplexityLimit = c.complexityLimit + operationComplexity := complexity.Calculate(es, op, variables) + reqCtx.OperationComplexity = operationComplexity + } + return reqCtx } @@ -78,6 +94,14 @@ func ErrorPresenter(f graphql.ErrorPresenterFunc) Option { } } +// IntrospectionEnabled = false will forbid clients from calling introspection endpoints. Can be useful in prod when you dont +// want clients introspecting the full schema. +func IntrospectionEnabled(enabled bool) Option { + return func(cfg *Config) { + cfg.disableIntrospection = !enabled + } +} + // ComplexityLimit sets a maximum query complexity that is allowed to be executed. // If a query is submitted that exceeds the limit, a 422 status code will be returned. func ComplexityLimit(limit int) Option { @@ -87,7 +111,7 @@ func ComplexityLimit(limit int) Option { } // ResolverMiddleware allows you to define a function that will be called around every resolver, -// useful for tracing and logging. +// useful for logging. func ResolverMiddleware(middleware graphql.FieldMiddleware) Option { return func(cfg *Config) { if cfg.resolverHook == nil { @@ -105,7 +129,7 @@ func ResolverMiddleware(middleware graphql.FieldMiddleware) Option { } // RequestMiddleware allows you to define a function that will be called around the root request, -// after the query has been parsed. This is useful for logging and tracing +// after the query has been parsed. This is useful for logging func RequestMiddleware(middleware graphql.RequestMiddleware) Option { return func(cfg *Config) { if cfg.requestHook == nil { @@ -122,6 +146,93 @@ func RequestMiddleware(middleware graphql.RequestMiddleware) Option { } } +// Tracer allows you to add a request/resolver tracer that will be called around the root request, +// calling resolver. This is useful for tracing +func Tracer(tracer graphql.Tracer) Option { + return func(cfg *Config) { + if cfg.tracer == nil { + cfg.tracer = tracer + + } else { + lastResolve := cfg.tracer + cfg.tracer = &tracerWrapper{ + tracer1: lastResolve, + tracer2: tracer, + } + } + + opt := RequestMiddleware(func(ctx context.Context, next func(ctx context.Context) []byte) []byte { + ctx = tracer.StartOperationExecution(ctx) + resp := next(ctx) + tracer.EndOperationExecution(ctx) + + return resp + }) + opt(cfg) + } +} + +type tracerWrapper struct { + tracer1 graphql.Tracer + tracer2 graphql.Tracer +} + +func (tw *tracerWrapper) StartOperationParsing(ctx context.Context) context.Context { + ctx = tw.tracer1.StartOperationParsing(ctx) + ctx = tw.tracer2.StartOperationParsing(ctx) + return ctx +} + +func (tw *tracerWrapper) EndOperationParsing(ctx context.Context) { + tw.tracer2.EndOperationParsing(ctx) + tw.tracer1.EndOperationParsing(ctx) +} + +func (tw *tracerWrapper) StartOperationValidation(ctx context.Context) context.Context { + ctx = tw.tracer1.StartOperationValidation(ctx) + ctx = tw.tracer2.StartOperationValidation(ctx) + return ctx +} + +func (tw *tracerWrapper) EndOperationValidation(ctx context.Context) { + tw.tracer2.EndOperationValidation(ctx) + tw.tracer1.EndOperationValidation(ctx) +} + +func (tw *tracerWrapper) StartOperationExecution(ctx context.Context) context.Context { + ctx = tw.tracer1.StartOperationExecution(ctx) + ctx = tw.tracer2.StartOperationExecution(ctx) + return ctx +} + +func (tw *tracerWrapper) StartFieldExecution(ctx context.Context, field graphql.CollectedField) context.Context { + ctx = tw.tracer1.StartFieldExecution(ctx, field) + ctx = tw.tracer2.StartFieldExecution(ctx, field) + return ctx +} + +func (tw *tracerWrapper) StartFieldResolverExecution(ctx context.Context, rc *graphql.ResolverContext) context.Context { + ctx = tw.tracer1.StartFieldResolverExecution(ctx, rc) + ctx = tw.tracer2.StartFieldResolverExecution(ctx, rc) + return ctx +} + +func (tw *tracerWrapper) StartFieldChildExecution(ctx context.Context) context.Context { + ctx = tw.tracer1.StartFieldChildExecution(ctx) + ctx = tw.tracer2.StartFieldChildExecution(ctx) + return ctx +} + +func (tw *tracerWrapper) EndFieldExecution(ctx context.Context) { + tw.tracer2.EndFieldExecution(ctx) + tw.tracer1.EndFieldExecution(ctx) +} + +func (tw *tracerWrapper) EndOperationExecution(ctx context.Context) { + tw.tracer2.EndOperationExecution(ctx) + tw.tracer1.EndOperationExecution(ctx) +} + // CacheSize sets the maximum size of the query cache. // If size is less than or equal to 0, the cache is disabled. func CacheSize(size int) Option { @@ -133,7 +244,7 @@ func CacheSize(size int) Option { const DefaultCacheSize = 1000 func GraphQL(exec graphql.ExecutableSchema, options ...Option) http.HandlerFunc { - cfg := Config{ + cfg := &Config{ cacheSize: DefaultCacheSize, upgrader: websocket.Upgrader{ ReadBufferSize: 1024, @@ -142,7 +253,7 @@ func GraphQL(exec graphql.ExecutableSchema, options ...Option) http.HandlerFunc } for _, option := range options { - option(&cfg) + option(cfg) } var cache *lru.Cache @@ -155,112 +266,187 @@ func GraphQL(exec graphql.ExecutableSchema, options ...Option) http.HandlerFunc panic("unexpected error creating cache: " + err.Error()) } } + if cfg.tracer == nil { + cfg.tracer = &graphql.NopTracer{} + } - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Method == http.MethodOptions { - w.Header().Set("Allow", "OPTIONS, GET, POST") - w.WriteHeader(http.StatusOK) - return - } + handler := &graphqlHandler{ + cfg: cfg, + cache: cache, + exec: exec, + } - if strings.Contains(r.Header.Get("Upgrade"), "websocket") { - connectWs(exec, w, r, &cfg) - return - } + return handler.ServeHTTP +} - var reqParams params - switch r.Method { - case http.MethodGet: - reqParams.Query = r.URL.Query().Get("query") - reqParams.OperationName = r.URL.Query().Get("operationName") - - if variables := r.URL.Query().Get("variables"); variables != "" { - if err := jsonDecode(strings.NewReader(variables), &reqParams.Variables); err != nil { - sendErrorf(w, http.StatusBadRequest, "variables could not be decoded") - return - } - } - case http.MethodPost: - if err := jsonDecode(r.Body, &reqParams); err != nil { - sendErrorf(w, http.StatusBadRequest, "json body could not be decoded: "+err.Error()) +var _ http.Handler = (*graphqlHandler)(nil) + +type graphqlHandler struct { + cfg *Config + cache *lru.Cache + exec graphql.ExecutableSchema +} + +func (gh *graphqlHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodOptions { + w.Header().Set("Allow", "OPTIONS, GET, POST") + w.WriteHeader(http.StatusOK) + return + } + + if strings.Contains(r.Header.Get("Upgrade"), "websocket") { + connectWs(gh.exec, w, r, gh.cfg) + return + } + + var reqParams params + switch r.Method { + case http.MethodGet: + reqParams.Query = r.URL.Query().Get("query") + reqParams.OperationName = r.URL.Query().Get("operationName") + + if variables := r.URL.Query().Get("variables"); variables != "" { + if err := jsonDecode(strings.NewReader(variables), &reqParams.Variables); err != nil { + sendErrorf(w, http.StatusBadRequest, "variables could not be decoded") return } - default: - w.WriteHeader(http.StatusMethodNotAllowed) + } + case http.MethodPost: + if err := jsonDecode(r.Body, &reqParams); err != nil { + sendErrorf(w, http.StatusBadRequest, "json body could not be decoded: "+err.Error()) return } - w.Header().Set("Content-Type", "application/json") + default: + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + w.Header().Set("Content-Type", "application/json") - var doc *ast.QueryDocument - if cache != nil { - val, ok := cache.Get(reqParams.Query) - if ok { - doc = val.(*ast.QueryDocument) - } - } - if doc == nil { - var qErr gqlerror.List - doc, qErr = gqlparser.LoadQuery(exec.Schema(), reqParams.Query) - if len(qErr) > 0 { - sendError(w, http.StatusUnprocessableEntity, qErr...) - return - } - if cache != nil { - cache.Add(reqParams.Query, doc) - } - } + ctx := r.Context() - op := doc.Operations.ForName(reqParams.OperationName) - if op == nil { - sendErrorf(w, http.StatusUnprocessableEntity, "operation %s not found", reqParams.OperationName) - return + var doc *ast.QueryDocument + var cacheHit bool + if gh.cache != nil { + val, ok := gh.cache.Get(reqParams.Query) + if ok { + doc = val.(*ast.QueryDocument) + cacheHit = true } + } - if op.Operation != ast.Query && r.Method == http.MethodGet { - sendErrorf(w, http.StatusUnprocessableEntity, "GET requests only allow query operations") - return + ctx, doc, gqlErr := gh.parseOperation(ctx, &parseOperationArgs{ + Query: reqParams.Query, + CachedDoc: doc, + }) + if gqlErr != nil { + sendError(w, http.StatusUnprocessableEntity, gqlErr) + return + } + + ctx, op, vars, listErr := gh.validateOperation(ctx, &validateOperationArgs{ + Doc: doc, + OperationName: reqParams.OperationName, + CacheHit: cacheHit, + R: r, + Variables: reqParams.Variables, + }) + if len(listErr) != 0 { + sendError(w, http.StatusUnprocessableEntity, listErr...) + return + } + + if gh.cache != nil && !cacheHit { + gh.cache.Add(reqParams.Query, doc) + } + + reqCtx := gh.cfg.newRequestContext(gh.exec, doc, op, reqParams.Query, vars) + ctx = graphql.WithRequestContext(ctx, reqCtx) + + defer func() { + if err := recover(); err != nil { + userErr := reqCtx.Recover(ctx, err) + sendErrorf(w, http.StatusUnprocessableEntity, userErr.Error()) } + }() - vars, err := validator.VariableValues(exec.Schema(), op, reqParams.Variables) + if reqCtx.ComplexityLimit > 0 && reqCtx.OperationComplexity > reqCtx.ComplexityLimit { + sendErrorf(w, http.StatusUnprocessableEntity, "operation has complexity %d, which exceeds the limit of %d", reqCtx.OperationComplexity, reqCtx.ComplexityLimit) + return + } + + switch op.Operation { + case ast.Query: + b, err := json.Marshal(gh.exec.Query(ctx, op)) if err != nil { - sendError(w, http.StatusUnprocessableEntity, err) - return + panic(err) } - reqCtx := cfg.newRequestContext(doc, reqParams.Query, vars) - ctx := graphql.WithRequestContext(r.Context(), reqCtx) + w.Write(b) + case ast.Mutation: + b, err := json.Marshal(gh.exec.Mutation(ctx, op)) + if err != nil { + panic(err) + } + w.Write(b) + default: + sendErrorf(w, http.StatusBadRequest, "unsupported operation type") + } +} - defer func() { - if err := recover(); err != nil { - userErr := reqCtx.Recover(ctx, err) - sendErrorf(w, http.StatusUnprocessableEntity, userErr.Error()) - } - }() +type parseOperationArgs struct { + Query string + CachedDoc *ast.QueryDocument +} - if cfg.complexityLimit > 0 { - queryComplexity := complexity.Calculate(exec, op, vars) - if queryComplexity > cfg.complexityLimit { - sendErrorf(w, http.StatusUnprocessableEntity, "query has complexity %d, which exceeds the limit of %d", queryComplexity, cfg.complexityLimit) - return - } - } +func (gh *graphqlHandler) parseOperation(ctx context.Context, args *parseOperationArgs) (context.Context, *ast.QueryDocument, *gqlerror.Error) { + ctx = gh.cfg.tracer.StartOperationParsing(ctx) + defer func() { gh.cfg.tracer.EndOperationParsing(ctx) }() - switch op.Operation { - case ast.Query: - b, err := json.Marshal(exec.Query(ctx, op)) - if err != nil { - panic(err) - } - w.Write(b) - case ast.Mutation: - b, err := json.Marshal(exec.Mutation(ctx, op)) - if err != nil { - panic(err) - } - w.Write(b) - default: - sendErrorf(w, http.StatusBadRequest, "unsupported operation type") + if args.CachedDoc != nil { + return ctx, args.CachedDoc, nil + } + + doc, gqlErr := parser.ParseQuery(&ast.Source{Input: args.Query}) + if gqlErr != nil { + return ctx, nil, gqlErr + } + + return ctx, doc, nil +} + +type validateOperationArgs struct { + Doc *ast.QueryDocument + OperationName string + CacheHit bool + R *http.Request + Variables map[string]interface{} +} + +func (gh *graphqlHandler) validateOperation(ctx context.Context, args *validateOperationArgs) (context.Context, *ast.OperationDefinition, map[string]interface{}, gqlerror.List) { + ctx = gh.cfg.tracer.StartOperationValidation(ctx) + defer func() { gh.cfg.tracer.EndOperationValidation(ctx) }() + + if !args.CacheHit { + listErr := validator.Validate(gh.exec.Schema(), args.Doc) + if len(listErr) != 0 { + return ctx, nil, nil, listErr } - }) + } + + op := args.Doc.Operations.ForName(args.OperationName) + if op == nil { + return ctx, nil, nil, gqlerror.List{gqlerror.Errorf("operation %s not found", args.OperationName)} + } + + if op.Operation != ast.Query && args.R.Method == http.MethodGet { + return ctx, nil, nil, gqlerror.List{gqlerror.Errorf("GET requests only allow query operations")} + } + + vars, err := validator.VariableValues(gh.exec.Schema(), op, args.Variables) + if err != nil { + return ctx, nil, nil, gqlerror.List{err} + } + + return ctx, op, vars, nil } func jsonDecode(r io.Reader, val interface{}) error { diff --git a/vendor/github.com/99designs/gqlgen/handler/playground.go b/vendor/github.com/99designs/gqlgen/handler/playground.go index d0ada8ca..f1687def 100644 --- a/vendor/github.com/99designs/gqlgen/handler/playground.go +++ b/vendor/github.com/99designs/gqlgen/handler/playground.go @@ -45,7 +45,7 @@ func Playground(title string, endpoint string) http.HandlerFunc { err := page.Execute(w, map[string]string{ "title": title, "endpoint": endpoint, - "version": "1.6.2", + "version": "1.7.8", }) if err != nil { panic(err) diff --git a/vendor/github.com/99designs/gqlgen/handler/websocket.go b/vendor/github.com/99designs/gqlgen/handler/websocket.go index 2be1e87f..dae262bd 100644 --- a/vendor/github.com/99designs/gqlgen/handler/websocket.go +++ b/vendor/github.com/99designs/gqlgen/handler/websocket.go @@ -43,6 +43,8 @@ type wsConnection struct { active map[string]context.CancelFunc mu sync.Mutex cfg *Config + + initPayload InitPayload } func connectWs(exec graphql.ExecutableSchema, w http.ResponseWriter, r *http.Request, cfg *Config) { @@ -79,6 +81,14 @@ func (c *wsConnection) init() bool { switch message.Type { case connectionInitMsg: + if len(message.Payload) > 0 { + c.initPayload = make(InitPayload) + err := json.Unmarshal(message.Payload, &c.initPayload) + if err != nil { + return false + } + } + c.write(&operationMessage{Type: connectionAckMsg}) case connectionTerminateMsg: c.close(websocket.CloseNormalClosure, "terminated") @@ -155,9 +165,13 @@ func (c *wsConnection) subscribe(message *operationMessage) bool { c.sendError(message.ID, err) return true } - reqCtx := c.cfg.newRequestContext(doc, reqParams.Query, vars) + reqCtx := c.cfg.newRequestContext(c.exec, doc, op, reqParams.Query, vars) ctx := graphql.WithRequestContext(c.ctx, reqCtx) + if c.initPayload != nil { + ctx = withInitPayload(ctx, c.initPayload) + } + if op.Operation != ast.Subscription { var result *graphql.Response if op.Operation == ast.Query { |