aboutsummaryrefslogtreecommitdiffstats
path: root/vendor/github.com/vektah/gqlgen/client/websocket.go
diff options
context:
space:
mode:
Diffstat (limited to 'vendor/github.com/vektah/gqlgen/client/websocket.go')
-rw-r--r--vendor/github.com/vektah/gqlgen/client/websocket.go103
1 files changed, 103 insertions, 0 deletions
diff --git a/vendor/github.com/vektah/gqlgen/client/websocket.go b/vendor/github.com/vektah/gqlgen/client/websocket.go
new file mode 100644
index 00000000..bd92e3c0
--- /dev/null
+++ b/vendor/github.com/vektah/gqlgen/client/websocket.go
@@ -0,0 +1,103 @@
+package client
+
+import (
+ "encoding/json"
+ "fmt"
+ "strings"
+
+ "github.com/gorilla/websocket"
+ "github.com/vektah/gqlparser/gqlerror"
+)
+
+const (
+ connectionInitMsg = "connection_init" // Client -> Server
+ startMsg = "start" // Client -> Server
+ connectionAckMsg = "connection_ack" // Server -> Client
+ dataMsg = "data" // Server -> Client
+ errorMsg = "error" // Server -> Client
+)
+
+type operationMessage struct {
+ Payload json.RawMessage `json:"payload,omitempty"`
+ ID string `json:"id,omitempty"`
+ Type string `json:"type"`
+}
+
+type Subscription struct {
+ Close func() error
+ Next func(response interface{}) error
+}
+
+func errorSubscription(err error) *Subscription {
+ return &Subscription{
+ Close: func() error { return nil },
+ Next: func(response interface{}) error {
+ return err
+ },
+ }
+}
+
+func (p *Client) Websocket(query string, options ...Option) *Subscription {
+ r := p.mkRequest(query, options...)
+ requestBody, err := json.Marshal(r)
+ if err != nil {
+ return errorSubscription(fmt.Errorf("encode: %s", err.Error()))
+ }
+
+ url := strings.Replace(p.url, "http://", "ws://", -1)
+ url = strings.Replace(url, "https://", "wss://", -1)
+
+ c, _, err := websocket.DefaultDialer.Dial(url, nil)
+ if err != nil {
+ return errorSubscription(fmt.Errorf("dial: %s", err.Error()))
+ }
+
+ if err = c.WriteJSON(operationMessage{Type: connectionInitMsg}); err != nil {
+ return errorSubscription(fmt.Errorf("init: %s", err.Error()))
+ }
+
+ var ack operationMessage
+ if err = c.ReadJSON(&ack); err != nil {
+ return errorSubscription(fmt.Errorf("ack: %s", err.Error()))
+ }
+ if ack.Type != connectionAckMsg {
+ return errorSubscription(fmt.Errorf("expected ack message, got %#v", ack))
+ }
+
+ if err = c.WriteJSON(operationMessage{Type: startMsg, ID: "1", Payload: requestBody}); err != nil {
+ return errorSubscription(fmt.Errorf("start: %s", err.Error()))
+ }
+
+ return &Subscription{
+ Close: c.Close,
+ Next: func(response interface{}) error {
+ var op operationMessage
+ c.ReadJSON(&op)
+ if op.Type != dataMsg {
+ if op.Type == errorMsg {
+ return fmt.Errorf(string(op.Payload))
+ } else {
+ return fmt.Errorf("expected data message, got %#v", op)
+ }
+ }
+
+ respDataRaw := map[string]interface{}{}
+ err = json.Unmarshal(op.Payload, &respDataRaw)
+ if err != nil {
+ return fmt.Errorf("decode: %s", err.Error())
+ }
+
+ if respDataRaw["errors"] != nil {
+ var errs []*gqlerror.Error
+ if err = unpack(respDataRaw["errors"], &errs); err != nil {
+ return err
+ }
+ if len(errs) > 0 {
+ return fmt.Errorf("errors: %s", errs)
+ }
+ }
+
+ return unpack(respDataRaw["data"], response)
+ },
+ }
+}