github.com/niko0xdev/gqlgen@v0.17.55-0.20240120102243-2ecff98c3e37/client/websocket.go (about)

     1  package client
     2  
     3  import (
     4  	"encoding/json"
     5  	"fmt"
     6  	"io"
     7  	"net/http/httptest"
     8  	"reflect"
     9  	"strings"
    10  
    11  	"github.com/gorilla/websocket"
    12  )
    13  
    14  const (
    15  	connectionInitMsg = "connection_init" // Client -> Server
    16  	startMsg          = "start"           // Client -> Server
    17  	connectionAckMsg  = "connection_ack"  // Server -> Client
    18  	connectionKaMsg   = "ka"              // Server -> Client
    19  	dataMsg           = "data"            // Server -> Client
    20  	errorMsg          = "error"           // Server -> Client
    21  )
    22  
    23  type operationMessage struct {
    24  	Payload json.RawMessage `json:"payload,omitempty"`
    25  	ID      string          `json:"id,omitempty"`
    26  	Type    string          `json:"type"`
    27  }
    28  
    29  type Subscription struct {
    30  	Close func() error
    31  	Next  func(response interface{}) error
    32  }
    33  
    34  func errorSubscription(err error) *Subscription {
    35  	return &Subscription{
    36  		Close: func() error { return nil },
    37  		Next: func(response interface{}) error {
    38  			return err
    39  		},
    40  	}
    41  }
    42  
    43  func (p *Client) Websocket(query string, options ...Option) *Subscription {
    44  	return p.WebsocketWithPayload(query, nil, options...)
    45  }
    46  
    47  // Grab a single response from a websocket based query
    48  func (p *Client) WebsocketOnce(query string, resp interface{}, options ...Option) error {
    49  	sock := p.Websocket(query, options...)
    50  	defer sock.Close()
    51  	if reflect.ValueOf(resp).Kind() == reflect.Ptr {
    52  		return sock.Next(resp)
    53  	}
    54  	// TODO: verify this is never called and remove it
    55  	return sock.Next(&resp)
    56  }
    57  
    58  func (p *Client) WebsocketWithPayload(query string, initPayload map[string]interface{}, options ...Option) *Subscription {
    59  	r, err := p.newRequest(query, options...)
    60  	if err != nil {
    61  		return errorSubscription(fmt.Errorf("request: %w", err))
    62  	}
    63  
    64  	requestBody, err := io.ReadAll(r.Body)
    65  	if err != nil {
    66  		return errorSubscription(fmt.Errorf("parse body: %w", err))
    67  	}
    68  
    69  	srv := httptest.NewServer(p.h)
    70  	host := strings.ReplaceAll(srv.URL, "http://", "ws://")
    71  	c, resp, err := websocket.DefaultDialer.Dial(host+r.URL.Path, r.Header)
    72  	if err != nil {
    73  		return errorSubscription(fmt.Errorf("dial: %w", err))
    74  	}
    75  	defer resp.Body.Close()
    76  
    77  	initMessage := operationMessage{Type: connectionInitMsg}
    78  	if initPayload != nil {
    79  		initMessage.Payload, err = json.Marshal(initPayload)
    80  		if err != nil {
    81  			return errorSubscription(fmt.Errorf("parse payload: %w", err))
    82  		}
    83  	}
    84  
    85  	if err = c.WriteJSON(initMessage); err != nil {
    86  		return errorSubscription(fmt.Errorf("init: %w", err))
    87  	}
    88  
    89  	var ack operationMessage
    90  	if err = c.ReadJSON(&ack); err != nil {
    91  		return errorSubscription(fmt.Errorf("ack: %w", err))
    92  	}
    93  
    94  	if ack.Type != connectionAckMsg {
    95  		return errorSubscription(fmt.Errorf("expected ack message, got %#v", ack))
    96  	}
    97  
    98  	var ka operationMessage
    99  	if err = c.ReadJSON(&ka); err != nil {
   100  		return errorSubscription(fmt.Errorf("ack: %w", err))
   101  	}
   102  
   103  	if ka.Type != connectionKaMsg {
   104  		return errorSubscription(fmt.Errorf("expected ack message, got %#v", ack))
   105  	}
   106  
   107  	if err = c.WriteJSON(operationMessage{Type: startMsg, ID: "1", Payload: requestBody}); err != nil {
   108  		return errorSubscription(fmt.Errorf("start: %w", err))
   109  	}
   110  
   111  	return &Subscription{
   112  		Close: func() error {
   113  			srv.Close()
   114  			return c.Close()
   115  		},
   116  		Next: func(response interface{}) error {
   117  			for {
   118  				var op operationMessage
   119  				err := c.ReadJSON(&op)
   120  				if err != nil {
   121  					return err
   122  				}
   123  
   124  				switch op.Type {
   125  				case dataMsg:
   126  					break
   127  				case connectionKaMsg:
   128  					continue
   129  				case errorMsg:
   130  					return fmt.Errorf(string(op.Payload))
   131  				default:
   132  					return fmt.Errorf("expected data message, got %#v", op)
   133  				}
   134  
   135  				var respDataRaw Response
   136  				err = json.Unmarshal(op.Payload, &respDataRaw)
   137  				if err != nil {
   138  					return fmt.Errorf("decode: %w", err)
   139  				}
   140  
   141  				// we want to unpack even if there is an error, so we can see partial responses
   142  				unpackErr := unpack(respDataRaw.Data, response, p.dc)
   143  
   144  				if respDataRaw.Errors != nil {
   145  					return RawJsonError{respDataRaw.Errors}
   146  				}
   147  				return unpackErr
   148  			}
   149  		},
   150  	}
   151  }