golang.org/x/net@v0.25.1-0.20240516223405-c87a5b62e243/websocket/client.go (about)

     1  // Copyright 2009 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package websocket
     6  
     7  import (
     8  	"bufio"
     9  	"context"
    10  	"io"
    11  	"net"
    12  	"net/http"
    13  	"net/url"
    14  	"time"
    15  )
    16  
    17  // DialError is an error that occurs while dialling a websocket server.
    18  type DialError struct {
    19  	*Config
    20  	Err error
    21  }
    22  
    23  func (e *DialError) Error() string {
    24  	return "websocket.Dial " + e.Config.Location.String() + ": " + e.Err.Error()
    25  }
    26  
    27  // NewConfig creates a new WebSocket config for client connection.
    28  func NewConfig(server, origin string) (config *Config, err error) {
    29  	config = new(Config)
    30  	config.Version = ProtocolVersionHybi13
    31  	config.Location, err = url.ParseRequestURI(server)
    32  	if err != nil {
    33  		return
    34  	}
    35  	config.Origin, err = url.ParseRequestURI(origin)
    36  	if err != nil {
    37  		return
    38  	}
    39  	config.Header = http.Header(make(map[string][]string))
    40  	return
    41  }
    42  
    43  // NewClient creates a new WebSocket client connection over rwc.
    44  func NewClient(config *Config, rwc io.ReadWriteCloser) (ws *Conn, err error) {
    45  	br := bufio.NewReader(rwc)
    46  	bw := bufio.NewWriter(rwc)
    47  	err = hybiClientHandshake(config, br, bw)
    48  	if err != nil {
    49  		return
    50  	}
    51  	buf := bufio.NewReadWriter(br, bw)
    52  	ws = newHybiClientConn(config, buf, rwc)
    53  	return
    54  }
    55  
    56  // Dial opens a new client connection to a WebSocket.
    57  func Dial(url_, protocol, origin string) (ws *Conn, err error) {
    58  	config, err := NewConfig(url_, origin)
    59  	if err != nil {
    60  		return nil, err
    61  	}
    62  	if protocol != "" {
    63  		config.Protocol = []string{protocol}
    64  	}
    65  	return DialConfig(config)
    66  }
    67  
    68  var portMap = map[string]string{
    69  	"ws":  "80",
    70  	"wss": "443",
    71  }
    72  
    73  func parseAuthority(location *url.URL) string {
    74  	if _, ok := portMap[location.Scheme]; ok {
    75  		if _, _, err := net.SplitHostPort(location.Host); err != nil {
    76  			return net.JoinHostPort(location.Host, portMap[location.Scheme])
    77  		}
    78  	}
    79  	return location.Host
    80  }
    81  
    82  // DialConfig opens a new client connection to a WebSocket with a config.
    83  func DialConfig(config *Config) (ws *Conn, err error) {
    84  	return config.DialContext(context.Background())
    85  }
    86  
    87  // DialContext opens a new client connection to a WebSocket, with context support for timeouts/cancellation.
    88  func (config *Config) DialContext(ctx context.Context) (*Conn, error) {
    89  	if config.Location == nil {
    90  		return nil, &DialError{config, ErrBadWebSocketLocation}
    91  	}
    92  	if config.Origin == nil {
    93  		return nil, &DialError{config, ErrBadWebSocketOrigin}
    94  	}
    95  
    96  	dialer := config.Dialer
    97  	if dialer == nil {
    98  		dialer = &net.Dialer{}
    99  	}
   100  
   101  	client, err := dialWithDialer(ctx, dialer, config)
   102  	if err != nil {
   103  		return nil, &DialError{config, err}
   104  	}
   105  
   106  	// Cleanup the connection if we fail to create the websocket successfully
   107  	success := false
   108  	defer func() {
   109  		if !success {
   110  			_ = client.Close()
   111  		}
   112  	}()
   113  
   114  	var ws *Conn
   115  	var wsErr error
   116  	doneConnecting := make(chan struct{})
   117  	go func() {
   118  		defer close(doneConnecting)
   119  		ws, err = NewClient(config, client)
   120  		if err != nil {
   121  			wsErr = &DialError{config, err}
   122  		}
   123  	}()
   124  
   125  	// The websocket.NewClient() function can block indefinitely, make sure that we
   126  	// respect the deadlines specified by the context.
   127  	select {
   128  	case <-ctx.Done():
   129  		// Force the pending operations to fail, terminating the pending connection attempt
   130  		_ = client.SetDeadline(time.Now())
   131  		<-doneConnecting // Wait for the goroutine that tries to establish the connection to finish
   132  		return nil, &DialError{config, ctx.Err()}
   133  	case <-doneConnecting:
   134  		if wsErr == nil {
   135  			success = true // Disarm the deferred connection cleanup
   136  		}
   137  		return ws, wsErr
   138  	}
   139  }