package handler
import (
"bytes"
"context"
"encoding/json"
"fmt"
"log"
"net/http"
"sync"
"time"
"github.com/99designs/gqlgen/graphql"
"github.com/gorilla/websocket"
lru "github.com/hashicorp/golang-lru"
"github.com/vektah/gqlparser"
"github.com/vektah/gqlparser/ast"
"github.com/vektah/gqlparser/gqlerror"
"github.com/vektah/gqlparser/validator"
)
const (
connectionInitMsg = "connection_init" // Client -> Server
connectionTerminateMsg = "connection_terminate" // Client -> Server
startMsg = "start" // Client -> Server
stopMsg = "stop" // Client -> Server
connectionAckMsg = "connection_ack" // Server -> Client
connectionErrorMsg = "connection_error" // Server -> Client
dataMsg = "data" // Server -> Client
errorMsg = "error" // Server -> Client
completeMsg = "complete" // Server -> Client
connectionKeepAliveMsg = "ka" // Server -> Client
)
type operationMessage struct {
Payload json.RawMessage `json:"payload,omitempty"`
ID string `json:"id,omitempty"`
Type string `json:"type"`
}
type wsConnection struct {
ctx context.Context
conn *websocket.Conn
exec graphql.ExecutableSchema
active map[string]context.CancelFunc
mu sync.Mutex
cfg *Config
cache *lru.Cache
keepAliveTicker *time.Ticker
initPayload InitPayload
}
func connectWs(exec graphql.ExecutableSchema, w http.ResponseWriter, r *http.Request, cfg *Config, cache *lru.Cache) {
ws, err := cfg.upgrader.Upgrade(w, r, http.Header{
"Sec-Websocket-Protocol": []string{"graphql-ws"},
})
if err != nil {
log.Printf("unable to upgrade %T to websocket %s: ", w, err.Error())
sendErrorf(w, http.StatusBadRequest, "unable to upgrade")
return
}
conn := wsConnection{
active: map[string]context.CancelFunc{},
exec: exec,
conn: ws,
ctx: r.Context(),
cfg: cfg,
cache: cache,
}
if !conn.init() {
return
}
conn.run()
}
func (c *wsConnection) init() bool {
message := c.readOp()
if message == nil {
c.close(websocket.CloseProtocolError, "decoding error")
return false
}
switch message.Type {
case connectionInitMsg:
if len(message.Payload) > 0 {
c.initPayload = make(InitPayload)
err := json.Unmarshal(message.Payload, &c.initPayload)
if err != nil {
return false
}
}
if c.cfg.websocketInitFunc != nil {
if err := c.cfg.websocketInitFunc(c.ctx, c.initPayload); err != nil {
c.sendConnectionError(err.Error())
c.close(websocket.CloseNormalClosure, "terminated")
return false
}
}
c.write(&operationMessage{Type: connectionAckMsg})
case connectionTerminateMsg:
c.close(websocket.CloseNormalClosure, "terminated")
return false
default:
c.sendConnectionError("unexpected message %s", message.Type)
c.close(websocket.CloseProtocolError, "unexpected message")
return false
}
return true
}
func (c *wsConnection) write(msg *operationMessage) {
c.mu.Lock()
c.conn.WriteJSON(msg)
c.mu.Unlock()
}
func (c *wsConnection) run() {
// We create a cancellation that will shutdown the keep-alive when we leave
// this function.
ctx, cancel := context.WithCancel(c.ctx)
defer cancel()
// Create a timer that will fire every interval to keep the connection alive.
if c.cfg.connectionKeepAlivePingInterval != 0 {
c.mu.Lock()
c.keepAliveTicker = time.NewTicker(c.cfg.connectionKeepAlivePingInterval)
c.mu.Unlock()
go c.keepAlive(ctx)
}
for {
message := c.readOp()
if message == nil {
return
}
switch message.Type {
case startMsg:
if !c.subscribe(message) {
return
}
case stopMsg:
c.mu.Lock()
closer := c.active[message.ID]
c.mu.Unlock()
if closer == nil {
c.sendError(message.ID, gqlerror.Errorf("%s is not running, cannot stop", message.ID))
continue
}
closer()
case connectionTerminateMsg:
c.close(websocket.CloseNormalClosure, "terminated")
return
default:
c.sendConnectionError("unexpected message %s", message.Type)
c.close(websocket.CloseProtocolError, "unexpected message")
return
}
}
}
func (c *wsConnection) keepAlive(ctx context.Context) {
for {
select {
case <-ctx.Done():
c.keepAliveTicker.Stop()
return
case <-c.keepAliveTicker.C:
c.write(&operationMessage{Type: connectionKeepAliveMsg})
}
}
}
func (c *wsConnection) subscribe(message *operationMessage) bool {
var reqParams params
if err := jsonDecode(bytes.NewReader(message.Payload), &reqParams); err != nil {
c.sendConnectionError("invalid json")
return false
}
var (
doc *ast.QueryDocument
cacheHit bool
)
if c.cache != nil {
val, ok := c.cache.Get(reqParams.Query)
if ok {
doc = val.(*ast.QueryDocument)
cacheHit = true
}
}
if !cacheHit {
var qErr gqlerror.List
doc, qErr = gqlparser.LoadQuery(c.exec.Schema(), reqParams.Query)
if qErr != nil {
c.sendError(message.ID, qErr...)
return true
}
if c.cache != nil {
c.cache.Add(reqParams.Query, doc)
}
}
op := doc.Operations.ForName(reqParams.OperationName)
if op == nil {
c.sendError(message.ID, gqlerror.Errorf("operation %s not found", reqParams.OperationName))
return true
}
vars, err := validator.VariableValues(c.exec.Schema(), op, reqParams.Variables)
if err != nil {
c.sendError(message.ID, err)
return true
}
reqCtx := c.cfg.newRequestContext(c.exec, doc, op, reqParams.Query, vars)
ctx := graphql.WithRequestContext(c.ctx, reqCtx)
if c.initPayload != nil {
ctx = withInitPayload(ctx, c.initPayload)
}
if op.Operation != ast.Subscription {
var result *graphql.Response
if op.Operation == ast.Query {
result = c.exec.Query(ctx, op)
} else {
result = c.exec.Mutation(ctx, op)
}
c.sendData(message.ID, result)
c.write(&operationMessage{ID: message.ID, Type: completeMsg})
return true
}
ctx, cancel := context.WithCancel(ctx)
c.mu.Lock()
c.active[message.ID] = cancel
c.mu.Unlock()
go func() {
defer func() {
if r := recover(); r != nil {
userErr := reqCtx.Recover(ctx, r)
c.sendError(message.ID, &gqlerror.Error{Message: userErr.Error()})
}
}()
next := c.exec.Subscription(ctx, op)
for result := next(); result != nil; result = next() {
c.sendData(message.ID, result)
}
c.write(&operationMessage{ID: message.ID, Type: completeMsg})
c.mu.Lock()
delete(c.active, message.ID)
c.mu.Unlock()
cancel()
}()
return true
}
func (c *wsConnection) sendData(id string, response *graphql.Response) {
b, err := json.Marshal(response)
if err != nil {
c.sendError(id, gqlerror.Errorf("unable to encode json response: %s", err.Error()))
return
}
c.write(&operationMessage{Type: dataMsg, ID: id, Payload: b})
}
func (c *wsConnection) sendError(id string, errors ...*gqlerror.Error) {
var errs []error
for _, err := range errors {
errs = append(errs, err)
}
b, err := json.Marshal(errs)
if err != nil {
panic(err)
}
c.write(&operationMessage{Type: errorMsg, ID: id, Payload: b})
}
func (c *wsConnection) sendConnectionError(format string, args ...interface{}) {
b, err := json.Marshal(&gqlerror.Error{Message: fmt.Sprintf(format, args...)})
if err != nil {
panic(err)
}
c.write(&operationMessage{Type: connectionErrorMsg, Payload: b})
}
func (c *wsConnection) readOp() *operationMessage {
_, r, err := c.conn.NextReader()
if err != nil {
c.sendConnectionError("invalid json")
return nil
}
message := operationMessage{}
if err := jsonDecode(r, &message); err != nil {
c.sendConnectionError("invalid json")
return nil
}
return &message
}
func (c *wsConnection) close(closeCode int, message string) {
c.mu.Lock()
_ = c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(closeCode, message))
c.mu.Unlock()
_ = c.conn.Close()
}