package handler import ( "context" "crypto/sha256" "encoding/hex" "encoding/json" "errors" "fmt" "io" "io/ioutil" "mime" "net/http" "os" "strconv" "strings" "time" "github.com/99designs/gqlgen/complexity" "github.com/99designs/gqlgen/graphql" "github.com/gorilla/websocket" lru "github.com/hashicorp/golang-lru" "github.com/vektah/gqlparser/ast" "github.com/vektah/gqlparser/gqlerror" "github.com/vektah/gqlparser/parser" "github.com/vektah/gqlparser/validator" ) type params struct { Query string `json:"query"` OperationName string `json:"operationName"` Variables map[string]interface{} `json:"variables"` Extensions *extensions `json:"extensions"` } type extensions struct { PersistedQuery *persistedQuery `json:"persistedQuery"` } type persistedQuery struct { Sha256 string `json:"sha256Hash"` Version int64 `json:"version"` } const ( errPersistedQueryNotSupported = "PersistedQueryNotSupported" errPersistedQueryNotFound = "PersistedQueryNotFound" ) type PersistedQueryCache interface { Add(ctx context.Context, hash string, query string) Get(ctx context.Context, hash string) (string, bool) } type websocketInitFunc func(ctx context.Context, initPayload InitPayload) error type Config struct { cacheSize int upgrader websocket.Upgrader recover graphql.RecoverFunc errorPresenter graphql.ErrorPresenterFunc resolverHook graphql.FieldMiddleware requestHook graphql.RequestMiddleware tracer graphql.Tracer complexityLimit int complexityLimitFunc graphql.ComplexityLimitFunc websocketInitFunc websocketInitFunc disableIntrospection bool connectionKeepAlivePingInterval time.Duration uploadMaxMemory int64 uploadMaxSize int64 apqCache PersistedQueryCache } func (c *Config) newRequestContext(es graphql.ExecutableSchema, doc *ast.QueryDocument, op *ast.OperationDefinition, query string, variables map[string]interface{}) *graphql.RequestContext { reqCtx := graphql.NewRequestContext(doc, query, variables) reqCtx.DisableIntrospection = c.disableIntrospection if hook := c.recover; hook != nil { reqCtx.Recover = hook } if hook := c.errorPresenter; hook != nil { reqCtx.ErrorPresenter = hook } if hook := c.resolverHook; hook != nil { reqCtx.ResolverMiddleware = hook } if hook := c.requestHook; hook != nil { reqCtx.RequestMiddleware = hook } if hook := c.tracer; hook != nil { reqCtx.Tracer = hook } if c.complexityLimit > 0 || c.complexityLimitFunc != nil { reqCtx.ComplexityLimit = c.complexityLimit operationComplexity := complexity.Calculate(es, op, variables) reqCtx.OperationComplexity = operationComplexity } return reqCtx } type Option func(cfg *Config) func WebsocketUpgrader(upgrader websocket.Upgrader) Option { return func(cfg *Config) { cfg.upgrader = upgrader } } func RecoverFunc(recover graphql.RecoverFunc) Option { return func(cfg *Config) { cfg.recover = recover } } // ErrorPresenter transforms errors found while resolving into errors that will be returned to the user. It provides // a good place to add any extra fields, like error.type, that might be desired by your frontend. Check the default // implementation in graphql.DefaultErrorPresenter for an example. func ErrorPresenter(f graphql.ErrorPresenterFunc) Option { return func(cfg *Config) { cfg.errorPresenter = f } } // IntrospectionEnabled = false will forbid clients from calling introspection endpoints. Can be useful in prod when you dont // want clients introspecting the full schema. func IntrospectionEnabled(enabled bool) Option { return func(cfg *Config) { cfg.disableIntrospection = !enabled } } // ComplexityLimit sets a maximum query complexity that is allowed to be executed. // If a query is submitted that exceeds the limit, a 422 status code will be returned. func ComplexityLimit(limit int) Option { return func(cfg *Config) { cfg.complexityLimit = limit } } // ComplexityLimitFunc allows you to define a function to dynamically set the maximum query complexity that is allowed // to be executed. // If a query is submitted that exceeds the limit, a 422 status code will be returned. func ComplexityLimitFunc(complexityLimitFunc graphql.ComplexityLimitFunc) Option { return func(cfg *Config) { cfg.complexityLimitFunc = complexityLimitFunc } } // ResolverMiddleware allows you to define a function that will be called around every resolver, // useful for logging. func ResolverMiddleware(middleware graphql.FieldMiddleware) Option { return func(cfg *Config) { if cfg.resolverHook == nil { cfg.resolverHook = middleware return } lastResolve := cfg.resolverHook cfg.resolverHook = func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) { return lastResolve(ctx, func(ctx context.Context) (res interface{}, err error) { return middleware(ctx, next) }) } } } // RequestMiddleware allows you to define a function that will be called around the root request, // after the query has been parsed. This is useful for logging func RequestMiddleware(middleware graphql.RequestMiddleware) Option { return func(cfg *Config) { if cfg.requestHook == nil { cfg.requestHook = middleware return } lastResolve := cfg.requestHook cfg.requestHook = func(ctx context.Context, next func(ctx context.Context) []byte) []byte { return lastResolve(ctx, func(ctx context.Context) []byte { return middleware(ctx, next) }) } } } // Tracer allows you to add a request/resolver tracer that will be called around the root request, // calling resolver. This is useful for tracing func Tracer(tracer graphql.Tracer) Option { return func(cfg *Config) { if cfg.tracer == nil { cfg.tracer = tracer } else { lastResolve := cfg.tracer cfg.tracer = &tracerWrapper{ tracer1: lastResolve, tracer2: tracer, } } opt := RequestMiddleware(func(ctx context.Context, next func(ctx context.Context) []byte) []byte { ctx = tracer.StartOperationExecution(ctx) resp := next(ctx) tracer.EndOperationExecution(ctx) return resp }) opt(cfg) } } type tracerWrapper struct { tracer1 graphql.Tracer tracer2 graphql.Tracer } func (tw *tracerWrapper) StartOperationParsing(ctx context.Context) context.Context { ctx = tw.tracer1.StartOperationParsing(ctx) ctx = tw.tracer2.StartOperationParsing(ctx) return ctx } func (tw *tracerWrapper) EndOperationParsing(ctx context.Context) { tw.tracer2.EndOperationParsing(ctx) tw.tracer1.EndOperationParsing(ctx) } func (tw *tracerWrapper) StartOperationValidation(ctx context.Context) context.Context { ctx = tw.tracer1.StartOperationValidation(ctx) ctx = tw.tracer2.StartOperationValidation(ctx) return ctx } func (tw *tracerWrapper) EndOperationValidation(ctx context.Context) { tw.tracer2.EndOperationValidation(ctx) tw.tracer1.EndOperationValidation(ctx) } func (tw *tracerWrapper) StartOperationExecution(ctx context.Context) context.Context { ctx = tw.tracer1.StartOperationExecution(ctx) ctx = tw.tracer2.StartOperationExecution(ctx) return ctx } func (tw *tracerWrapper) StartFieldExecution(ctx context.Context, field graphql.CollectedField) context.Context { ctx = tw.tracer1.StartFieldExecution(ctx, field) ctx = tw.tracer2.StartFieldExecution(ctx, field) return ctx } func (tw *tracerWrapper) StartFieldResolverExecution(ctx context.Context, rc *graphql.ResolverContext) context.Context { ctx = tw.tracer1.StartFieldResolverExecution(ctx, rc) ctx = tw.tracer2.StartFieldResolverExecution(ctx, rc) return ctx } func (tw *tracerWrapper) StartFieldChildExecution(ctx context.Context) context.Context { ctx = tw.tracer1.StartFieldChildExecution(ctx) ctx = tw.tracer2.StartFieldChildExecution(ctx) return ctx } func (tw *tracerWrapper) EndFieldExecution(ctx context.Context) { tw.tracer2.EndFieldExecution(ctx) tw.tracer1.EndFieldExecution(ctx) } func (tw *tracerWrapper) EndOperationExecution(ctx context.Context) { tw.tracer2.EndOperationExecution(ctx) tw.tracer1.EndOperationExecution(ctx) } // WebsocketInitFunc is called when the server receives connection init message from the client. // This can be used to check initial payload to see whether to accept the websocket connection. func WebsocketInitFunc(websocketInitFunc func(ctx context.Context, initPayload InitPayload) error) Option { return func(cfg *Config) { cfg.websocketInitFunc = websocketInitFunc } } // CacheSize sets the maximum size of the query cache. // If size is less than or equal to 0, the cache is disabled. func CacheSize(size int) Option { return func(cfg *Config) { cfg.cacheSize = size } } // UploadMaxSize sets the maximum number of bytes used to parse a request body // as multipart/form-data. func UploadMaxSize(size int64) Option { return func(cfg *Config) { cfg.uploadMaxSize = size } } // UploadMaxMemory sets the maximum number of bytes used to parse a request body // as multipart/form-data in memory, with the remainder stored on disk in // temporary files. func UploadMaxMemory(size int64) Option { return func(cfg *Config) { cfg.uploadMaxMemory = size } } // WebsocketKeepAliveDuration allows you to reconfigure the keepalive behavior. // By default, keepalive is enabled with a DefaultConnectionKeepAlivePingInterval // duration. Set handler.connectionKeepAlivePingInterval = 0 to disable keepalive // altogether. func WebsocketKeepAliveDuration(duration time.Duration) Option { return func(cfg *Config) { cfg.connectionKeepAlivePingInterval = duration } } // Add cache that will hold queries for automatic persisted queries (APQ) func EnablePersistedQueryCache(cache PersistedQueryCache) Option { return func(cfg *Config) { cfg.apqCache = cache } } const DefaultCacheSize = 1000 const DefaultConnectionKeepAlivePingInterval = 25 * time.Second // DefaultUploadMaxMemory is the maximum number of bytes used to parse a request body // as multipart/form-data in memory, with the remainder stored on disk in // temporary files. const DefaultUploadMaxMemory = 32 << 20 // DefaultUploadMaxSize is maximum number of bytes used to parse a request body // as multipart/form-data. const DefaultUploadMaxSize = 32 << 20 func GraphQL(exec graphql.ExecutableSchema, options ...Option) http.HandlerFunc { cfg := &Config{ cacheSize: DefaultCacheSize, uploadMaxMemory: DefaultUploadMaxMemory, uploadMaxSize: DefaultUploadMaxSize, connectionKeepAlivePingInterval: DefaultConnectionKeepAlivePingInterval, upgrader: websocket.Upgrader{ ReadBufferSize: 1024, WriteBufferSize: 1024, }, } for _, option := range options { option(cfg) } var cache *lru.Cache if cfg.cacheSize > 0 { var err error cache, err = lru.New(cfg.cacheSize) if err != nil { // An error is only returned for non-positive cache size // and we already checked for that. panic("unexpected error creating cache: " + err.Error()) } } if cfg.tracer == nil { cfg.tracer = &graphql.NopTracer{} } handler := &graphqlHandler{ cfg: cfg, cache: cache, exec: exec, } return handler.ServeHTTP } var _ http.Handler = (*graphqlHandler)(nil) type graphqlHandler struct { cfg *Config cache *lru.Cache exec graphql.ExecutableSchema } func computeQueryHash(query string) string { b := sha256.Sum256([]byte(query)) return hex.EncodeToString(b[:]) } func (gh *graphqlHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { if r.Method == http.MethodOptions { w.Header().Set("Allow", "OPTIONS, GET, POST") w.WriteHeader(http.StatusOK) return } if strings.Contains(r.Header.Get("Upgrade"), "websocket") { connectWs(gh.exec, w, r, gh.cfg, gh.cache) return } w.Header().Set("Content-Type", "application/json") var reqParams params switch r.Method { case http.MethodGet: reqParams.Query = r.URL.Query().Get("query") reqParams.OperationName = r.URL.Query().Get("operationName") if variables := r.URL.Query().Get("variables"); variables != "" { if err := jsonDecode(strings.NewReader(variables), &reqParams.Variables); err != nil { sendErrorf(w, http.StatusBadRequest, "variables could not be decoded") return } } if extensions := r.URL.Query().Get("extensions"); extensions != "" { if err := jsonDecode(strings.NewReader(extensions), &reqParams.Extensions); err != nil { sendErrorf(w, http.StatusBadRequest, "extensions could not be decoded") return } } case http.MethodPost: mediaType, _, err := mime.ParseMediaType(r.Header.Get("Content-Type")) if err != nil { sendErrorf(w, http.StatusBadRequest, "error parsing request Content-Type") return } switch mediaType { case "application/json": if err := jsonDecode(r.Body, &reqParams); err != nil { sendErrorf(w, http.StatusBadRequest, "json body could not be decoded: "+err.Error()) return } case "multipart/form-data": var closers []io.Closer var tmpFiles []string defer func() { for i := len(closers) - 1; 0 <= i; i-- { _ = closers[i].Close() } for _, tmpFile := range tmpFiles { _ = os.Remove(tmpFile) } }() if err := processMultipart(w, r, &reqParams, &closers, &tmpFiles, gh.cfg.uploadMaxSize, gh.cfg.uploadMaxMemory); err != nil { sendErrorf(w, http.StatusBadRequest, "multipart body could not be decoded: "+err.Error()) return } default: sendErrorf(w, http.StatusBadRequest, "unsupported Content-Type: "+mediaType) return } default: w.WriteHeader(http.StatusMethodNotAllowed) return } ctx := r.Context() var queryHash string apqRegister := false apq := reqParams.Extensions != nil && reqParams.Extensions.PersistedQuery != nil if apq { // client has enabled apq queryHash = reqParams.Extensions.PersistedQuery.Sha256 if gh.cfg.apqCache == nil { // server has disabled apq sendErrorf(w, http.StatusOK, errPersistedQueryNotSupported) return } if reqParams.Extensions.PersistedQuery.Version != 1 { sendErrorf(w, http.StatusOK, "Unsupported persisted query version") return } if reqParams.Query == "" { // client sent optimistic query hash without query string query, ok := gh.cfg.apqCache.Get(ctx, queryHash) if !ok { sendErrorf(w, http.StatusOK, errPersistedQueryNotFound) return } reqParams.Query = query } else { if computeQueryHash(reqParams.Query) != queryHash { sendErrorf(w, http.StatusOK, "provided sha does not match query") return } apqRegister = true } } else if reqParams.Query == "" { sendErrorf(w, http.StatusUnprocessableEntity, "Must provide query string") return } var doc *ast.QueryDocument var cacheHit bool if gh.cache != nil { val, ok := gh.cache.Get(reqParams.Query) if ok { doc = val.(*ast.QueryDocument) cacheHit = true } } ctx, doc, gqlErr := gh.parseOperation(ctx, &parseOperationArgs{ Query: reqParams.Query, CachedDoc: doc, }) if gqlErr != nil { sendError(w, http.StatusUnprocessableEntity, gqlErr) return } ctx, op, vars, listErr := gh.validateOperation(ctx, &validateOperationArgs{ Doc: doc, OperationName: reqParams.OperationName, CacheHit: cacheHit, R: r, Variables: reqParams.Variables, }) if len(listErr) != 0 { sendError(w, http.StatusUnprocessableEntity, listErr...) return } if gh.cache != nil && !cacheHit { gh.cache.Add(reqParams.Query, doc) } reqCtx := gh.cfg.newRequestContext(gh.exec, doc, op, reqParams.Query, vars) ctx = graphql.WithRequestContext(ctx, reqCtx) defer func() { if err := recover(); err != nil { userErr := reqCtx.Recover(ctx, err) sendErrorf(w, http.StatusUnprocessableEntity, userErr.Error()) } }() if gh.cfg.complexityLimitFunc != nil { reqCtx.ComplexityLimit = gh.cfg.complexityLimitFunc(ctx) } if reqCtx.ComplexityLimit > 0 && reqCtx.OperationComplexity > reqCtx.ComplexityLimit { sendErrorf(w, http.StatusUnprocessableEntity, "operation has complexity %d, which exceeds the limit of %d", reqCtx.OperationComplexity, reqCtx.ComplexityLimit) return } if apqRegister && gh.cfg.apqCache != nil { // Add to persisted query cache gh.cfg.apqCache.Add(ctx, queryHash, reqParams.Query) } switch op.Operation { case ast.Query: b, err := json.Marshal(gh.exec.Query(ctx, op)) if err != nil { panic(err) } w.Write(b) case ast.Mutation: b, err := json.Marshal(gh.exec.Mutation(ctx, op)) if err != nil { panic(err) } w.Write(b) default: sendErrorf(w, http.StatusBadRequest, "unsupported operation type") } } type parseOperationArgs struct { Query string CachedDoc *ast.QueryDocument } func (gh *graphqlHandler) parseOperation(ctx context.Context, args *parseOperationArgs) (context.Context, *ast.QueryDocument, *gqlerror.Error) { ctx = gh.cfg.tracer.StartOperationParsing(ctx) defer func() { gh.cfg.tracer.EndOperationParsing(ctx) }() if args.CachedDoc != nil { return ctx, args.CachedDoc, nil } doc, gqlErr := parser.ParseQuery(&ast.Source{Input: args.Query}) if gqlErr != nil { return ctx, nil, gqlErr } return ctx, doc, nil } type validateOperationArgs struct { Doc *ast.QueryDocument OperationName string CacheHit bool R *http.Request Variables map[string]interface{} } func (gh *graphqlHandler) validateOperation(ctx context.Context, args *validateOperationArgs) (context.Context, *ast.OperationDefinition, map[string]interface{}, gqlerror.List) { ctx = gh.cfg.tracer.StartOperationValidation(ctx) defer func() { gh.cfg.tracer.EndOperationValidation(ctx) }() if !args.CacheHit { listErr := validator.Validate(gh.exec.Schema(), args.Doc) if len(listErr) != 0 { return ctx, nil, nil, listErr } } op := args.Doc.Operations.ForName(args.OperationName) if op == nil { return ctx, nil, nil, gqlerror.List{gqlerror.Errorf("operation %s not found", args.OperationName)} } if op.Operation != ast.Query && args.R.Method == http.MethodGet { return ctx, nil, nil, gqlerror.List{gqlerror.Errorf("GET requests only allow query operations")} } vars, err := validator.VariableValues(gh.exec.Schema(), op, args.Variables) if err != nil { return ctx, nil, nil, gqlerror.List{err} } return ctx, op, vars, nil } func jsonDecode(r io.Reader, val interface{}) error { dec := json.NewDecoder(r) dec.UseNumber() return dec.Decode(val) } func sendError(w http.ResponseWriter, code int, errors ...*gqlerror.Error) { w.WriteHeader(code) b, err := json.Marshal(&graphql.Response{Errors: errors}) if err != nil { panic(err) } w.Write(b) } func sendErrorf(w http.ResponseWriter, code int, format string, args ...interface{}) { sendError(w, code, &gqlerror.Error{Message: fmt.Sprintf(format, args...)}) } type bytesReader struct { s *[]byte i int64 // current reading index prevRune int // index of previous rune; or < 0 } func (r *bytesReader) Read(b []byte) (n int, err error) { if r.s == nil { return 0, errors.New("byte slice pointer is nil") } if r.i >= int64(len(*r.s)) { return 0, io.EOF } r.prevRune = -1 n = copy(b, (*r.s)[r.i:]) r.i += int64(n) return } func processMultipart(w http.ResponseWriter, r *http.Request, request *params, closers *[]io.Closer, tmpFiles *[]string, uploadMaxSize, uploadMaxMemory int64) error { var err error if r.ContentLength > uploadMaxSize { return errors.New("failed to parse multipart form, request body too large") } r.Body = http.MaxBytesReader(w, r.Body, uploadMaxSize) if err = r.ParseMultipartForm(uploadMaxMemory); err != nil { if strings.Contains(err.Error(), "request body too large") { return errors.New("failed to parse multipart form, request body too large") } return errors.New("failed to parse multipart form") } *closers = append(*closers, r.Body) if err = jsonDecode(strings.NewReader(r.Form.Get("operations")), &request); err != nil { return errors.New("operations form field could not be decoded") } var uploadsMap = map[string][]string{} if err = json.Unmarshal([]byte(r.Form.Get("map")), &uploadsMap); err != nil { return errors.New("map form field could not be decoded") } var upload graphql.Upload for key, paths := range uploadsMap { if len(paths) == 0 { return fmt.Errorf("invalid empty operations paths list for key %s", key) } file, header, err := r.FormFile(key) if err != nil { return fmt.Errorf("failed to get key %s from form", key) } *closers = append(*closers, file) if len(paths) == 1 { upload = graphql.Upload{ File: file, Size: header.Size, Filename: header.Filename, } err = addUploadToOperations(request, upload, key, paths[0]) if err != nil { return err } } else { if r.ContentLength < uploadMaxMemory { fileBytes, err := ioutil.ReadAll(file) if err != nil { return fmt.Errorf("failed to read file for key %s", key) } for _, path := range paths { upload = graphql.Upload{ File: &bytesReader{s: &fileBytes, i: 0, prevRune: -1}, Size: header.Size, Filename: header.Filename, } err = addUploadToOperations(request, upload, key, path) if err != nil { return err } } } else { tmpFile, err := ioutil.TempFile(os.TempDir(), "gqlgen-") if err != nil { return fmt.Errorf("failed to create temp file for key %s", key) } tmpName := tmpFile.Name() *tmpFiles = append(*tmpFiles, tmpName) _, err = io.Copy(tmpFile, file) if err != nil { if err := tmpFile.Close(); err != nil { return fmt.Errorf("failed to copy to temp file and close temp file for key %s", key) } return fmt.Errorf("failed to copy to temp file for key %s", key) } if err := tmpFile.Close(); err != nil { return fmt.Errorf("failed to close temp file for key %s", key) } for _, path := range paths { pathTmpFile, err := os.Open(tmpName) if err != nil { return fmt.Errorf("failed to open temp file for key %s", key) } *closers = append(*closers, pathTmpFile) upload = graphql.Upload{ File: pathTmpFile, Size: header.Size, Filename: header.Filename, } err = addUploadToOperations(request, upload, key, path) if err != nil { return err } } } } } return nil } func addUploadToOperations(request *params, upload graphql.Upload, key, path string) error { if !strings.HasPrefix(path, "variables.") { return fmt.Errorf("invalid operations paths for key %s", key) } var ptr interface{} = request.Variables parts := strings.Split(path, ".") // skip the first part (variables) because we started there for i, p := range parts[1:] { last := i == len(parts)-2 if ptr == nil { return fmt.Errorf("path is missing \"variables.\" prefix, key: %s, path: %s", key, path) } if index, parseNbrErr := strconv.Atoi(p); parseNbrErr == nil { if last { ptr.([]interface{})[index] = upload } else { ptr = ptr.([]interface{})[index] } } else { if last { ptr.(map[string]interface{})[p] = upload } else { ptr = ptr.(map[string]interface{})[p] } } } return nil }