aboutsummaryrefslogtreecommitdiffstats
path: root/vendor/github.com/99designs/gqlgen/handler/websocket.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/99designs/gqlgen/handler/websocket.go')
-rw-r--r--vendor/github.com/99designs/gqlgen/handler/websocket.go252
1 files changed, 252 insertions, 0 deletions
diff --git a/vendor/github.com/99designs/gqlgen/handler/websocket.go b/vendor/github.com/99designs/gqlgen/handler/websocket.go
new file mode 100644
index 00000000..2be1e87f
--- /dev/null
+++ b/vendor/github.com/99designs/gqlgen/handler/websocket.go
@@ -0,0 +1,252 @@
+package handler
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "fmt"
+ "log"
+ "net/http"
+ "sync"
+
+ "github.com/99designs/gqlgen/graphql"
+ "github.com/gorilla/websocket"
+ "github.com/vektah/gqlparser"
+ "github.com/vektah/gqlparser/ast"
+ "github.com/vektah/gqlparser/gqlerror"
+ "github.com/vektah/gqlparser/validator"
+)
+
+const (
+ connectionInitMsg = "connection_init" // Client -> Server
+ connectionTerminateMsg = "connection_terminate" // Client -> Server
+ startMsg = "start" // Client -> Server
+ stopMsg = "stop" // Client -> Server
+ connectionAckMsg = "connection_ack" // Server -> Client
+ connectionErrorMsg = "connection_error" // Server -> Client
+ dataMsg = "data" // Server -> Client
+ errorMsg = "error" // Server -> Client
+ completeMsg = "complete" // Server -> Client
+ //connectionKeepAliveMsg = "ka" // Server -> Client TODO: keepalives
+)
+
+type operationMessage struct {
+ Payload json.RawMessage `json:"payload,omitempty"`
+ ID string `json:"id,omitempty"`
+ Type string `json:"type"`
+}
+
+type wsConnection struct {
+ ctx context.Context
+ conn *websocket.Conn
+ exec graphql.ExecutableSchema
+ active map[string]context.CancelFunc
+ mu sync.Mutex
+ cfg *Config
+}
+
+func connectWs(exec graphql.ExecutableSchema, w http.ResponseWriter, r *http.Request, cfg *Config) {
+ ws, err := cfg.upgrader.Upgrade(w, r, http.Header{
+ "Sec-Websocket-Protocol": []string{"graphql-ws"},
+ })
+ if err != nil {
+ log.Printf("unable to upgrade %T to websocket %s: ", w, err.Error())
+ sendErrorf(w, http.StatusBadRequest, "unable to upgrade")
+ return
+ }
+
+ conn := wsConnection{
+ active: map[string]context.CancelFunc{},
+ exec: exec,
+ conn: ws,
+ ctx: r.Context(),
+ cfg: cfg,
+ }
+
+ if !conn.init() {
+ return
+ }
+
+ conn.run()
+}
+
+func (c *wsConnection) init() bool {
+ message := c.readOp()
+ if message == nil {
+ c.close(websocket.CloseProtocolError, "decoding error")
+ return false
+ }
+
+ switch message.Type {
+ case connectionInitMsg:
+ c.write(&operationMessage{Type: connectionAckMsg})
+ case connectionTerminateMsg:
+ c.close(websocket.CloseNormalClosure, "terminated")
+ return false
+ default:
+ c.sendConnectionError("unexpected message %s", message.Type)
+ c.close(websocket.CloseProtocolError, "unexpected message")
+ return false
+ }
+
+ return true
+}
+
+func (c *wsConnection) write(msg *operationMessage) {
+ c.mu.Lock()
+ c.conn.WriteJSON(msg)
+ c.mu.Unlock()
+}
+
+func (c *wsConnection) run() {
+ for {
+ message := c.readOp()
+ if message == nil {
+ return
+ }
+
+ switch message.Type {
+ case startMsg:
+ if !c.subscribe(message) {
+ return
+ }
+ case stopMsg:
+ c.mu.Lock()
+ closer := c.active[message.ID]
+ c.mu.Unlock()
+ if closer == nil {
+ c.sendError(message.ID, gqlerror.Errorf("%s is not running, cannot stop", message.ID))
+ continue
+ }
+
+ closer()
+ case connectionTerminateMsg:
+ c.close(websocket.CloseNormalClosure, "terminated")
+ return
+ default:
+ c.sendConnectionError("unexpected message %s", message.Type)
+ c.close(websocket.CloseProtocolError, "unexpected message")
+ return
+ }
+ }
+}
+
+func (c *wsConnection) subscribe(message *operationMessage) bool {
+ var reqParams params
+ if err := jsonDecode(bytes.NewReader(message.Payload), &reqParams); err != nil {
+ c.sendConnectionError("invalid json")
+ return false
+ }
+
+ doc, qErr := gqlparser.LoadQuery(c.exec.Schema(), reqParams.Query)
+ if qErr != nil {
+ c.sendError(message.ID, qErr...)
+ return true
+ }
+
+ op := doc.Operations.ForName(reqParams.OperationName)
+ if op == nil {
+ c.sendError(message.ID, gqlerror.Errorf("operation %s not found", reqParams.OperationName))
+ return true
+ }
+
+ vars, err := validator.VariableValues(c.exec.Schema(), op, reqParams.Variables)
+ if err != nil {
+ c.sendError(message.ID, err)
+ return true
+ }
+ reqCtx := c.cfg.newRequestContext(doc, reqParams.Query, vars)
+ ctx := graphql.WithRequestContext(c.ctx, reqCtx)
+
+ if op.Operation != ast.Subscription {
+ var result *graphql.Response
+ if op.Operation == ast.Query {
+ result = c.exec.Query(ctx, op)
+ } else {
+ result = c.exec.Mutation(ctx, op)
+ }
+
+ c.sendData(message.ID, result)
+ c.write(&operationMessage{ID: message.ID, Type: completeMsg})
+ return true
+ }
+
+ ctx, cancel := context.WithCancel(ctx)
+ c.mu.Lock()
+ c.active[message.ID] = cancel
+ c.mu.Unlock()
+ go func() {
+ defer func() {
+ if r := recover(); r != nil {
+ userErr := reqCtx.Recover(ctx, r)
+ c.sendError(message.ID, &gqlerror.Error{Message: userErr.Error()})
+ }
+ }()
+ next := c.exec.Subscription(ctx, op)
+ for result := next(); result != nil; result = next() {
+ c.sendData(message.ID, result)
+ }
+
+ c.write(&operationMessage{ID: message.ID, Type: completeMsg})
+
+ c.mu.Lock()
+ delete(c.active, message.ID)
+ c.mu.Unlock()
+ cancel()
+ }()
+
+ return true
+}
+
+func (c *wsConnection) sendData(id string, response *graphql.Response) {
+ b, err := json.Marshal(response)
+ if err != nil {
+ c.sendError(id, gqlerror.Errorf("unable to encode json response: %s", err.Error()))
+ return
+ }
+
+ c.write(&operationMessage{Type: dataMsg, ID: id, Payload: b})
+}
+
+func (c *wsConnection) sendError(id string, errors ...*gqlerror.Error) {
+ var errs []error
+ for _, err := range errors {
+ errs = append(errs, err)
+ }
+ b, err := json.Marshal(errs)
+ if err != nil {
+ panic(err)
+ }
+ c.write(&operationMessage{Type: errorMsg, ID: id, Payload: b})
+}
+
+func (c *wsConnection) sendConnectionError(format string, args ...interface{}) {
+ b, err := json.Marshal(&gqlerror.Error{Message: fmt.Sprintf(format, args...)})
+ if err != nil {
+ panic(err)
+ }
+
+ c.write(&operationMessage{Type: connectionErrorMsg, Payload: b})
+}
+
+func (c *wsConnection) readOp() *operationMessage {
+ _, r, err := c.conn.NextReader()
+ if err != nil {
+ c.sendConnectionError("invalid json")
+ return nil
+ }
+ message := operationMessage{}
+ if err := jsonDecode(r, &message); err != nil {
+ c.sendConnectionError("invalid json")
+ return nil
+ }
+
+ return &message
+}
+
+func (c *wsConnection) close(closeCode int, message string) {
+ c.mu.Lock()
+ _ = c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(closeCode, message))
+ c.mu.Unlock()
+ _ = c.conn.Close()
+}