package handler
import (
"context"
"encoding/json"
"fmt"
"net/http"
"strings"
"github.com/gorilla/websocket"
"github.com/vektah/gqlgen/graphql"
"github.com/vektah/gqlgen/neelance/errors"
"github.com/vektah/gqlgen/neelance/query"
"github.com/vektah/gqlgen/neelance/validation"
)
type params struct {
Query string `json:"query"`
OperationName string `json:"operationName"`
Variables map[string]interface{} `json:"variables"`
}
type Config struct {
upgrader websocket.Upgrader
recover graphql.RecoverFunc
errorPresenter graphql.ErrorPresenterFunc
resolverHook graphql.ResolverMiddleware
requestHook graphql.RequestMiddleware
}
func (c *Config) newRequestContext(doc *query.Document, query string, variables map[string]interface{}) *graphql.RequestContext {
reqCtx := graphql.NewRequestContext(doc, query, variables)
if hook := c.recover; hook != nil {
reqCtx.Recover = hook
}
if hook := c.errorPresenter; hook != nil {
reqCtx.ErrorPresenter = hook
}
if hook := c.resolverHook; hook != nil {
reqCtx.ResolverMiddleware = hook
}
if hook := c.requestHook; hook != nil {
reqCtx.RequestMiddleware = hook
}
return reqCtx
}
type Option func(cfg *Config)
func WebsocketUpgrader(upgrader websocket.Upgrader) Option {
return func(cfg *Config) {
cfg.upgrader = upgrader
}
}
func RecoverFunc(recover graphql.RecoverFunc) Option {
return func(cfg *Config) {
cfg.recover = recover
}
}
// ErrorPresenter transforms errors found while resolving into errors that will be returned to the user. It provides
// a good place to add any extra fields, like error.type, that might be desired by your frontend. Check the default
// implementation in graphql.DefaultErrorPresenter for an example.
func ErrorPresenter(f graphql.ErrorPresenterFunc) Option {
return func(cfg *Config) {
cfg.errorPresenter = f
}
}
// ResolverMiddleware allows you to define a function that will be called around every resolver,
// useful for tracing and logging.
// It will only be called for user defined resolvers, any direct binding to models is assumed
// to cost nothing.
func ResolverMiddleware(middleware graphql.ResolverMiddleware) Option {
return func(cfg *Config) {
if cfg.resolverHook == nil {
cfg.resolverHook = middleware
return
}
lastResolve := cfg.resolverHook
cfg.resolverHook = func(ctx context.Context, next graphql.Resolver) (res interface{}, err error) {
return lastResolve(ctx, func(ctx context.Context) (res interface{}, err error) {
return middleware(ctx, next)
})
}
}
}
// RequestMiddleware allows you to define a function that will be called around the root request,
// after the query has been parsed. This is useful for logging and tracing
func RequestMiddleware(middleware graphql.RequestMiddleware) Option {
return func(cfg *Config) {
if cfg.requestHook == nil {
cfg.requestHook = middleware
return
}
lastResolve := cfg.requestHook
cfg.requestHook = func(ctx context.Context, next func(ctx context.Context) []byte) []byte {
return lastResolve(ctx, func(ctx context.Context) []byte {
return middleware(ctx, next)
})
}
}
}
func GraphQL(exec graphql.ExecutableSchema, options ...Option) http.HandlerFunc {
cfg := Config{
upgrader: websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
},
}
for _, option := range options {
option(&cfg)
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodOptions {
w.Header().Set("Allow", "OPTIONS, GET, POST")
w.WriteHeader(http.StatusOK)
return
}
if strings.Contains(r.Header.Get("Upgrade"), "websocket") {
connectWs(exec, w, r, &cfg)
return
}
var reqParams params
switch r.Method {
case http.MethodGet:
reqParams.Query = r.URL.Query().Get("query")
reqParams.OperationName = r.URL.Query().Get("operationName")
if variables := r.URL.Query().Get("variables"); variables != "" {
if err := json.Unmarshal([]byte(variables), &reqParams.Variables); err != nil {
sendErrorf(w, http.StatusBadRequest, "variables could not be decoded")
return
}
}
case http.MethodPost:
if err := json.NewDecoder(r.Body).Decode(&reqParams); err != nil {
sendErrorf(w, http.StatusBadRequest, "json body could not be decoded: "+err.Error())
return
}
default:
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
w.Header().Set("Content-Type", "application/json")
doc, qErr := query.Parse(reqParams.Query)
if qErr != nil {
sendError(w, http.StatusUnprocessableEntity, qErr)
return
}
errs := validation.Validate(exec.Schema(), doc)
if len(errs) != 0 {
sendError(w, http.StatusUnprocessableEntity, errs...)
return
}
op, err := doc.GetOperation(reqParams.OperationName)
if err != nil {
sendErrorf(w, http.StatusUnprocessableEntity, err.Error())
return
}
reqCtx := cfg.newRequestContext(doc, reqParams.Query, reqParams.Variables)
ctx := graphql.WithRequestContext(r.Context(), reqCtx)
defer func() {
if err := recover(); err != nil {
userErr := reqCtx.Recover(ctx, err)
sendErrorf(w, http.StatusUnprocessableEntity, userErr.Error())
}
}()
switch op.Type {
case query.Query:
b, err := json.Marshal(exec.Query(ctx, op))
if err != nil {
panic(err)
}
w.Write(b)
case query.Mutation:
b, err := json.Marshal(exec.Mutation(ctx, op))
if err != nil {
panic(err)
}
w.Write(b)
default:
sendErrorf(w, http.StatusBadRequest, "unsupported operation type")
}
})
}
func sendError(w http.ResponseWriter, code int, errors ...*errors.QueryError) {
w.WriteHeader(code)
var errs []*graphql.Error
for _, err := range errors {
var locations []graphql.ErrorLocation
for _, l := range err.Locations {
fmt.Println(graphql.ErrorLocation(l))
locations = append(locations, graphql.ErrorLocation{
Line: l.Line,
Column: l.Column,
})
}
errs = append(errs, &graphql.Error{
Message: err.Message,
Path: err.Path,
Locations: locations,
})
}
b, err := json.Marshal(&graphql.Response{Errors: errs})
if err != nil {
panic(err)
}
w.Write(b)
}
func sendErrorf(w http.ResponseWriter, code int, format string, args ...interface{}) {
sendError(w, code, &errors.QueryError{Message: fmt.Sprintf(format, args...)})
}