github.com/animeshon/gqlgen@v0.13.1-0.20210304133704-3a770431bb6d/client/websocket.go (about)

     1  package client
     2  
     3  import (
     4  	"encoding/json"
     5  	"fmt"
     6  	"io/ioutil"
     7  	"net/http/httptest"
     8  	"strings"
     9  
    10  	"github.com/gorilla/websocket"
    11  )
    12  
    13  const (
    14  	connectionInitMsg = "connection_init" // Client -> Server
    15  	startMsg          = "start"           // Client -> Server
    16  	connectionAckMsg  = "connection_ack"  // Server -> Client
    17  	connectionKaMsg   = "ka"              // Server -> Client
    18  	dataMsg           = "data"            // Server -> Client
    19  	errorMsg          = "error"           // Server -> Client
    20  )
    21  
    22  type operationMessage struct {
    23  	Payload json.RawMessage `json:"payload,omitempty"`
    24  	ID      string          `json:"id,omitempty"`
    25  	Type    string          `json:"type"`
    26  }
    27  
    28  type Subscription struct {
    29  	Close func() error
    30  	Next  func(response interface{}) error
    31  }
    32  
    33  func errorSubscription(err error) *Subscription {
    34  	return &Subscription{
    35  		Close: func() error { return nil },
    36  		Next: func(response interface{}) error {
    37  			return err
    38  		},
    39  	}
    40  }
    41  
    42  func (p *Client) Websocket(query string, options ...Option) *Subscription {
    43  	return p.WebsocketWithPayload(query, nil, options...)
    44  }
    45  
    46  // Grab a single response from a websocket based query
    47  func (p *Client) WebsocketOnce(query string, resp interface{}, options ...Option) error {
    48  	sock := p.Websocket(query)
    49  	defer sock.Close()
    50  	return sock.Next(&resp)
    51  }
    52  
    53  func (p *Client) WebsocketWithPayload(query string, initPayload map[string]interface{}, options ...Option) *Subscription {
    54  	r, err := p.newRequest(query, options...)
    55  	if err != nil {
    56  		return errorSubscription(fmt.Errorf("request: %s", err.Error()))
    57  	}
    58  
    59  	requestBody, err := ioutil.ReadAll(r.Body)
    60  	if err != nil {
    61  		return errorSubscription(fmt.Errorf("parse body: %s", err.Error()))
    62  	}
    63  
    64  	srv := httptest.NewServer(p.h)
    65  	host := strings.Replace(srv.URL, "http://", "ws://", -1)
    66  	c, _, err := websocket.DefaultDialer.Dial(host+r.URL.Path, r.Header)
    67  
    68  	if err != nil {
    69  		return errorSubscription(fmt.Errorf("dial: %s", err.Error()))
    70  	}
    71  
    72  	initMessage := operationMessage{Type: connectionInitMsg}
    73  	if initPayload != nil {
    74  		initMessage.Payload, err = json.Marshal(initPayload)
    75  		if err != nil {
    76  			return errorSubscription(fmt.Errorf("parse payload: %s", err.Error()))
    77  		}
    78  	}
    79  
    80  	if err = c.WriteJSON(initMessage); err != nil {
    81  		return errorSubscription(fmt.Errorf("init: %s", err.Error()))
    82  	}
    83  
    84  	var ack operationMessage
    85  	if err = c.ReadJSON(&ack); err != nil {
    86  		return errorSubscription(fmt.Errorf("ack: %s", err.Error()))
    87  	}
    88  
    89  	if ack.Type != connectionAckMsg {
    90  		return errorSubscription(fmt.Errorf("expected ack message, got %#v", ack))
    91  	}
    92  
    93  	var ka operationMessage
    94  	if err = c.ReadJSON(&ka); err != nil {
    95  		return errorSubscription(fmt.Errorf("ack: %s", err.Error()))
    96  	}
    97  
    98  	if ka.Type != connectionKaMsg {
    99  		return errorSubscription(fmt.Errorf("expected ack message, got %#v", ack))
   100  	}
   101  
   102  	if err = c.WriteJSON(operationMessage{Type: startMsg, ID: "1", Payload: requestBody}); err != nil {
   103  		return errorSubscription(fmt.Errorf("start: %s", err.Error()))
   104  	}
   105  
   106  	return &Subscription{
   107  		Close: func() error {
   108  			srv.Close()
   109  			return c.Close()
   110  		},
   111  		Next: func(response interface{}) error {
   112  			var op operationMessage
   113  			err := c.ReadJSON(&op)
   114  			if err != nil {
   115  				return err
   116  			}
   117  			if op.Type != dataMsg {
   118  				if op.Type == errorMsg {
   119  					return fmt.Errorf(string(op.Payload))
   120  				} else {
   121  					return fmt.Errorf("expected data message, got %#v", op)
   122  				}
   123  			}
   124  
   125  			var respDataRaw Response
   126  			err = json.Unmarshal(op.Payload, &respDataRaw)
   127  			if err != nil {
   128  				return fmt.Errorf("decode: %s", err.Error())
   129  			}
   130  
   131  			// we want to unpack even if there is an error, so we can see partial responses
   132  			unpackErr := unpack(respDataRaw.Data, response)
   133  
   134  			if respDataRaw.Errors != nil {
   135  				return RawJsonError{respDataRaw.Errors}
   136  			}
   137  			return unpackErr
   138  		},
   139  	}
   140  }