github.com/shippio/gqlgen@v0.0.0-20220912092219-633ea699ef07/client/websocket.go (about)

     1  package client
     2  
     3  import (
     4  	"encoding/json"
     5  	"fmt"
     6  	"io"
     7  	"net/http/httptest"
     8  	"strings"
     9  
    10  	"github.com/gorilla/websocket"
    11  )
    12  
    13  const (
    14  	connectionInitMsg = "connection_init" // Client -> Server
    15  	startMsg          = "start"           // Client -> Server
    16  	connectionAckMsg  = "connection_ack"  // Server -> Client
    17  	connectionKaMsg   = "ka"              // Server -> Client
    18  	dataMsg           = "data"            // Server -> Client
    19  	errorMsg          = "error"           // Server -> Client
    20  )
    21  
    22  type operationMessage struct {
    23  	Payload json.RawMessage `json:"payload,omitempty"`
    24  	ID      string          `json:"id,omitempty"`
    25  	Type    string          `json:"type"`
    26  }
    27  
    28  type Subscription struct {
    29  	Close func() error
    30  	Next  func(response interface{}) error
    31  }
    32  
    33  func errorSubscription(err error) *Subscription {
    34  	return &Subscription{
    35  		Close: func() error { return nil },
    36  		Next: func(response interface{}) error {
    37  			return err
    38  		},
    39  	}
    40  }
    41  
    42  func (p *Client) Websocket(query string, options ...Option) *Subscription {
    43  	return p.WebsocketWithPayload(query, nil, options...)
    44  }
    45  
    46  // Grab a single response from a websocket based query
    47  func (p *Client) WebsocketOnce(query string, resp interface{}, options ...Option) error {
    48  	sock := p.Websocket(query, options...)
    49  	defer sock.Close()
    50  	return sock.Next(&resp)
    51  }
    52  
    53  func (p *Client) WebsocketWithPayload(query string, initPayload map[string]interface{}, options ...Option) *Subscription {
    54  	r, err := p.newRequest(query, options...)
    55  	if err != nil {
    56  		return errorSubscription(fmt.Errorf("request: %w", err))
    57  	}
    58  
    59  	requestBody, err := io.ReadAll(r.Body)
    60  	if err != nil {
    61  		return errorSubscription(fmt.Errorf("parse body: %w", err))
    62  	}
    63  
    64  	srv := httptest.NewServer(p.h)
    65  	host := strings.ReplaceAll(srv.URL, "http://", "ws://")
    66  	c, resp, err := websocket.DefaultDialer.Dial(host+r.URL.Path, r.Header)
    67  	if err != nil {
    68  		return errorSubscription(fmt.Errorf("dial: %w", err))
    69  	}
    70  	defer resp.Body.Close()
    71  
    72  	initMessage := operationMessage{Type: connectionInitMsg}
    73  	if initPayload != nil {
    74  		initMessage.Payload, err = json.Marshal(initPayload)
    75  		if err != nil {
    76  			return errorSubscription(fmt.Errorf("parse payload: %w", err))
    77  		}
    78  	}
    79  
    80  	if err = c.WriteJSON(initMessage); err != nil {
    81  		return errorSubscription(fmt.Errorf("init: %w", err))
    82  	}
    83  
    84  	var ack operationMessage
    85  	if err = c.ReadJSON(&ack); err != nil {
    86  		return errorSubscription(fmt.Errorf("ack: %w", err))
    87  	}
    88  
    89  	if ack.Type != connectionAckMsg {
    90  		return errorSubscription(fmt.Errorf("expected ack message, got %#v", ack))
    91  	}
    92  
    93  	var ka operationMessage
    94  	if err = c.ReadJSON(&ka); err != nil {
    95  		return errorSubscription(fmt.Errorf("ack: %w", err))
    96  	}
    97  
    98  	if ka.Type != connectionKaMsg {
    99  		return errorSubscription(fmt.Errorf("expected ack message, got %#v", ack))
   100  	}
   101  
   102  	if err = c.WriteJSON(operationMessage{Type: startMsg, ID: "1", Payload: requestBody}); err != nil {
   103  		return errorSubscription(fmt.Errorf("start: %w", err))
   104  	}
   105  
   106  	return &Subscription{
   107  		Close: func() error {
   108  			srv.Close()
   109  			return c.Close()
   110  		},
   111  		Next: func(response interface{}) error {
   112  			for {
   113  				var op operationMessage
   114  				err := c.ReadJSON(&op)
   115  				if err != nil {
   116  					return err
   117  				}
   118  
   119  				switch op.Type {
   120  				case dataMsg:
   121  					break
   122  				case connectionKaMsg:
   123  					continue
   124  				case errorMsg:
   125  					return fmt.Errorf(string(op.Payload))
   126  				default:
   127  					return fmt.Errorf("expected data message, got %#v", op)
   128  				}
   129  
   130  				var respDataRaw Response
   131  				err = json.Unmarshal(op.Payload, &respDataRaw)
   132  				if err != nil {
   133  					return fmt.Errorf("decode: %w", err)
   134  				}
   135  
   136  				// we want to unpack even if there is an error, so we can see partial responses
   137  				unpackErr := unpack(respDataRaw.Data, response)
   138  
   139  				if respDataRaw.Errors != nil {
   140  					return RawJsonError{respDataRaw.Errors}
   141  				}
   142  				return unpackErr
   143  			}
   144  		},
   145  	}
   146  }