github.com/philippseith/signalr@v0.6.3/httpconnection.go (about)

     1  package signalr
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"fmt"
     7  	"io"
     8  	"net/http"
     9  	"net/url"
    10  	"path"
    11  
    12  	"nhooyr.io/websocket"
    13  )
    14  
    15  // Doer is the *http.Client interface
    16  type Doer interface {
    17  	Do(req *http.Request) (*http.Response, error)
    18  }
    19  
    20  type httpConnection struct {
    21  	client     Doer
    22  	headers    func() http.Header
    23  	transports []TransportType
    24  }
    25  
    26  // WithHTTPClient sets the http client used to connect to the signalR server.
    27  // The client is only used for http requests. It is not used for the websocket connection.
    28  func WithHTTPClient(client Doer) func(*httpConnection) error {
    29  	return func(c *httpConnection) error {
    30  		c.client = client
    31  		return nil
    32  	}
    33  }
    34  
    35  // WithHTTPHeaders sets the function for providing request headers for HTTP and websocket requests
    36  func WithHTTPHeaders(headers func() http.Header) func(*httpConnection) error {
    37  	return func(c *httpConnection) error {
    38  		c.headers = headers
    39  		return nil
    40  	}
    41  }
    42  
    43  func WithTransports(transports ...TransportType) func(*httpConnection) error {
    44  	return func(c *httpConnection) error {
    45  		for _, transport := range transports {
    46  			switch transport {
    47  			case TransportWebSockets, TransportServerSentEvents:
    48  				// Supported
    49  			default:
    50  				return fmt.Errorf("unsupported transport %s", transport)
    51  			}
    52  		}
    53  		c.transports = transports
    54  		return nil
    55  	}
    56  }
    57  
    58  // NewHTTPConnection creates a signalR HTTP Connection for usage with a Client.
    59  // ctx can be used to cancel the SignalR negotiation during the creation of the Connection
    60  // but not the Connection itself.
    61  func NewHTTPConnection(ctx context.Context, address string, options ...func(*httpConnection) error) (Connection, error) {
    62  	httpConn := &httpConnection{}
    63  
    64  	for _, option := range options {
    65  		if option != nil {
    66  			if err := option(httpConn); err != nil {
    67  				return nil, err
    68  			}
    69  		}
    70  	}
    71  
    72  	if httpConn.client == nil {
    73  		httpConn.client = http.DefaultClient
    74  	}
    75  	if len(httpConn.transports) == 0 {
    76  		httpConn.transports = []TransportType{TransportWebSockets, TransportServerSentEvents}
    77  	}
    78  
    79  	reqURL, err := url.Parse(address)
    80  	if err != nil {
    81  		return nil, err
    82  	}
    83  
    84  	negotiateURL := *reqURL
    85  	negotiateURL.Path = path.Join(negotiateURL.Path, "negotiate")
    86  	req, err := http.NewRequestWithContext(ctx, "POST", negotiateURL.String(), nil)
    87  	if err != nil {
    88  		return nil, err
    89  	}
    90  
    91  	if httpConn.headers != nil {
    92  		req.Header = httpConn.headers()
    93  	}
    94  
    95  	resp, err := httpConn.client.Do(req)
    96  	if err != nil {
    97  		return nil, err
    98  	}
    99  	defer func() { closeResponseBody(resp.Body) }()
   100  
   101  	if resp.StatusCode != 200 {
   102  		return nil, fmt.Errorf("%v %v -> %v", req.Method, req.URL.String(), resp.Status)
   103  	}
   104  
   105  	body, err := io.ReadAll(resp.Body)
   106  	if err != nil {
   107  		return nil, err
   108  	}
   109  
   110  	negotiateResponse := negotiateResponse{}
   111  	if err := json.Unmarshal(body, &negotiateResponse); err != nil {
   112  		return nil, err
   113  	}
   114  
   115  	q := reqURL.Query()
   116  	q.Set("id", negotiateResponse.ConnectionID)
   117  	reqURL.RawQuery = q.Encode()
   118  
   119  	// Select the best connection
   120  	var conn Connection
   121  	switch {
   122  	case negotiateResponse.hasTransport("WebTransports"):
   123  		// TODO
   124  
   125  	case httpConn.hasTransport(TransportWebSockets) && negotiateResponse.hasTransport(TransportWebSockets):
   126  		wsURL := reqURL
   127  
   128  		// switch to wss for secure connection
   129  		if reqURL.Scheme == "https" {
   130  			wsURL.Scheme = "wss"
   131  		} else {
   132  			wsURL.Scheme = "ws"
   133  		}
   134  
   135  		opts := &websocket.DialOptions{}
   136  
   137  		if httpConn.headers != nil {
   138  			opts.HTTPHeader = httpConn.headers()
   139  		} else {
   140  			opts.HTTPHeader = http.Header{}
   141  		}
   142  
   143  		for _, cookie := range resp.Cookies() {
   144  			opts.HTTPHeader.Add("Cookie", cookie.String())
   145  		}
   146  
   147  		ws, _, err := websocket.Dial(ctx, wsURL.String(), opts)
   148  		if err != nil {
   149  			return nil, err
   150  		}
   151  
   152  		// TODO think about if the API should give the possibility to cancel this connection
   153  		conn = newWebSocketConnection(context.Background(), negotiateResponse.ConnectionID, ws)
   154  
   155  	case httpConn.hasTransport(TransportServerSentEvents) && negotiateResponse.hasTransport(TransportServerSentEvents):
   156  		req, err := http.NewRequest("GET", reqURL.String(), nil)
   157  		if err != nil {
   158  			return nil, err
   159  		}
   160  
   161  		if httpConn.headers != nil {
   162  			req.Header = httpConn.headers()
   163  		}
   164  		req.Header.Set("Accept", "text/event-stream")
   165  
   166  		resp, err := httpConn.client.Do(req)
   167  		if err != nil {
   168  			return nil, err
   169  		}
   170  
   171  		conn, err = newClientSSEConnection(address, negotiateResponse.ConnectionID, resp.Body)
   172  		if err != nil {
   173  			return nil, err
   174  		}
   175  	}
   176  
   177  	return conn, nil
   178  }
   179  
   180  // closeResponseBody reads a http response body to the end and closes it
   181  // See https://blog.cubieserver.de/2022/http-connection-reuse-in-go-clients/
   182  // The body needs to be fully read and closed, otherwise the connection will not be reused
   183  func closeResponseBody(body io.ReadCloser) {
   184  	_, _ = io.Copy(io.Discard, body)
   185  	_ = body.Close()
   186  }
   187  
   188  func (h *httpConnection) hasTransport(transport TransportType) bool {
   189  	for _, t := range h.transports {
   190  		if transport == t {
   191  			return true
   192  		}
   193  	}
   194  	return false
   195  }