github.com/ezoic/ws@v1.0.4-0.20220713205711-5c1d69e074c5/dialer.go (about) 1 package ws 2 3 import ( 4 "bufio" 5 "bytes" 6 "context" 7 "crypto/tls" 8 "fmt" 9 "io" 10 "net" 11 "net/url" 12 "strconv" 13 "strings" 14 "time" 15 16 "github.com/ezoic/httphead" 17 "github.com/ezoic/pool/pbufio" 18 ) 19 20 // Constants used by Dialer. 21 const ( 22 DefaultClientReadBufferSize = 4096 23 DefaultClientWriteBufferSize = 4096 24 ) 25 26 // Handshake represents handshake result. 27 type Handshake struct { 28 // Protocol is the subprotocol selected during handshake. 29 Protocol string 30 31 // Extensions is the list of negotiated extensions. 32 Extensions []httphead.Option 33 } 34 35 // Errors used by the websocket client. 36 var ( 37 ErrHandshakeBadStatus = fmt.Errorf("unexpected http status") 38 ErrHandshakeBadSubProtocol = fmt.Errorf("unexpected protocol in %q header", headerSecProtocol) 39 ErrHandshakeBadExtensions = fmt.Errorf("unexpected extensions in %q header", headerSecProtocol) 40 ) 41 42 // DefaultDialer is dialer that holds no options and is used by Dial function. 43 var DefaultDialer Dialer 44 45 // Dial is like Dialer{}.Dial(). 46 func Dial(ctx context.Context, urlstr string) (net.Conn, *bufio.Reader, Handshake, error) { 47 return DefaultDialer.Dial(ctx, urlstr) 48 } 49 50 // Dialer contains options for establishing websocket connection to an url. 51 type Dialer struct { 52 // ReadBufferSize and WriteBufferSize is an I/O buffer sizes. 53 // They used to read and write http data while upgrading to WebSocket. 54 // Allocated buffers are pooled with sync.Pool to avoid extra allocations. 55 // 56 // If a size is zero then default value is used. 57 ReadBufferSize, WriteBufferSize int 58 59 // Timeout is the maximum amount of time a Dial() will wait for a connect 60 // and an handshake to complete. 61 // 62 // The default is no timeout. 63 Timeout time.Duration 64 65 // Protocols is the list of subprotocols that the client wants to speak, 66 // ordered by preference. 67 // 68 // See https://tools.ietf.org/html/rfc6455#section-4.1 69 Protocols []string 70 71 // Extensions is the list of extensions that client wants to speak. 72 // 73 // Note that if server decides to use some of this extensions, Dial() will 74 // return Handshake struct containing a slice of items, which are the 75 // shallow copies of the items from this list. That is, internals of 76 // Extensions items are shared during Dial(). 77 // 78 // See https://tools.ietf.org/html/rfc6455#section-4.1 79 // See https://tools.ietf.org/html/rfc6455#section-9.1 80 Extensions []httphead.Option 81 82 // Header is an optional HandshakeHeader instance that could be used to 83 // write additional headers to the handshake request. 84 // 85 // It used instead of any key-value mappings to avoid allocations in user 86 // land. 87 Header HandshakeHeader 88 89 // OnStatusError is the callback that will be called after receiving non 90 // "101 Continue" HTTP response status. It receives an io.Reader object 91 // representing server response bytes. That is, it gives ability to parse 92 // HTTP response somehow (probably with http.ReadResponse call) and make a 93 // decision of further logic. 94 // 95 // The arguments are only valid until the callback returns. 96 OnStatusError func(status int, reason []byte, resp io.Reader) 97 98 // OnHeader is the callback that will be called after successful parsing of 99 // header, that is not used during WebSocket handshake procedure. That is, 100 // it will be called with non-websocket headers, which could be relevant 101 // for application-level logic. 102 // 103 // The arguments are only valid until the callback returns. 104 // 105 // Returned value could be used to prevent processing response. 106 OnHeader func(key, value []byte) (err error) 107 108 // NetDial is the function that is used to get plain tcp connection. 109 // If it is not nil, then it is used instead of net.Dialer. 110 NetDial func(ctx context.Context, network, addr string) (net.Conn, error) 111 112 // TLSClient is the callback that will be called after successful dial with 113 // received connection and its remote host name. If it is nil, then the 114 // default tls.Client() will be used. 115 // If it is not nil, then TLSConfig field is ignored. 116 TLSClient func(conn net.Conn, hostname string) net.Conn 117 118 // TLSConfig is passed to tls.Client() to start TLS over established 119 // connection. If TLSClient is not nil, then it is ignored. If TLSConfig is 120 // non-nil and its ServerName is empty, then for every Dial() it will be 121 // cloned and appropriate ServerName will be set. 122 TLSConfig *tls.Config 123 124 // WrapConn is the optional callback that will be called when connection is 125 // ready for an i/o. That is, it will be called after successful dial and 126 // TLS initialization (for "wss" schemes). It may be helpful for different 127 // user land purposes such as end to end encryption. 128 // 129 // Note that for debugging purposes of an http handshake (e.g. sent request 130 // and received response), there is an wsutil.DebugDialer struct. 131 WrapConn func(conn net.Conn) net.Conn 132 } 133 134 // Dial connects to the url host and upgrades connection to WebSocket. 135 // 136 // If server has sent frames right after successful handshake then returned 137 // buffer will be non-nil. In other cases buffer is always nil. For better 138 // memory efficiency received non-nil bufio.Reader should be returned to the 139 // inner pool with PutReader() function after use. 140 // 141 // Note that Dialer does not implement IDNA (RFC5895) logic as net/http does. 142 // If you want to dial non-ascii host name, take care of its name serialization 143 // avoiding bad request issues. For more info see net/http Request.Write() 144 // implementation, especially cleanHost() function. 145 func (d Dialer) Dial(ctx context.Context, urlstr string) (conn net.Conn, br *bufio.Reader, hs Handshake, err error) { 146 u, err := url.ParseRequestURI(urlstr) 147 if err != nil { 148 return 149 } 150 151 // Prepare context to dial with. Initially it is the same as original, but 152 // if d.Timeout is non-zero and points to time that is before ctx.Deadline, 153 // we use more shorter context for dial. 154 dialctx := ctx 155 156 var deadline time.Time 157 if t := d.Timeout; t != 0 { 158 deadline = time.Now().Add(t) 159 if d, ok := ctx.Deadline(); !ok || deadline.Before(d) { 160 var cancel context.CancelFunc 161 dialctx, cancel = context.WithDeadline(ctx, deadline) 162 defer cancel() 163 } 164 } 165 if conn, err = d.dial(dialctx, u); err != nil { 166 return 167 } 168 defer func() { 169 if err != nil { 170 conn.Close() 171 } 172 }() 173 if ctx == context.Background() { 174 // No need to start I/O interrupter goroutine which is not zero-cost. 175 conn.SetDeadline(deadline) 176 defer conn.SetDeadline(noDeadline) 177 } else { 178 // Context could be canceled or its deadline could be exceeded. 179 // Start the interrupter goroutine to handle context cancelation. 180 done := setupContextDeadliner(ctx, conn) 181 defer func() { 182 // Map Upgrade() error to a possible context expiration error. That 183 // is, even if Upgrade() err is nil, context could be already 184 // expired and connection be "poisoned" by SetDeadline() call. 185 // In that case we must not return ctx.Err() error. 186 done(&err) 187 }() 188 } 189 190 br, hs, err = d.Upgrade(conn, u) 191 192 return 193 } 194 195 var ( 196 // netEmptyDialer is a net.Dialer without options, used in Dialer.dial() if 197 // Dialer.NetDial is not provided. 198 netEmptyDialer net.Dialer 199 // tlsEmptyConfig is an empty tls.Config used as default one. 200 tlsEmptyConfig tls.Config 201 ) 202 203 func tlsDefaultConfig() *tls.Config { 204 return &tlsEmptyConfig 205 } 206 207 func hostport(host string, defaultPort string) (hostname, addr string) { 208 var ( 209 colon = strings.LastIndexByte(host, ':') 210 bracket = strings.IndexByte(host, ']') 211 ) 212 if colon > bracket { 213 return host[:colon], host 214 } 215 return host, host + defaultPort 216 } 217 218 func (d Dialer) dial(ctx context.Context, u *url.URL) (conn net.Conn, err error) { 219 dial := d.NetDial 220 if dial == nil { 221 dial = netEmptyDialer.DialContext 222 } 223 switch u.Scheme { 224 case "ws": 225 _, addr := hostport(u.Host, ":80") 226 conn, err = dial(ctx, "tcp", addr) 227 case "wss": 228 hostname, addr := hostport(u.Host, ":443") 229 conn, err = dial(ctx, "tcp", addr) 230 if err != nil { 231 return 232 } 233 tlsClient := d.TLSClient 234 if tlsClient == nil { 235 tlsClient = d.tlsClient 236 } 237 conn = tlsClient(conn, hostname) 238 default: 239 return nil, fmt.Errorf("unexpected websocket scheme: %q", u.Scheme) 240 } 241 if wrap := d.WrapConn; wrap != nil { 242 conn = wrap(conn) 243 } 244 return 245 } 246 247 func (d Dialer) tlsClient(conn net.Conn, hostname string) net.Conn { 248 config := d.TLSConfig 249 if config == nil { 250 config = tlsDefaultConfig() 251 } 252 if config.ServerName == "" { 253 config = tlsCloneConfig(config) 254 config.ServerName = hostname 255 } 256 // Do not make conn.Handshake() here because downstairs we will prepare 257 // i/o on this conn with proper context's timeout handling. 258 return tls.Client(conn, config) 259 } 260 261 var ( 262 // This variables are set like in net/net.go. 263 // noDeadline is just zero value for readability. 264 noDeadline = time.Time{} 265 // aLongTimeAgo is a non-zero time, far in the past, used for immediate 266 // cancelation of dials. 267 aLongTimeAgo = time.Unix(42, 0) 268 ) 269 270 // Upgrade writes an upgrade request to the given io.ReadWriter conn at given 271 // url u and reads a response from it. 272 // 273 // It is a caller responsibility to manage I/O deadlines on conn. 274 // 275 // It returns handshake info and some bytes which could be written by the peer 276 // right after response and be caught by us during buffered read. 277 func (d Dialer) Upgrade(conn io.ReadWriter, u *url.URL) (br *bufio.Reader, hs Handshake, err error) { 278 // headerSeen constants helps to report whether or not some header was seen 279 // during reading request bytes. 280 const ( 281 headerSeenUpgrade = 1 << iota 282 headerSeenConnection 283 headerSeenSecAccept 284 285 // headerSeenAll is the value that we expect to receive at the end of 286 // headers read/parse loop. 287 headerSeenAll = 0 | 288 headerSeenUpgrade | 289 headerSeenConnection | 290 headerSeenSecAccept 291 ) 292 293 br = pbufio.GetReader(conn, 294 nonZero(d.ReadBufferSize, DefaultClientReadBufferSize), 295 ) 296 bw := pbufio.GetWriter(conn, 297 nonZero(d.WriteBufferSize, DefaultClientWriteBufferSize), 298 ) 299 defer func() { 300 pbufio.PutWriter(bw) 301 if br.Buffered() == 0 || err != nil { 302 // Server does not wrote additional bytes to the connection or 303 // error occurred. That is, no reason to return buffer. 304 pbufio.PutReader(br) 305 br = nil 306 } 307 }() 308 309 nonce := make([]byte, nonceSize) 310 initNonce(nonce) 311 312 httpWriteUpgradeRequest(bw, u, nonce, d.Protocols, d.Extensions, d.Header) 313 if err = bw.Flush(); err != nil { 314 return 315 } 316 317 // Read HTTP status line like "HTTP/1.1 101 Switching Protocols". 318 sl, err := readLine(br) 319 if err != nil { 320 return 321 } 322 // Begin validation of the response. 323 // See https://tools.ietf.org/html/rfc6455#section-4.2.2 324 // Parse request line data like HTTP version, uri and method. 325 resp, err := httpParseResponseLine(sl) 326 if err != nil { 327 return 328 } 329 // Even if RFC says "1.1 or higher" without mentioning the part of the 330 // version, we apply it only to minor part. 331 if resp.major != 1 || resp.minor < 1 { 332 err = ErrHandshakeBadProtocol 333 return 334 } 335 if resp.status != 101 { 336 err = StatusError(resp.status) 337 if onStatusError := d.OnStatusError; onStatusError != nil { 338 // Invoke callback with multireader of status-line bytes br. 339 onStatusError(resp.status, resp.reason, 340 io.MultiReader( 341 bytes.NewReader(sl), 342 strings.NewReader(crlf), 343 br, 344 ), 345 ) 346 } 347 return 348 } 349 // If response status is 101 then we expect all technical headers to be 350 // valid. If not, then we stop processing response without giving user 351 // ability to read non-technical headers. That is, we do not distinguish 352 // technical errors (such as parsing error) and protocol errors. 353 var headerSeen byte 354 for { 355 line, e := readLine(br) 356 if e != nil { 357 err = e 358 return 359 } 360 if len(line) == 0 { 361 // Blank line, no more lines to read. 362 break 363 } 364 365 k, v, ok := httpParseHeaderLine(line) 366 if !ok { 367 err = ErrMalformedResponse 368 return 369 } 370 371 switch btsToString(k) { 372 case headerUpgradeCanonical: 373 headerSeen |= headerSeenUpgrade 374 if !bytes.Equal(v, specHeaderValueUpgrade) && !bytes.EqualFold(v, specHeaderValueUpgrade) { 375 err = ErrHandshakeBadUpgrade 376 return 377 } 378 379 case headerConnectionCanonical: 380 headerSeen |= headerSeenConnection 381 // Note that as RFC6455 says: 382 // > A |Connection| header field with value "Upgrade". 383 // That is, in server side, "Connection" header could contain 384 // multiple token. But in response it must contains exactly one. 385 if !bytes.Equal(v, specHeaderValueConnection) && !bytes.EqualFold(v, specHeaderValueConnection) { 386 err = ErrHandshakeBadConnection 387 return 388 } 389 390 case headerSecAcceptCanonical: 391 headerSeen |= headerSeenSecAccept 392 if !checkAcceptFromNonce(v, nonce) { 393 err = ErrHandshakeBadSecAccept 394 return 395 } 396 397 case headerSecProtocolCanonical: 398 // RFC6455 1.3: 399 // "The server selects one or none of the acceptable protocols 400 // and echoes that value in its handshake to indicate that it has 401 // selected that protocol." 402 for _, want := range d.Protocols { 403 if string(v) == want { 404 hs.Protocol = want 405 break 406 } 407 } 408 if hs.Protocol == "" { 409 // Server echoed subprotocol that is not present in client 410 // requested protocols. 411 err = ErrHandshakeBadSubProtocol 412 return 413 } 414 415 case headerSecExtensionsCanonical: 416 hs.Extensions, err = matchSelectedExtensions(v, d.Extensions, hs.Extensions) 417 if err != nil { 418 return 419 } 420 421 default: 422 if onHeader := d.OnHeader; onHeader != nil { 423 if e := onHeader(k, v); e != nil { 424 err = e 425 return 426 } 427 } 428 } 429 } 430 if err == nil && headerSeen != headerSeenAll { 431 switch { 432 case headerSeen&headerSeenUpgrade == 0: 433 err = ErrHandshakeBadUpgrade 434 case headerSeen&headerSeenConnection == 0: 435 err = ErrHandshakeBadConnection 436 case headerSeen&headerSeenSecAccept == 0: 437 err = ErrHandshakeBadSecAccept 438 default: 439 panic("unknown headers state") 440 } 441 } 442 return 443 } 444 445 // PutReader returns bufio.Reader instance to the inner reuse pool. 446 // It is useful in rare cases, when Dialer.Dial() returns non-nil buffer which 447 // contains unprocessed buffered data, that was sent by the server quickly 448 // right after handshake. 449 func PutReader(br *bufio.Reader) { 450 pbufio.PutReader(br) 451 } 452 453 // StatusError contains an unexpected status-line code from the server. 454 type StatusError int 455 456 func (s StatusError) Error() string { 457 return "unexpected HTTP response status: " + strconv.Itoa(int(s)) 458 } 459 460 func isTimeoutError(err error) bool { 461 t, ok := err.(net.Error) 462 return ok && t.Timeout() 463 } 464 465 func matchSelectedExtensions(selected []byte, wanted, received []httphead.Option) ([]httphead.Option, error) { 466 if len(selected) == 0 { 467 return received, nil 468 } 469 var ( 470 index int 471 option httphead.Option 472 err error 473 ) 474 index = -1 475 match := func() (ok bool) { 476 for _, want := range wanted { 477 if option.Equal(want) { 478 // Check parsed extension to be present in client 479 // requested extensions. We move matched extension 480 // from client list to avoid allocation. 481 received = append(received, want) 482 return true 483 } 484 } 485 return false 486 } 487 ok := httphead.ScanOptions(selected, func(i int, name, attr, val []byte) httphead.Control { 488 if i != index { 489 // Met next option. 490 index = i 491 if i != 0 && !match() { 492 // Server returned non-requested extension. 493 err = ErrHandshakeBadExtensions 494 return httphead.ControlBreak 495 } 496 option = httphead.Option{Name: name} 497 } 498 if attr != nil { 499 option.Parameters.Set(attr, val) 500 } 501 return httphead.ControlContinue 502 }) 503 if !ok { 504 err = ErrMalformedResponse 505 return received, err 506 } 507 if !match() { 508 return received, ErrHandshakeBadExtensions 509 } 510 return received, err 511 } 512 513 // setupContextDeadliner is a helper function that starts connection I/O 514 // interrupter goroutine. 515 // 516 // Started goroutine calls SetDeadline() with long time ago value when context 517 // become expired to make any I/O operations failed. It returns done function 518 // that stops started goroutine and maps error received from conn I/O methods 519 // to possible context expiration error. 520 // 521 // In concern with possible SetDeadline() call inside interrupter goroutine, 522 // caller passes pointer to its I/O error (even if it is nil) to done(&err). 523 // That is, even if I/O error is nil, context could be already expired and 524 // connection "poisoned" by SetDeadline() call. In that case done(&err) will 525 // store at *err ctx.Err() result. If err is caused not by timeout, it will 526 // leaved untouched. 527 func setupContextDeadliner(ctx context.Context, conn net.Conn) (done func(*error)) { 528 var ( 529 quit = make(chan struct{}) 530 interrupt = make(chan error, 1) 531 ) 532 go func() { 533 select { 534 case <-quit: 535 interrupt <- nil 536 case <-ctx.Done(): 537 // Cancel i/o immediately. 538 conn.SetDeadline(aLongTimeAgo) 539 interrupt <- ctx.Err() 540 } 541 }() 542 return func(err *error) { 543 close(quit) 544 // If ctx.Err() is non-nil and the original err is net.Error with 545 // Timeout() == true, then it means that I/O was canceled by us by 546 // SetDeadline(aLongTimeAgo) call, or by somebody else previously 547 // by conn.SetDeadline(x). 548 // 549 // Even on race condition when both deadlines are expired 550 // (SetDeadline() made not by us and context's), we prefer ctx.Err() to 551 // be returned. 552 if ctxErr := <-interrupt; ctxErr != nil && (*err == nil || isTimeoutError(*err)) { 553 *err = ctxErr 554 } 555 } 556 }