github.com/hxx258456/ccgo@v0.0.5-0.20230213014102-48b35f46f66f/net/websocket/hybi.go (about) 1 // Copyright 2011 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package websocket 6 7 // This file implements a protocol of hybi draft. 8 // http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-17 9 10 import ( 11 "bufio" 12 "bytes" 13 "crypto/rand" 14 "crypto/sha1" 15 "encoding/base64" 16 "encoding/binary" 17 "fmt" 18 "io" 19 "io/ioutil" 20 "net/url" 21 "strings" 22 23 http "github.com/hxx258456/ccgo/gmhttp" 24 ) 25 26 const ( 27 websocketGUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" 28 29 closeStatusNormal = 1000 30 closeStatusGoingAway = 1001 31 closeStatusProtocolError = 1002 32 closeStatusUnsupportedData = 1003 33 closeStatusFrameTooLarge = 1004 34 closeStatusNoStatusRcvd = 1005 35 closeStatusAbnormalClosure = 1006 36 closeStatusBadMessageData = 1007 37 closeStatusPolicyViolation = 1008 38 closeStatusTooBigData = 1009 39 closeStatusExtensionMismatch = 1010 40 41 maxControlFramePayloadLength = 125 42 ) 43 44 var ( 45 ErrBadMaskingKey = &ProtocolError{"bad masking key"} 46 ErrBadPongMessage = &ProtocolError{"bad pong message"} 47 ErrBadClosingStatus = &ProtocolError{"bad closing status"} 48 ErrUnsupportedExtensions = &ProtocolError{"unsupported extensions"} 49 ErrNotImplemented = &ProtocolError{"not implemented"} 50 51 handshakeHeader = map[string]bool{ 52 "Host": true, 53 "Upgrade": true, 54 "Connection": true, 55 "Sec-Websocket-Key": true, 56 "Sec-Websocket-Origin": true, 57 "Sec-Websocket-Version": true, 58 "Sec-Websocket-Protocol": true, 59 "Sec-Websocket-Accept": true, 60 } 61 ) 62 63 // A hybiFrameHeader is a frame header as defined in hybi draft. 64 type hybiFrameHeader struct { 65 Fin bool 66 Rsv [3]bool 67 OpCode byte 68 Length int64 69 MaskingKey []byte 70 71 data *bytes.Buffer 72 } 73 74 // A hybiFrameReader is a reader for hybi frame. 75 type hybiFrameReader struct { 76 reader io.Reader 77 78 header hybiFrameHeader 79 pos int64 80 length int 81 } 82 83 func (frame *hybiFrameReader) Read(msg []byte) (n int, err error) { 84 n, err = frame.reader.Read(msg) 85 if frame.header.MaskingKey != nil { 86 for i := 0; i < n; i++ { 87 msg[i] = msg[i] ^ frame.header.MaskingKey[frame.pos%4] 88 frame.pos++ 89 } 90 } 91 return n, err 92 } 93 94 func (frame *hybiFrameReader) PayloadType() byte { return frame.header.OpCode } 95 96 func (frame *hybiFrameReader) HeaderReader() io.Reader { 97 if frame.header.data == nil { 98 return nil 99 } 100 if frame.header.data.Len() == 0 { 101 return nil 102 } 103 return frame.header.data 104 } 105 106 func (frame *hybiFrameReader) TrailerReader() io.Reader { return nil } 107 108 func (frame *hybiFrameReader) Len() (n int) { return frame.length } 109 110 // A hybiFrameReaderFactory creates new frame reader based on its frame type. 111 type hybiFrameReaderFactory struct { 112 *bufio.Reader 113 } 114 115 // NewFrameReader reads a frame header from the connection, and creates new reader for the frame. 116 // See Section 5.2 Base Framing protocol for detail. 117 // http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-17#section-5.2 118 func (buf hybiFrameReaderFactory) NewFrameReader() (frame frameReader, err error) { 119 hybiFrame := new(hybiFrameReader) 120 frame = hybiFrame 121 var header []byte 122 var b byte 123 // First byte. FIN/RSV1/RSV2/RSV3/OpCode(4bits) 124 b, err = buf.ReadByte() 125 if err != nil { 126 return 127 } 128 header = append(header, b) 129 hybiFrame.header.Fin = ((header[0] >> 7) & 1) != 0 130 for i := 0; i < 3; i++ { 131 j := uint(6 - i) 132 hybiFrame.header.Rsv[i] = ((header[0] >> j) & 1) != 0 133 } 134 hybiFrame.header.OpCode = header[0] & 0x0f 135 136 // Second byte. Mask/Payload len(7bits) 137 b, err = buf.ReadByte() 138 if err != nil { 139 return 140 } 141 header = append(header, b) 142 mask := (b & 0x80) != 0 143 b &= 0x7f 144 lengthFields := 0 145 switch { 146 case b <= 125: // Payload length 7bits. 147 hybiFrame.header.Length = int64(b) 148 case b == 126: // Payload length 7+16bits 149 lengthFields = 2 150 case b == 127: // Payload length 7+64bits 151 lengthFields = 8 152 } 153 for i := 0; i < lengthFields; i++ { 154 b, err = buf.ReadByte() 155 if err != nil { 156 return 157 } 158 if lengthFields == 8 && i == 0 { // MSB must be zero when 7+64 bits 159 b &= 0x7f 160 } 161 header = append(header, b) 162 hybiFrame.header.Length = hybiFrame.header.Length*256 + int64(b) 163 } 164 if mask { 165 // Masking key. 4 bytes. 166 for i := 0; i < 4; i++ { 167 b, err = buf.ReadByte() 168 if err != nil { 169 return 170 } 171 header = append(header, b) 172 hybiFrame.header.MaskingKey = append(hybiFrame.header.MaskingKey, b) 173 } 174 } 175 hybiFrame.reader = io.LimitReader(buf.Reader, hybiFrame.header.Length) 176 hybiFrame.header.data = bytes.NewBuffer(header) 177 hybiFrame.length = len(header) + int(hybiFrame.header.Length) 178 return 179 } 180 181 // A HybiFrameWriter is a writer for hybi frame. 182 type hybiFrameWriter struct { 183 writer *bufio.Writer 184 185 header *hybiFrameHeader 186 } 187 188 func (frame *hybiFrameWriter) Write(msg []byte) (n int, err error) { 189 var header []byte 190 var b byte 191 if frame.header.Fin { 192 b |= 0x80 193 } 194 for i := 0; i < 3; i++ { 195 if frame.header.Rsv[i] { 196 j := uint(6 - i) 197 b |= 1 << j 198 } 199 } 200 b |= frame.header.OpCode 201 header = append(header, b) 202 if frame.header.MaskingKey != nil { 203 b = 0x80 204 } else { 205 b = 0 206 } 207 lengthFields := 0 208 length := len(msg) 209 switch { 210 case length <= 125: 211 b |= byte(length) 212 case length < 65536: 213 b |= 126 214 lengthFields = 2 215 default: 216 b |= 127 217 lengthFields = 8 218 } 219 header = append(header, b) 220 for i := 0; i < lengthFields; i++ { 221 j := uint((lengthFields - i - 1) * 8) 222 b = byte((length >> j) & 0xff) 223 header = append(header, b) 224 } 225 if frame.header.MaskingKey != nil { 226 if len(frame.header.MaskingKey) != 4 { 227 return 0, ErrBadMaskingKey 228 } 229 header = append(header, frame.header.MaskingKey...) 230 frame.writer.Write(header) 231 data := make([]byte, length) 232 for i := range data { 233 data[i] = msg[i] ^ frame.header.MaskingKey[i%4] 234 } 235 frame.writer.Write(data) 236 err = frame.writer.Flush() 237 return length, err 238 } 239 frame.writer.Write(header) 240 frame.writer.Write(msg) 241 err = frame.writer.Flush() 242 return length, err 243 } 244 245 func (frame *hybiFrameWriter) Close() error { return nil } 246 247 type hybiFrameWriterFactory struct { 248 *bufio.Writer 249 needMaskingKey bool 250 } 251 252 func (buf hybiFrameWriterFactory) NewFrameWriter(payloadType byte) (frame frameWriter, err error) { 253 frameHeader := &hybiFrameHeader{Fin: true, OpCode: payloadType} 254 if buf.needMaskingKey { 255 frameHeader.MaskingKey, err = generateMaskingKey() 256 if err != nil { 257 return nil, err 258 } 259 } 260 return &hybiFrameWriter{writer: buf.Writer, header: frameHeader}, nil 261 } 262 263 type hybiFrameHandler struct { 264 conn *Conn 265 payloadType byte 266 } 267 268 func (handler *hybiFrameHandler) HandleFrame(frame frameReader) (frameReader, error) { 269 if handler.conn.IsServerConn() { 270 // The client MUST mask all frames sent to the server. 271 if frame.(*hybiFrameReader).header.MaskingKey == nil { 272 handler.WriteClose(closeStatusProtocolError) 273 return nil, io.EOF 274 } 275 } else { 276 // The server MUST NOT mask all frames. 277 if frame.(*hybiFrameReader).header.MaskingKey != nil { 278 handler.WriteClose(closeStatusProtocolError) 279 return nil, io.EOF 280 } 281 } 282 if header := frame.HeaderReader(); header != nil { 283 io.Copy(ioutil.Discard, header) 284 } 285 switch frame.PayloadType() { 286 case ContinuationFrame: 287 frame.(*hybiFrameReader).header.OpCode = handler.payloadType 288 case TextFrame, BinaryFrame: 289 handler.payloadType = frame.PayloadType() 290 case CloseFrame: 291 return nil, io.EOF 292 case PingFrame, PongFrame: 293 b := make([]byte, maxControlFramePayloadLength) 294 n, err := io.ReadFull(frame, b) 295 if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF { 296 return nil, err 297 } 298 io.Copy(ioutil.Discard, frame) 299 if frame.PayloadType() == PingFrame { 300 if _, err := handler.WritePong(b[:n]); err != nil { 301 return nil, err 302 } 303 } 304 return nil, nil 305 } 306 return frame, nil 307 } 308 309 func (handler *hybiFrameHandler) WriteClose(status int) (err error) { 310 handler.conn.wio.Lock() 311 defer handler.conn.wio.Unlock() 312 w, err := handler.conn.frameWriterFactory.NewFrameWriter(CloseFrame) 313 if err != nil { 314 return err 315 } 316 msg := make([]byte, 2) 317 binary.BigEndian.PutUint16(msg, uint16(status)) 318 _, err = w.Write(msg) 319 w.Close() 320 return err 321 } 322 323 func (handler *hybiFrameHandler) WritePong(msg []byte) (n int, err error) { 324 handler.conn.wio.Lock() 325 defer handler.conn.wio.Unlock() 326 w, err := handler.conn.frameWriterFactory.NewFrameWriter(PongFrame) 327 if err != nil { 328 return 0, err 329 } 330 n, err = w.Write(msg) 331 w.Close() 332 return n, err 333 } 334 335 // newHybiConn creates a new WebSocket connection speaking hybi draft protocol. 336 func newHybiConn(config *Config, buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) *Conn { 337 if buf == nil { 338 br := bufio.NewReader(rwc) 339 bw := bufio.NewWriter(rwc) 340 buf = bufio.NewReadWriter(br, bw) 341 } 342 ws := &Conn{config: config, request: request, buf: buf, rwc: rwc, 343 frameReaderFactory: hybiFrameReaderFactory{buf.Reader}, 344 frameWriterFactory: hybiFrameWriterFactory{ 345 buf.Writer, request == nil}, 346 PayloadType: TextFrame, 347 defaultCloseStatus: closeStatusNormal} 348 ws.frameHandler = &hybiFrameHandler{conn: ws} 349 return ws 350 } 351 352 // generateMaskingKey generates a masking key for a frame. 353 func generateMaskingKey() (maskingKey []byte, err error) { 354 maskingKey = make([]byte, 4) 355 if _, err = io.ReadFull(rand.Reader, maskingKey); err != nil { 356 return 357 } 358 return 359 } 360 361 // generateNonce generates a nonce consisting of a randomly selected 16-byte 362 // value that has been base64-encoded. 363 func generateNonce() (nonce []byte) { 364 key := make([]byte, 16) 365 if _, err := io.ReadFull(rand.Reader, key); err != nil { 366 panic(err) 367 } 368 nonce = make([]byte, 24) 369 base64.StdEncoding.Encode(nonce, key) 370 return 371 } 372 373 // removeZone removes IPv6 zone identifer from host. 374 // E.g., "[fe80::1%en0]:8080" to "[fe80::1]:8080" 375 func removeZone(host string) string { 376 if !strings.HasPrefix(host, "[") { 377 return host 378 } 379 i := strings.LastIndex(host, "]") 380 if i < 0 { 381 return host 382 } 383 j := strings.LastIndex(host[:i], "%") 384 if j < 0 { 385 return host 386 } 387 return host[:j] + host[i:] 388 } 389 390 // getNonceAccept computes the base64-encoded SHA-1 of the concatenation of 391 // the nonce ("Sec-WebSocket-Key" value) with the websocket GUID string. 392 func getNonceAccept(nonce []byte) (expected []byte, err error) { 393 h := sha1.New() 394 if _, err = h.Write(nonce); err != nil { 395 return 396 } 397 if _, err = h.Write([]byte(websocketGUID)); err != nil { 398 return 399 } 400 expected = make([]byte, 28) 401 base64.StdEncoding.Encode(expected, h.Sum(nil)) 402 return 403 } 404 405 // Client handshake described in draft-ietf-hybi-thewebsocket-protocol-17 406 func hybiClientHandshake(config *Config, br *bufio.Reader, bw *bufio.Writer) (err error) { 407 bw.WriteString("GET " + config.Location.RequestURI() + " HTTP/1.1\r\n") 408 409 // According to RFC 6874, an HTTP client, proxy, or other 410 // intermediary must remove any IPv6 zone identifier attached 411 // to an outgoing URI. 412 bw.WriteString("Host: " + removeZone(config.Location.Host) + "\r\n") 413 bw.WriteString("Upgrade: websocket\r\n") 414 bw.WriteString("Connection: Upgrade\r\n") 415 nonce := generateNonce() 416 if config.handshakeData != nil { 417 nonce = []byte(config.handshakeData["key"]) 418 } 419 bw.WriteString("Sec-WebSocket-Key: " + string(nonce) + "\r\n") 420 bw.WriteString("Origin: " + strings.ToLower(config.Origin.String()) + "\r\n") 421 422 if config.Version != ProtocolVersionHybi13 { 423 return ErrBadProtocolVersion 424 } 425 426 bw.WriteString("Sec-WebSocket-Version: " + fmt.Sprintf("%d", config.Version) + "\r\n") 427 if len(config.Protocol) > 0 { 428 bw.WriteString("Sec-WebSocket-Protocol: " + strings.Join(config.Protocol, ", ") + "\r\n") 429 } 430 // TODO(ukai): send Sec-WebSocket-Extensions. 431 err = config.Header.WriteSubset(bw, handshakeHeader) 432 if err != nil { 433 return err 434 } 435 436 bw.WriteString("\r\n") 437 if err = bw.Flush(); err != nil { 438 return err 439 } 440 441 resp, err := http.ReadResponse(br, &http.Request{Method: "GET"}) 442 if err != nil { 443 return err 444 } 445 if resp.StatusCode != 101 { 446 return ErrBadStatus 447 } 448 if strings.ToLower(resp.Header.Get("Upgrade")) != "websocket" || 449 strings.ToLower(resp.Header.Get("Connection")) != "upgrade" { 450 return ErrBadUpgrade 451 } 452 expectedAccept, err := getNonceAccept(nonce) 453 if err != nil { 454 return err 455 } 456 if resp.Header.Get("Sec-WebSocket-Accept") != string(expectedAccept) { 457 return ErrChallengeResponse 458 } 459 if resp.Header.Get("Sec-WebSocket-Extensions") != "" { 460 return ErrUnsupportedExtensions 461 } 462 offeredProtocol := resp.Header.Get("Sec-WebSocket-Protocol") 463 if offeredProtocol != "" { 464 protocolMatched := false 465 for i := 0; i < len(config.Protocol); i++ { 466 if config.Protocol[i] == offeredProtocol { 467 protocolMatched = true 468 break 469 } 470 } 471 if !protocolMatched { 472 return ErrBadWebSocketProtocol 473 } 474 config.Protocol = []string{offeredProtocol} 475 } 476 477 return nil 478 } 479 480 // newHybiClientConn creates a client WebSocket connection after handshake. 481 func newHybiClientConn(config *Config, buf *bufio.ReadWriter, rwc io.ReadWriteCloser) *Conn { 482 return newHybiConn(config, buf, rwc, nil) 483 } 484 485 // A HybiServerHandshaker performs a server handshake using hybi draft protocol. 486 type hybiServerHandshaker struct { 487 *Config 488 accept []byte 489 } 490 491 func (c *hybiServerHandshaker) ReadHandshake(buf *bufio.Reader, req *http.Request) (code int, err error) { 492 c.Version = ProtocolVersionHybi13 493 if req.Method != "GET" { 494 return http.StatusMethodNotAllowed, ErrBadRequestMethod 495 } 496 // HTTP version can be safely ignored. 497 498 if strings.ToLower(req.Header.Get("Upgrade")) != "websocket" || 499 !strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade") { 500 return http.StatusBadRequest, ErrNotWebSocket 501 } 502 503 key := req.Header.Get("Sec-Websocket-Key") 504 if key == "" { 505 return http.StatusBadRequest, ErrChallengeResponse 506 } 507 version := req.Header.Get("Sec-Websocket-Version") 508 switch version { 509 case "13": 510 c.Version = ProtocolVersionHybi13 511 default: 512 return http.StatusBadRequest, ErrBadWebSocketVersion 513 } 514 var scheme string 515 if req.TLS != nil { 516 scheme = "wss" 517 } else { 518 scheme = "ws" 519 } 520 c.Location, err = url.ParseRequestURI(scheme + "://" + req.Host + req.URL.RequestURI()) 521 if err != nil { 522 return http.StatusBadRequest, err 523 } 524 protocol := strings.TrimSpace(req.Header.Get("Sec-Websocket-Protocol")) 525 if protocol != "" { 526 protocols := strings.Split(protocol, ",") 527 for i := 0; i < len(protocols); i++ { 528 c.Protocol = append(c.Protocol, strings.TrimSpace(protocols[i])) 529 } 530 } 531 c.accept, err = getNonceAccept([]byte(key)) 532 if err != nil { 533 return http.StatusInternalServerError, err 534 } 535 return http.StatusSwitchingProtocols, nil 536 } 537 538 // Origin parses the Origin header in req. 539 // If the Origin header is not set, it returns nil and nil. 540 func Origin(config *Config, req *http.Request) (*url.URL, error) { 541 var origin string 542 switch config.Version { 543 case ProtocolVersionHybi13: 544 origin = req.Header.Get("Origin") 545 } 546 if origin == "" { 547 return nil, nil 548 } 549 return url.ParseRequestURI(origin) 550 } 551 552 func (c *hybiServerHandshaker) AcceptHandshake(buf *bufio.Writer) (err error) { 553 if len(c.Protocol) > 0 { 554 if len(c.Protocol) != 1 { 555 // You need choose a Protocol in Handshake func in Server. 556 return ErrBadWebSocketProtocol 557 } 558 } 559 buf.WriteString("HTTP/1.1 101 Switching Protocols\r\n") 560 buf.WriteString("Upgrade: websocket\r\n") 561 buf.WriteString("Connection: Upgrade\r\n") 562 buf.WriteString("Sec-WebSocket-Accept: " + string(c.accept) + "\r\n") 563 if len(c.Protocol) > 0 { 564 buf.WriteString("Sec-WebSocket-Protocol: " + c.Protocol[0] + "\r\n") 565 } 566 // TODO(ukai): send Sec-WebSocket-Extensions. 567 if c.Header != nil { 568 err := c.Header.WriteSubset(buf, handshakeHeader) 569 if err != nil { 570 return err 571 } 572 } 573 buf.WriteString("\r\n") 574 return buf.Flush() 575 } 576 577 func (c *hybiServerHandshaker) NewServerConn(buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) *Conn { 578 return newHybiServerConn(c.Config, buf, rwc, request) 579 } 580 581 // newHybiServerConn returns a new WebSocket connection speaking hybi draft protocol. 582 func newHybiServerConn(config *Config, buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) *Conn { 583 return newHybiConn(config, buf, rwc, request) 584 }