diff options
Diffstat (limited to 'vendor/github.com/vektah/gqlgen/handler/graphql.go')
-rw-r--r-- | vendor/github.com/vektah/gqlgen/handler/graphql.go | 235 |
1 files changed, 235 insertions, 0 deletions
diff --git a/vendor/github.com/vektah/gqlgen/handler/graphql.go b/vendor/github.com/vektah/gqlgen/handler/graphql.go new file mode 100644 index 00000000..4a5c61f5 --- /dev/null +++ b/vendor/github.com/vektah/gqlgen/handler/graphql.go @@ -0,0 +1,235 @@ +package handler + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "strings" + + "github.com/gorilla/websocket" + "github.com/vektah/gqlgen/graphql" + "github.com/vektah/gqlgen/neelance/errors" + "github.com/vektah/gqlgen/neelance/query" + "github.com/vektah/gqlgen/neelance/validation" +) + +type params struct { + Query string `json:"query"` + OperationName string `json:"operationName"` + Variables map[string]interface{} `json:"variables"` +} + +type Config struct { + upgrader websocket.Upgrader + recover graphql.RecoverFunc + errorPresenter graphql.ErrorPresenterFunc + resolverHook graphql.ResolverMiddleware + requestHook graphql.RequestMiddleware +} + +func (c *Config) newRequestContext(doc *query.Document, query string, variables map[string]interface{}) *graphql.RequestContext { + reqCtx := graphql.NewRequestContext(doc, query, variables) + if hook := c.recover; hook != nil { + reqCtx.Recover = hook + } + + if hook := c.errorPresenter; hook != nil { + reqCtx.ErrorPresenter = hook + } + + if hook := c.resolverHook; hook != nil { + reqCtx.ResolverMiddleware = hook + } + + if hook := c.requestHook; hook != nil { + reqCtx.RequestMiddleware = hook + } + + return reqCtx +} + +type Option func(cfg *Config) + +func WebsocketUpgrader(upgrader websocket.Upgrader) Option { + return func(cfg *Config) { + cfg.upgrader = upgrader + } +} + +func RecoverFunc(recover graphql.RecoverFunc) Option { + return func(cfg *Config) { + cfg.recover = recover + } +} + +// ErrorPresenter transforms errors found while resolving into errors that will be returned to the user. It provides +// a good place to add any extra fields, like error.type, that might be desired by your frontend. Check the default +// implementation in graphql.DefaultErrorPresenter for an example. +func ErrorPresenter(f graphql.ErrorPresenterFunc) Option { + return func(cfg *Config) { + cfg.errorPresenter = f + } +} + +// ResolverMiddleware allows you to define a function that will be called around every resolver, +// useful for tracing and logging. +// It will only be called for user defined resolvers, any direct binding to models is assumed +// to cost nothing. +func ResolverMiddleware(middleware graphql.ResolverMiddleware) Option { + return func(cfg *Config) { + if cfg.resolverHook == nil { + cfg.resolverHook = middleware + return + } + + lastResolve := cfg.resolverHook + cfg.resolverHook = func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) { + return lastResolve(ctx, func(ctx context.Context) (res interface{}, err error) { + return middleware(ctx, next) + }) + } + } +} + +// RequestMiddleware allows you to define a function that will be called around the root request, +// after the query has been parsed. This is useful for logging and tracing +func RequestMiddleware(middleware graphql.RequestMiddleware) Option { + return func(cfg *Config) { + if cfg.requestHook == nil { + cfg.requestHook = middleware + return + } + + lastResolve := cfg.requestHook + cfg.requestHook = func(ctx context.Context, next func(ctx context.Context) []byte) []byte { + return lastResolve(ctx, func(ctx context.Context) []byte { + return middleware(ctx, next) + }) + } + } +} + +func GraphQL(exec graphql.ExecutableSchema, options ...Option) http.HandlerFunc { + cfg := Config{ + upgrader: websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + }, + } + + for _, option := range options { + option(&cfg) + } + + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodOptions { + w.Header().Set("Allow", "OPTIONS, GET, POST") + w.WriteHeader(http.StatusOK) + return + } + + if strings.Contains(r.Header.Get("Upgrade"), "websocket") { + connectWs(exec, w, r, &cfg) + return + } + + var reqParams params + switch r.Method { + case http.MethodGet: + reqParams.Query = r.URL.Query().Get("query") + reqParams.OperationName = r.URL.Query().Get("operationName") + + if variables := r.URL.Query().Get("variables"); variables != "" { + if err := json.Unmarshal([]byte(variables), &reqParams.Variables); err != nil { + sendErrorf(w, http.StatusBadRequest, "variables could not be decoded") + return + } + } + case http.MethodPost: + if err := json.NewDecoder(r.Body).Decode(&reqParams); err != nil { + sendErrorf(w, http.StatusBadRequest, "json body could not be decoded: "+err.Error()) + return + } + default: + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + w.Header().Set("Content-Type", "application/json") + + doc, qErr := query.Parse(reqParams.Query) + if qErr != nil { + sendError(w, http.StatusUnprocessableEntity, qErr) + return + } + + errs := validation.Validate(exec.Schema(), doc) + if len(errs) != 0 { + sendError(w, http.StatusUnprocessableEntity, errs...) + return + } + + op, err := doc.GetOperation(reqParams.OperationName) + if err != nil { + sendErrorf(w, http.StatusUnprocessableEntity, err.Error()) + return + } + + reqCtx := cfg.newRequestContext(doc, reqParams.Query, reqParams.Variables) + ctx := graphql.WithRequestContext(r.Context(), reqCtx) + + defer func() { + if err := recover(); err != nil { + userErr := reqCtx.Recover(ctx, err) + sendErrorf(w, http.StatusUnprocessableEntity, userErr.Error()) + } + }() + + switch op.Type { + case query.Query: + b, err := json.Marshal(exec.Query(ctx, op)) + if err != nil { + panic(err) + } + w.Write(b) + case query.Mutation: + b, err := json.Marshal(exec.Mutation(ctx, op)) + if err != nil { + panic(err) + } + w.Write(b) + default: + sendErrorf(w, http.StatusBadRequest, "unsupported operation type") + } + }) +} + +func sendError(w http.ResponseWriter, code int, errors ...*errors.QueryError) { + w.WriteHeader(code) + var errs []*graphql.Error + for _, err := range errors { + var locations []graphql.ErrorLocation + for _, l := range err.Locations { + fmt.Println(graphql.ErrorLocation(l)) + locations = append(locations, graphql.ErrorLocation{ + Line: l.Line, + Column: l.Column, + }) + } + + errs = append(errs, &graphql.Error{ + Message: err.Message, + Path: err.Path, + Locations: locations, + }) + } + b, err := json.Marshal(&graphql.Response{Errors: errs}) + if err != nil { + panic(err) + } + w.Write(b) +} + +func sendErrorf(w http.ResponseWriter, code int, format string, args ...interface{}) { + sendError(w, code, &errors.QueryError{Message: fmt.Sprintf(format, args...)}) +} |