github.com/operandinc/gqlgen@v0.16.1/client/websocket.go (about)

     1  package client
     2  
     3  import (
     4  	"encoding/json"
     5  	"fmt"
     6  	"io/ioutil"
     7  	"net/http/httptest"
     8  	"strings"
     9  
    10  	"github.com/gorilla/websocket"
    11  )
    12  
    13  const (
    14  	connectionInitMsg = "connection_init" // Client -> Server
    15  	startMsg          = "start"           // Client -> Server
    16  	connectionAckMsg  = "connection_ack"  // Server -> Client
    17  	connectionKaMsg   = "ka"              // Server -> Client
    18  	dataMsg           = "data"            // Server -> Client
    19  	errorMsg          = "error"           // Server -> Client
    20  )
    21  
    22  type operationMessage struct {
    23  	Payload json.RawMessage `json:"payload,omitempty"`
    24  	ID      string          `json:"id,omitempty"`
    25  	Type    string          `json:"type"`
    26  }
    27  
    28  type Subscription struct {
    29  	Close func() error
    30  	Next  func(response interface{}) error
    31  }
    32  
    33  func errorSubscription(err error) *Subscription {
    34  	return &Subscription{
    35  		Close: func() error { return nil },
    36  		Next: func(response interface{}) error {
    37  			return err
    38  		},
    39  	}
    40  }
    41  
    42  func (p *Client) Websocket(query string, options ...Option) *Subscription {
    43  	return p.WebsocketWithPayload(query, nil, options...)
    44  }
    45  
    46  // Grab a single response from a websocket based query
    47  func (p *Client) WebsocketOnce(query string, resp interface{}, options ...Option) error {
    48  	sock := p.Websocket(query)
    49  	defer sock.Close()
    50  	return sock.Next(&resp)
    51  }
    52  
    53  func (p *Client) WebsocketWithPayload(query string, initPayload map[string]interface{}, options ...Option) *Subscription {
    54  	r, err := p.newRequest(query, options...)
    55  	if err != nil {
    56  		return errorSubscription(fmt.Errorf("request: %w", err))
    57  	}
    58  
    59  	requestBody, err := ioutil.ReadAll(r.Body)
    60  	if err != nil {
    61  		return errorSubscription(fmt.Errorf("parse body: %w", err))
    62  	}
    63  
    64  	srv := httptest.NewServer(p.h)
    65  	host := strings.ReplaceAll(srv.URL, "http://", "ws://")
    66  	c, _, err := websocket.DefaultDialer.Dial(host+r.URL.Path, r.Header)
    67  	if err != nil {
    68  		return errorSubscription(fmt.Errorf("dial: %w", err))
    69  	}
    70  
    71  	initMessage := operationMessage{Type: connectionInitMsg}
    72  	if initPayload != nil {
    73  		initMessage.Payload, err = json.Marshal(initPayload)
    74  		if err != nil {
    75  			return errorSubscription(fmt.Errorf("parse payload: %w", err))
    76  		}
    77  	}
    78  
    79  	if err = c.WriteJSON(initMessage); err != nil {
    80  		return errorSubscription(fmt.Errorf("init: %w", err))
    81  	}
    82  
    83  	var ack operationMessage
    84  	if err = c.ReadJSON(&ack); err != nil {
    85  		return errorSubscription(fmt.Errorf("ack: %w", err))
    86  	}
    87  
    88  	if ack.Type != connectionAckMsg {
    89  		return errorSubscription(fmt.Errorf("expected ack message, got %#v", ack))
    90  	}
    91  
    92  	var ka operationMessage
    93  	if err = c.ReadJSON(&ka); err != nil {
    94  		return errorSubscription(fmt.Errorf("ack: %w", err))
    95  	}
    96  
    97  	if ka.Type != connectionKaMsg {
    98  		return errorSubscription(fmt.Errorf("expected ack message, got %#v", ack))
    99  	}
   100  
   101  	if err = c.WriteJSON(operationMessage{Type: startMsg, ID: "1", Payload: requestBody}); err != nil {
   102  		return errorSubscription(fmt.Errorf("start: %w", err))
   103  	}
   104  
   105  	return &Subscription{
   106  		Close: func() error {
   107  			srv.Close()
   108  			return c.Close()
   109  		},
   110  		Next: func(response interface{}) error {
   111  			var op operationMessage
   112  			err := c.ReadJSON(&op)
   113  			if err != nil {
   114  				return err
   115  			}
   116  			if op.Type != dataMsg {
   117  				if op.Type == errorMsg {
   118  					return fmt.Errorf(string(op.Payload))
   119  				} else {
   120  					return fmt.Errorf("expected data message, got %#v", op)
   121  				}
   122  			}
   123  
   124  			var respDataRaw Response
   125  			err = json.Unmarshal(op.Payload, &respDataRaw)
   126  			if err != nil {
   127  				return fmt.Errorf("decode: %w", err)
   128  			}
   129  
   130  			// we want to unpack even if there is an error, so we can see partial responses
   131  			unpackErr := unpack(respDataRaw.Data, response)
   132  
   133  			if respDataRaw.Errors != nil {
   134  				return RawJsonError{respDataRaw.Errors}
   135  			}
   136  			return unpackErr
   137  		},
   138  	}
   139  }