github.com/99designs/gqlgen@v0.17.45/client/websocket.go (about) 1 package client 2 3 import ( 4 "encoding/json" 5 "fmt" 6 "io" 7 "net/http/httptest" 8 "reflect" 9 "strings" 10 11 "github.com/gorilla/websocket" 12 ) 13 14 const ( 15 connectionInitMsg = "connection_init" // Client -> Server 16 startMsg = "start" // Client -> Server 17 connectionAckMsg = "connection_ack" // Server -> Client 18 connectionKaMsg = "ka" // Server -> Client 19 dataMsg = "data" // Server -> Client 20 errorMsg = "error" // Server -> Client 21 ) 22 23 type operationMessage struct { 24 Payload json.RawMessage `json:"payload,omitempty"` 25 ID string `json:"id,omitempty"` 26 Type string `json:"type"` 27 } 28 29 type Subscription struct { 30 Close func() error 31 Next func(response interface{}) error 32 } 33 34 func errorSubscription(err error) *Subscription { 35 return &Subscription{ 36 Close: func() error { return nil }, 37 Next: func(response interface{}) error { 38 return err 39 }, 40 } 41 } 42 43 func (p *Client) Websocket(query string, options ...Option) *Subscription { 44 return p.WebsocketWithPayload(query, nil, options...) 45 } 46 47 // Grab a single response from a websocket based query 48 func (p *Client) WebsocketOnce(query string, resp interface{}, options ...Option) error { 49 sock := p.Websocket(query, options...) 50 defer sock.Close() 51 if reflect.ValueOf(resp).Kind() == reflect.Ptr { 52 return sock.Next(resp) 53 } 54 // TODO: verify this is never called and remove it 55 return sock.Next(&resp) 56 } 57 58 func (p *Client) WebsocketWithPayload(query string, initPayload map[string]interface{}, options ...Option) *Subscription { 59 r, err := p.newRequest(query, options...) 60 if err != nil { 61 return errorSubscription(fmt.Errorf("request: %w", err)) 62 } 63 64 requestBody, err := io.ReadAll(r.Body) 65 if err != nil { 66 return errorSubscription(fmt.Errorf("parse body: %w", err)) 67 } 68 69 srv := httptest.NewServer(p.h) 70 host := strings.ReplaceAll(srv.URL, "http://", "ws://") 71 c, resp, err := websocket.DefaultDialer.Dial(host+r.URL.Path, r.Header) 72 if err != nil { 73 return errorSubscription(fmt.Errorf("dial: %w", err)) 74 } 75 defer resp.Body.Close() 76 77 initMessage := operationMessage{Type: connectionInitMsg} 78 if initPayload != nil { 79 initMessage.Payload, err = json.Marshal(initPayload) 80 if err != nil { 81 return errorSubscription(fmt.Errorf("parse payload: %w", err)) 82 } 83 } 84 85 if err = c.WriteJSON(initMessage); err != nil { 86 return errorSubscription(fmt.Errorf("init: %w", err)) 87 } 88 89 var ack operationMessage 90 if err = c.ReadJSON(&ack); err != nil { 91 return errorSubscription(fmt.Errorf("ack: %w", err)) 92 } 93 94 if ack.Type != connectionAckMsg { 95 return errorSubscription(fmt.Errorf("expected ack message, got %#v", ack)) 96 } 97 98 var ka operationMessage 99 if err = c.ReadJSON(&ka); err != nil { 100 return errorSubscription(fmt.Errorf("ack: %w", err)) 101 } 102 103 if ka.Type != connectionKaMsg { 104 return errorSubscription(fmt.Errorf("expected ack message, got %#v", ack)) 105 } 106 107 if err = c.WriteJSON(operationMessage{Type: startMsg, ID: "1", Payload: requestBody}); err != nil { 108 return errorSubscription(fmt.Errorf("start: %w", err)) 109 } 110 111 return &Subscription{ 112 Close: func() error { 113 srv.Close() 114 return c.Close() 115 }, 116 Next: func(response interface{}) error { 117 for { 118 var op operationMessage 119 err := c.ReadJSON(&op) 120 if err != nil { 121 return err 122 } 123 124 switch op.Type { 125 case dataMsg: 126 break 127 case connectionKaMsg: 128 continue 129 case errorMsg: 130 return fmt.Errorf(string(op.Payload)) 131 default: 132 return fmt.Errorf("expected data message, got %#v", op) 133 } 134 135 var respDataRaw Response 136 err = json.Unmarshal(op.Payload, &respDataRaw) 137 if err != nil { 138 return fmt.Errorf("decode: %w", err) 139 } 140 141 // we want to unpack even if there is an error, so we can see partial responses 142 unpackErr := unpack(respDataRaw.Data, response, p.dc) 143 144 if respDataRaw.Errors != nil { 145 return RawJsonError{respDataRaw.Errors} 146 } 147 return unpackErr 148 } 149 }, 150 } 151 }