From e381d5554a1b2b6e3a750206a853e090ec8183ab Mon Sep 17 00:00:00 2001 From: Amine Hilaly Date: Sun, 7 Jul 2019 13:37:03 +0200 Subject: Update gqlgen vendors --- .../github.com/99designs/gqlgen/handler/graphql.go | 93 ++++++++++++++++++++++ .../99designs/gqlgen/handler/websocket.go | 10 ++- 2 files changed, 102 insertions(+), 1 deletion(-) (limited to 'vendor/github.com/99designs/gqlgen/handler') diff --git a/vendor/github.com/99designs/gqlgen/handler/graphql.go b/vendor/github.com/99designs/gqlgen/handler/graphql.go index a2254222..289901f0 100644 --- a/vendor/github.com/99designs/gqlgen/handler/graphql.go +++ b/vendor/github.com/99designs/gqlgen/handler/graphql.go @@ -2,6 +2,8 @@ package handler import ( "context" + "crypto/sha256" + "encoding/hex" "encoding/json" "errors" "fmt" @@ -28,8 +30,30 @@ type params struct { Query string `json:"query"` OperationName string `json:"operationName"` Variables map[string]interface{} `json:"variables"` + Extensions *extensions `json:"extensions"` } +type extensions struct { + PersistedQuery *persistedQuery `json:"persistedQuery"` +} + +type persistedQuery struct { + Sha256 string `json:"sha256Hash"` + Version int64 `json:"version"` +} + +const ( + errPersistedQueryNotSupported = "PersistedQueryNotSupported" + errPersistedQueryNotFound = "PersistedQueryNotFound" +) + +type PersistedQueryCache interface { + Add(ctx context.Context, hash string, query string) + Get(ctx context.Context, hash string) (string, bool) +} + +type websocketInitFunc func(ctx context.Context, initPayload InitPayload) error + type Config struct { cacheSize int upgrader websocket.Upgrader @@ -40,10 +64,12 @@ type Config struct { tracer graphql.Tracer complexityLimit int complexityLimitFunc graphql.ComplexityLimitFunc + websocketInitFunc websocketInitFunc disableIntrospection bool connectionKeepAlivePingInterval time.Duration uploadMaxMemory int64 uploadMaxSize int64 + apqCache PersistedQueryCache } func (c *Config) newRequestContext(es graphql.ExecutableSchema, doc *ast.QueryDocument, op *ast.OperationDefinition, query string, variables map[string]interface{}) *graphql.RequestContext { @@ -250,6 +276,14 @@ func (tw *tracerWrapper) EndOperationExecution(ctx context.Context) { tw.tracer1.EndOperationExecution(ctx) } +// WebsocketInitFunc is called when the server receives connection init message from the client. +// This can be used to check initial payload to see whether to accept the websocket connection. +func WebsocketInitFunc(websocketInitFunc func(ctx context.Context, initPayload InitPayload) error) Option { + return func(cfg *Config) { + cfg.websocketInitFunc = websocketInitFunc + } +} + // CacheSize sets the maximum size of the query cache. // If size is less than or equal to 0, the cache is disabled. func CacheSize(size int) Option { @@ -285,6 +319,13 @@ func WebsocketKeepAliveDuration(duration time.Duration) Option { } } +// Add cache that will hold queries for automatic persisted queries (APQ) +func EnablePersistedQueryCache(cache PersistedQueryCache) Option { + return func(cfg *Config) { + cfg.apqCache = cache + } +} + const DefaultCacheSize = 1000 const DefaultConnectionKeepAlivePingInterval = 25 * time.Second @@ -344,6 +385,11 @@ type graphqlHandler struct { exec graphql.ExecutableSchema } +func computeQueryHash(query string) string { + b := sha256.Sum256([]byte(query)) + return hex.EncodeToString(b[:]) +} + func (gh *graphqlHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { if r.Method == http.MethodOptions { w.Header().Set("Allow", "OPTIONS, GET, POST") @@ -369,6 +415,13 @@ func (gh *graphqlHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } } + + if extensions := r.URL.Query().Get("extensions"); extensions != "" { + if err := jsonDecode(strings.NewReader(extensions), &reqParams.Extensions); err != nil { + sendErrorf(w, http.StatusBadRequest, "extensions could not be decoded") + return + } + } case http.MethodPost: mediaType, _, err := mime.ParseMediaType(r.Header.Get("Content-Type")) if err != nil { @@ -409,6 +462,41 @@ func (gh *graphqlHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { ctx := r.Context() + var queryHash string + apqRegister := false + apq := reqParams.Extensions != nil && reqParams.Extensions.PersistedQuery != nil + if apq { + // client has enabled apq + queryHash = reqParams.Extensions.PersistedQuery.Sha256 + if gh.cfg.apqCache == nil { + // server has disabled apq + sendErrorf(w, http.StatusOK, errPersistedQueryNotSupported) + return + } + if reqParams.Extensions.PersistedQuery.Version != 1 { + sendErrorf(w, http.StatusOK, "Unsupported persisted query version") + return + } + if reqParams.Query == "" { + // client sent optimistic query hash without query string + query, ok := gh.cfg.apqCache.Get(ctx, queryHash) + if !ok { + sendErrorf(w, http.StatusOK, errPersistedQueryNotFound) + return + } + reqParams.Query = query + } else { + if computeQueryHash(reqParams.Query) != queryHash { + sendErrorf(w, http.StatusOK, "provided sha does not match query") + return + } + apqRegister = true + } + } else if reqParams.Query == "" { + sendErrorf(w, http.StatusUnprocessableEntity, "Must provide query string") + return + } + var doc *ast.QueryDocument var cacheHit bool if gh.cache != nil { @@ -463,6 +551,11 @@ func (gh *graphqlHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } + if apqRegister && gh.cfg.apqCache != nil { + // Add to persisted query cache + gh.cfg.apqCache.Add(ctx, queryHash, reqParams.Query) + } + switch op.Operation { case ast.Query: b, err := json.Marshal(gh.exec.Query(ctx, op)) diff --git a/vendor/github.com/99designs/gqlgen/handler/websocket.go b/vendor/github.com/99designs/gqlgen/handler/websocket.go index 58f38e5d..07a1a8c2 100644 --- a/vendor/github.com/99designs/gqlgen/handler/websocket.go +++ b/vendor/github.com/99designs/gqlgen/handler/websocket.go @@ -12,7 +12,7 @@ import ( "github.com/99designs/gqlgen/graphql" "github.com/gorilla/websocket" - "github.com/hashicorp/golang-lru" + lru "github.com/hashicorp/golang-lru" "github.com/vektah/gqlparser" "github.com/vektah/gqlparser/ast" "github.com/vektah/gqlparser/gqlerror" @@ -94,6 +94,14 @@ func (c *wsConnection) init() bool { } } + if c.cfg.websocketInitFunc != nil { + if err := c.cfg.websocketInitFunc(c.ctx, c.initPayload); err != nil { + c.sendConnectionError(err.Error()) + c.close(websocket.CloseNormalClosure, "terminated") + return false + } + } + c.write(&operationMessage{Type: connectionAckMsg}) case connectionTerminateMsg: c.close(websocket.CloseNormalClosure, "terminated") -- cgit