github.com/Andyfoo/golang/x/net@v0.0.0-20190901054642-57c1bf301704/websocket/hybi.go (about) 1 // Copyright 2011 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package websocket 6 7 // This file implements a protocol of hybi draft. 8 // http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-17 9 10 import ( 11 "bufio" 12 "bytes" 13 "crypto/rand" 14 "crypto/sha1" 15 "encoding/base64" 16 "encoding/binary" 17 "fmt" 18 "io" 19 "io/ioutil" 20 "net/http" 21 "net/url" 22 "strings" 23 ) 24 25 const ( 26 websocketGUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" 27 28 closeStatusNormal = 1000 29 closeStatusGoingAway = 1001 30 closeStatusProtocolError = 1002 31 closeStatusUnsupportedData = 1003 32 closeStatusFrameTooLarge = 1004 33 closeStatusNoStatusRcvd = 1005 34 closeStatusAbnormalClosure = 1006 35 closeStatusBadMessageData = 1007 36 closeStatusPolicyViolation = 1008 37 closeStatusTooBigData = 1009 38 closeStatusExtensionMismatch = 1010 39 40 maxControlFramePayloadLength = 125 41 ) 42 43 var ( 44 ErrBadMaskingKey = &ProtocolError{"bad masking key"} 45 ErrBadPongMessage = &ProtocolError{"bad pong message"} 46 ErrBadClosingStatus = &ProtocolError{"bad closing status"} 47 ErrUnsupportedExtensions = &ProtocolError{"unsupported extensions"} 48 ErrNotImplemented = &ProtocolError{"not implemented"} 49 50 handshakeHeader = map[string]bool{ 51 "Host": true, 52 "Upgrade": true, 53 "Connection": true, 54 "Sec-Websocket-Key": true, 55 "Sec-Websocket-Origin": true, 56 "Sec-Websocket-Version": true, 57 "Sec-Websocket-Protocol": true, 58 "Sec-Websocket-Accept": true, 59 } 60 ) 61 62 // A hybiFrameHeader is a frame header as defined in hybi draft. 63 type hybiFrameHeader struct { 64 Fin bool 65 Rsv [3]bool 66 OpCode byte 67 Length int64 68 MaskingKey []byte 69 70 data *bytes.Buffer 71 } 72 73 // A hybiFrameReader is a reader for hybi frame. 74 type hybiFrameReader struct { 75 reader io.Reader 76 77 header hybiFrameHeader 78 pos int64 79 length int 80 } 81 82 func (frame *hybiFrameReader) Read(msg []byte) (n int, err error) { 83 n, err = frame.reader.Read(msg) 84 if frame.header.MaskingKey != nil { 85 for i := 0; i < n; i++ { 86 msg[i] = msg[i] ^ frame.header.MaskingKey[frame.pos%4] 87 frame.pos++ 88 } 89 } 90 return n, err 91 } 92 93 func (frame *hybiFrameReader) PayloadType() byte { return frame.header.OpCode } 94 95 func (frame *hybiFrameReader) HeaderReader() io.Reader { 96 if frame.header.data == nil { 97 return nil 98 } 99 if frame.header.data.Len() == 0 { 100 return nil 101 } 102 return frame.header.data 103 } 104 105 func (frame *hybiFrameReader) TrailerReader() io.Reader { return nil } 106 107 func (frame *hybiFrameReader) Len() (n int) { return frame.length } 108 109 // A hybiFrameReaderFactory creates new frame reader based on its frame type. 110 type hybiFrameReaderFactory struct { 111 *bufio.Reader 112 } 113 114 // NewFrameReader reads a frame header from the connection, and creates new reader for the frame. 115 // See Section 5.2 Base Framing protocol for detail. 116 // http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-17#section-5.2 117 func (buf hybiFrameReaderFactory) NewFrameReader() (frame frameReader, err error) { 118 hybiFrame := new(hybiFrameReader) 119 frame = hybiFrame 120 var header []byte 121 var b byte 122 // First byte. FIN/RSV1/RSV2/RSV3/OpCode(4bits) 123 b, err = buf.ReadByte() 124 if err != nil { 125 return 126 } 127 header = append(header, b) 128 hybiFrame.header.Fin = ((header[0] >> 7) & 1) != 0 129 for i := 0; i < 3; i++ { 130 j := uint(6 - i) 131 hybiFrame.header.Rsv[i] = ((header[0] >> j) & 1) != 0 132 } 133 hybiFrame.header.OpCode = header[0] & 0x0f 134 135 // Second byte. Mask/Payload len(7bits) 136 b, err = buf.ReadByte() 137 if err != nil { 138 return 139 } 140 header = append(header, b) 141 mask := (b & 0x80) != 0 142 b &= 0x7f 143 lengthFields := 0 144 switch { 145 case b <= 125: // Payload length 7bits. 146 hybiFrame.header.Length = int64(b) 147 case b == 126: // Payload length 7+16bits 148 lengthFields = 2 149 case b == 127: // Payload length 7+64bits 150 lengthFields = 8 151 } 152 for i := 0; i < lengthFields; i++ { 153 b, err = buf.ReadByte() 154 if err != nil { 155 return 156 } 157 if lengthFields == 8 && i == 0 { // MSB must be zero when 7+64 bits 158 b &= 0x7f 159 } 160 header = append(header, b) 161 hybiFrame.header.Length = hybiFrame.header.Length*256 + int64(b) 162 } 163 if mask { 164 // Masking key. 4 bytes. 165 for i := 0; i < 4; i++ { 166 b, err = buf.ReadByte() 167 if err != nil { 168 return 169 } 170 header = append(header, b) 171 hybiFrame.header.MaskingKey = append(hybiFrame.header.MaskingKey, b) 172 } 173 } 174 hybiFrame.reader = io.LimitReader(buf.Reader, hybiFrame.header.Length) 175 hybiFrame.header.data = bytes.NewBuffer(header) 176 hybiFrame.length = len(header) + int(hybiFrame.header.Length) 177 return 178 } 179 180 // A HybiFrameWriter is a writer for hybi frame. 181 type hybiFrameWriter struct { 182 writer *bufio.Writer 183 184 header *hybiFrameHeader 185 } 186 187 func (frame *hybiFrameWriter) Write(msg []byte) (n int, err error) { 188 var header []byte 189 var b byte 190 if frame.header.Fin { 191 b |= 0x80 192 } 193 for i := 0; i < 3; i++ { 194 if frame.header.Rsv[i] { 195 j := uint(6 - i) 196 b |= 1 << j 197 } 198 } 199 b |= frame.header.OpCode 200 header = append(header, b) 201 if frame.header.MaskingKey != nil { 202 b = 0x80 203 } else { 204 b = 0 205 } 206 lengthFields := 0 207 length := len(msg) 208 switch { 209 case length <= 125: 210 b |= byte(length) 211 case length < 65536: 212 b |= 126 213 lengthFields = 2 214 default: 215 b |= 127 216 lengthFields = 8 217 } 218 header = append(header, b) 219 for i := 0; i < lengthFields; i++ { 220 j := uint((lengthFields - i - 1) * 8) 221 b = byte((length >> j) & 0xff) 222 header = append(header, b) 223 } 224 if frame.header.MaskingKey != nil { 225 if len(frame.header.MaskingKey) != 4 { 226 return 0, ErrBadMaskingKey 227 } 228 header = append(header, frame.header.MaskingKey...) 229 frame.writer.Write(header) 230 data := make([]byte, length) 231 for i := range data { 232 data[i] = msg[i] ^ frame.header.MaskingKey[i%4] 233 } 234 frame.writer.Write(data) 235 err = frame.writer.Flush() 236 return length, err 237 } 238 frame.writer.Write(header) 239 frame.writer.Write(msg) 240 err = frame.writer.Flush() 241 return length, err 242 } 243 244 func (frame *hybiFrameWriter) Close() error { return nil } 245 246 type hybiFrameWriterFactory struct { 247 *bufio.Writer 248 needMaskingKey bool 249 } 250 251 func (buf hybiFrameWriterFactory) NewFrameWriter(payloadType byte) (frame frameWriter, err error) { 252 frameHeader := &hybiFrameHeader{Fin: true, OpCode: payloadType} 253 if buf.needMaskingKey { 254 frameHeader.MaskingKey, err = generateMaskingKey() 255 if err != nil { 256 return nil, err 257 } 258 } 259 return &hybiFrameWriter{writer: buf.Writer, header: frameHeader}, nil 260 } 261 262 type hybiFrameHandler struct { 263 conn *Conn 264 payloadType byte 265 } 266 267 func (handler *hybiFrameHandler) HandleFrame(frame frameReader) (frameReader, error) { 268 if handler.conn.IsServerConn() { 269 // The client MUST mask all frames sent to the server. 270 if frame.(*hybiFrameReader).header.MaskingKey == nil { 271 handler.WriteClose(closeStatusProtocolError) 272 return nil, io.EOF 273 } 274 } else { 275 // The server MUST NOT mask all frames. 276 if frame.(*hybiFrameReader).header.MaskingKey != nil { 277 handler.WriteClose(closeStatusProtocolError) 278 return nil, io.EOF 279 } 280 } 281 if header := frame.HeaderReader(); header != nil { 282 io.Copy(ioutil.Discard, header) 283 } 284 switch frame.PayloadType() { 285 case ContinuationFrame: 286 frame.(*hybiFrameReader).header.OpCode = handler.payloadType 287 case TextFrame, BinaryFrame: 288 handler.payloadType = frame.PayloadType() 289 case CloseFrame: 290 return nil, io.EOF 291 case PingFrame, PongFrame: 292 b := make([]byte, maxControlFramePayloadLength) 293 n, err := io.ReadFull(frame, b) 294 if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF { 295 return nil, err 296 } 297 io.Copy(ioutil.Discard, frame) 298 if frame.PayloadType() == PingFrame { 299 if _, err := handler.WritePong(b[:n]); err != nil { 300 return nil, err 301 } 302 } 303 return nil, nil 304 } 305 return frame, nil 306 } 307 308 func (handler *hybiFrameHandler) WriteClose(status int) (err error) { 309 handler.conn.wio.Lock() 310 defer handler.conn.wio.Unlock() 311 w, err := handler.conn.frameWriterFactory.NewFrameWriter(CloseFrame) 312 if err != nil { 313 return err 314 } 315 msg := make([]byte, 2) 316 binary.BigEndian.PutUint16(msg, uint16(status)) 317 _, err = w.Write(msg) 318 w.Close() 319 return err 320 } 321 322 func (handler *hybiFrameHandler) WritePong(msg []byte) (n int, err error) { 323 handler.conn.wio.Lock() 324 defer handler.conn.wio.Unlock() 325 w, err := handler.conn.frameWriterFactory.NewFrameWriter(PongFrame) 326 if err != nil { 327 return 0, err 328 } 329 n, err = w.Write(msg) 330 w.Close() 331 return n, err 332 } 333 334 // newHybiConn creates a new WebSocket connection speaking hybi draft protocol. 335 func newHybiConn(config *Config, buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) *Conn { 336 if buf == nil { 337 br := bufio.NewReader(rwc) 338 bw := bufio.NewWriter(rwc) 339 buf = bufio.NewReadWriter(br, bw) 340 } 341 ws := &Conn{config: config, request: request, buf: buf, rwc: rwc, 342 frameReaderFactory: hybiFrameReaderFactory{buf.Reader}, 343 frameWriterFactory: hybiFrameWriterFactory{ 344 buf.Writer, request == nil}, 345 PayloadType: TextFrame, 346 defaultCloseStatus: closeStatusNormal} 347 ws.frameHandler = &hybiFrameHandler{conn: ws} 348 return ws 349 } 350 351 // generateMaskingKey generates a masking key for a frame. 352 func generateMaskingKey() (maskingKey []byte, err error) { 353 maskingKey = make([]byte, 4) 354 if _, err = io.ReadFull(rand.Reader, maskingKey); err != nil { 355 return 356 } 357 return 358 } 359 360 // generateNonce generates a nonce consisting of a randomly selected 16-byte 361 // value that has been base64-encoded. 362 func generateNonce() (nonce []byte) { 363 key := make([]byte, 16) 364 if _, err := io.ReadFull(rand.Reader, key); err != nil { 365 panic(err) 366 } 367 nonce = make([]byte, 24) 368 base64.StdEncoding.Encode(nonce, key) 369 return 370 } 371 372 // removeZone removes IPv6 zone identifer from host. 373 // E.g., "[fe80::1%en0]:8080" to "[fe80::1]:8080" 374 func removeZone(host string) string { 375 if !strings.HasPrefix(host, "[") { 376 return host 377 } 378 i := strings.LastIndex(host, "]") 379 if i < 0 { 380 return host 381 } 382 j := strings.LastIndex(host[:i], "%") 383 if j < 0 { 384 return host 385 } 386 return host[:j] + host[i:] 387 } 388 389 // getNonceAccept computes the base64-encoded SHA-1 of the concatenation of 390 // the nonce ("Sec-WebSocket-Key" value) with the websocket GUID string. 391 func getNonceAccept(nonce []byte) (expected []byte, err error) { 392 h := sha1.New() 393 if _, err = h.Write(nonce); err != nil { 394 return 395 } 396 if _, err = h.Write([]byte(websocketGUID)); err != nil { 397 return 398 } 399 expected = make([]byte, 28) 400 base64.StdEncoding.Encode(expected, h.Sum(nil)) 401 return 402 } 403 404 // Client handshake described in draft-ietf-hybi-thewebsocket-protocol-17 405 func hybiClientHandshake(config *Config, br *bufio.Reader, bw *bufio.Writer) (err error) { 406 bw.WriteString("GET " + config.Location.RequestURI() + " HTTP/1.1\r\n") 407 408 // According to RFC 6874, an HTTP client, proxy, or other 409 // intermediary must remove any IPv6 zone identifier attached 410 // to an outgoing URI. 411 bw.WriteString("Host: " + removeZone(config.Location.Host) + "\r\n") 412 bw.WriteString("Upgrade: websocket\r\n") 413 bw.WriteString("Connection: Upgrade\r\n") 414 nonce := generateNonce() 415 if config.handshakeData != nil { 416 nonce = []byte(config.handshakeData["key"]) 417 } 418 bw.WriteString("Sec-WebSocket-Key: " + string(nonce) + "\r\n") 419 bw.WriteString("Origin: " + strings.ToLower(config.Origin.String()) + "\r\n") 420 421 if config.Version != ProtocolVersionHybi13 { 422 return ErrBadProtocolVersion 423 } 424 425 bw.WriteString("Sec-WebSocket-Version: " + fmt.Sprintf("%d", config.Version) + "\r\n") 426 if len(config.Protocol) > 0 { 427 bw.WriteString("Sec-WebSocket-Protocol: " + strings.Join(config.Protocol, ", ") + "\r\n") 428 } 429 // TODO(ukai): send Sec-WebSocket-Extensions. 430 err = config.Header.WriteSubset(bw, handshakeHeader) 431 if err != nil { 432 return err 433 } 434 435 bw.WriteString("\r\n") 436 if err = bw.Flush(); err != nil { 437 return err 438 } 439 440 resp, err := http.ReadResponse(br, &http.Request{Method: "GET"}) 441 if err != nil { 442 return err 443 } 444 if resp.StatusCode != 101 { 445 return ErrBadStatus 446 } 447 if strings.ToLower(resp.Header.Get("Upgrade")) != "websocket" || 448 strings.ToLower(resp.Header.Get("Connection")) != "upgrade" { 449 return ErrBadUpgrade 450 } 451 expectedAccept, err := getNonceAccept(nonce) 452 if err != nil { 453 return err 454 } 455 if resp.Header.Get("Sec-WebSocket-Accept") != string(expectedAccept) { 456 return ErrChallengeResponse 457 } 458 if resp.Header.Get("Sec-WebSocket-Extensions") != "" { 459 return ErrUnsupportedExtensions 460 } 461 offeredProtocol := resp.Header.Get("Sec-WebSocket-Protocol") 462 if offeredProtocol != "" { 463 protocolMatched := false 464 for i := 0; i < len(config.Protocol); i++ { 465 if config.Protocol[i] == offeredProtocol { 466 protocolMatched = true 467 break 468 } 469 } 470 if !protocolMatched { 471 return ErrBadWebSocketProtocol 472 } 473 config.Protocol = []string{offeredProtocol} 474 } 475 476 return nil 477 } 478 479 // newHybiClientConn creates a client WebSocket connection after handshake. 480 func newHybiClientConn(config *Config, buf *bufio.ReadWriter, rwc io.ReadWriteCloser) *Conn { 481 return newHybiConn(config, buf, rwc, nil) 482 } 483 484 // A HybiServerHandshaker performs a server handshake using hybi draft protocol. 485 type hybiServerHandshaker struct { 486 *Config 487 accept []byte 488 } 489 490 func (c *hybiServerHandshaker) ReadHandshake(buf *bufio.Reader, req *http.Request) (code int, err error) { 491 c.Version = ProtocolVersionHybi13 492 if req.Method != "GET" { 493 return http.StatusMethodNotAllowed, ErrBadRequestMethod 494 } 495 // HTTP version can be safely ignored. 496 497 if strings.ToLower(req.Header.Get("Upgrade")) != "websocket" || 498 !strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade") { 499 return http.StatusBadRequest, ErrNotWebSocket 500 } 501 502 key := req.Header.Get("Sec-Websocket-Key") 503 if key == "" { 504 return http.StatusBadRequest, ErrChallengeResponse 505 } 506 version := req.Header.Get("Sec-Websocket-Version") 507 switch version { 508 case "13": 509 c.Version = ProtocolVersionHybi13 510 default: 511 return http.StatusBadRequest, ErrBadWebSocketVersion 512 } 513 var scheme string 514 if req.TLS != nil { 515 scheme = "wss" 516 } else { 517 scheme = "ws" 518 } 519 c.Location, err = url.ParseRequestURI(scheme + "://" + req.Host + req.URL.RequestURI()) 520 if err != nil { 521 return http.StatusBadRequest, err 522 } 523 protocol := strings.TrimSpace(req.Header.Get("Sec-Websocket-Protocol")) 524 if protocol != "" { 525 protocols := strings.Split(protocol, ",") 526 for i := 0; i < len(protocols); i++ { 527 c.Protocol = append(c.Protocol, strings.TrimSpace(protocols[i])) 528 } 529 } 530 c.accept, err = getNonceAccept([]byte(key)) 531 if err != nil { 532 return http.StatusInternalServerError, err 533 } 534 return http.StatusSwitchingProtocols, nil 535 } 536 537 // Origin parses the Origin header in req. 538 // If the Origin header is not set, it returns nil and nil. 539 func Origin(config *Config, req *http.Request) (*url.URL, error) { 540 var origin string 541 switch config.Version { 542 case ProtocolVersionHybi13: 543 origin = req.Header.Get("Origin") 544 } 545 if origin == "" { 546 return nil, nil 547 } 548 return url.ParseRequestURI(origin) 549 } 550 551 func (c *hybiServerHandshaker) AcceptHandshake(buf *bufio.Writer) (err error) { 552 if len(c.Protocol) > 0 { 553 if len(c.Protocol) != 1 { 554 // You need choose a Protocol in Handshake func in Server. 555 return ErrBadWebSocketProtocol 556 } 557 } 558 buf.WriteString("HTTP/1.1 101 Switching Protocols\r\n") 559 buf.WriteString("Upgrade: websocket\r\n") 560 buf.WriteString("Connection: Upgrade\r\n") 561 buf.WriteString("Sec-WebSocket-Accept: " + string(c.accept) + "\r\n") 562 if len(c.Protocol) > 0 { 563 buf.WriteString("Sec-WebSocket-Protocol: " + c.Protocol[0] + "\r\n") 564 } 565 // TODO(ukai): send Sec-WebSocket-Extensions. 566 if c.Header != nil { 567 err := c.Header.WriteSubset(buf, handshakeHeader) 568 if err != nil { 569 return err 570 } 571 } 572 buf.WriteString("\r\n") 573 return buf.Flush() 574 } 575 576 func (c *hybiServerHandshaker) NewServerConn(buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) *Conn { 577 return newHybiServerConn(c.Config, buf, rwc, request) 578 } 579 580 // newHybiServerConn returns a new WebSocket connection speaking hybi draft protocol. 581 func newHybiServerConn(config *Config, buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) *Conn { 582 return newHybiConn(config, buf, rwc, request) 583 }