github.com/animeshon/gqlgen@v0.13.1-0.20210304133704-3a770431bb6d/client/websocket.go (about) 1 package client 2 3 import ( 4 "encoding/json" 5 "fmt" 6 "io/ioutil" 7 "net/http/httptest" 8 "strings" 9 10 "github.com/gorilla/websocket" 11 ) 12 13 const ( 14 connectionInitMsg = "connection_init" // Client -> Server 15 startMsg = "start" // Client -> Server 16 connectionAckMsg = "connection_ack" // Server -> Client 17 connectionKaMsg = "ka" // Server -> Client 18 dataMsg = "data" // Server -> Client 19 errorMsg = "error" // Server -> Client 20 ) 21 22 type operationMessage struct { 23 Payload json.RawMessage `json:"payload,omitempty"` 24 ID string `json:"id,omitempty"` 25 Type string `json:"type"` 26 } 27 28 type Subscription struct { 29 Close func() error 30 Next func(response interface{}) error 31 } 32 33 func errorSubscription(err error) *Subscription { 34 return &Subscription{ 35 Close: func() error { return nil }, 36 Next: func(response interface{}) error { 37 return err 38 }, 39 } 40 } 41 42 func (p *Client) Websocket(query string, options ...Option) *Subscription { 43 return p.WebsocketWithPayload(query, nil, options...) 44 } 45 46 // Grab a single response from a websocket based query 47 func (p *Client) WebsocketOnce(query string, resp interface{}, options ...Option) error { 48 sock := p.Websocket(query) 49 defer sock.Close() 50 return sock.Next(&resp) 51 } 52 53 func (p *Client) WebsocketWithPayload(query string, initPayload map[string]interface{}, options ...Option) *Subscription { 54 r, err := p.newRequest(query, options...) 55 if err != nil { 56 return errorSubscription(fmt.Errorf("request: %s", err.Error())) 57 } 58 59 requestBody, err := ioutil.ReadAll(r.Body) 60 if err != nil { 61 return errorSubscription(fmt.Errorf("parse body: %s", err.Error())) 62 } 63 64 srv := httptest.NewServer(p.h) 65 host := strings.Replace(srv.URL, "http://", "ws://", -1) 66 c, _, err := websocket.DefaultDialer.Dial(host+r.URL.Path, r.Header) 67 68 if err != nil { 69 return errorSubscription(fmt.Errorf("dial: %s", err.Error())) 70 } 71 72 initMessage := operationMessage{Type: connectionInitMsg} 73 if initPayload != nil { 74 initMessage.Payload, err = json.Marshal(initPayload) 75 if err != nil { 76 return errorSubscription(fmt.Errorf("parse payload: %s", err.Error())) 77 } 78 } 79 80 if err = c.WriteJSON(initMessage); err != nil { 81 return errorSubscription(fmt.Errorf("init: %s", err.Error())) 82 } 83 84 var ack operationMessage 85 if err = c.ReadJSON(&ack); err != nil { 86 return errorSubscription(fmt.Errorf("ack: %s", err.Error())) 87 } 88 89 if ack.Type != connectionAckMsg { 90 return errorSubscription(fmt.Errorf("expected ack message, got %#v", ack)) 91 } 92 93 var ka operationMessage 94 if err = c.ReadJSON(&ka); err != nil { 95 return errorSubscription(fmt.Errorf("ack: %s", err.Error())) 96 } 97 98 if ka.Type != connectionKaMsg { 99 return errorSubscription(fmt.Errorf("expected ack message, got %#v", ack)) 100 } 101 102 if err = c.WriteJSON(operationMessage{Type: startMsg, ID: "1", Payload: requestBody}); err != nil { 103 return errorSubscription(fmt.Errorf("start: %s", err.Error())) 104 } 105 106 return &Subscription{ 107 Close: func() error { 108 srv.Close() 109 return c.Close() 110 }, 111 Next: func(response interface{}) error { 112 var op operationMessage 113 err := c.ReadJSON(&op) 114 if err != nil { 115 return err 116 } 117 if op.Type != dataMsg { 118 if op.Type == errorMsg { 119 return fmt.Errorf(string(op.Payload)) 120 } else { 121 return fmt.Errorf("expected data message, got %#v", op) 122 } 123 } 124 125 var respDataRaw Response 126 err = json.Unmarshal(op.Payload, &respDataRaw) 127 if err != nil { 128 return fmt.Errorf("decode: %s", err.Error()) 129 } 130 131 // we want to unpack even if there is an error, so we can see partial responses 132 unpackErr := unpack(respDataRaw.Data, response) 133 134 if respDataRaw.Errors != nil { 135 return RawJsonError{respDataRaw.Errors} 136 } 137 return unpackErr 138 }, 139 } 140 }