github.com/gravitational/teleport/api@v0.0.0-20240507183017-3110591cbafc/client/alpn_websocket.go (about) 1 /* 2 Copyright 2024 Gravitational, Inc. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package client 18 19 import ( 20 "crypto/rand" 21 "crypto/sha1" 22 "encoding/base64" 23 "io" 24 "net" 25 "net/http" 26 "sync" 27 28 "github.com/gobwas/ws" 29 "github.com/gravitational/trace" 30 31 "github.com/gravitational/teleport/api/constants" 32 ) 33 34 func applyWebSocketUpgradeHeaders(req *http.Request, alpnUpgradeType, challengeKey string) { 35 // Set standard WebSocket upgrade type. 36 req.Header.Add(constants.WebAPIConnUpgradeHeader, constants.WebAPIConnUpgradeTypeWebSocket) 37 38 // Set "Connection" header to meet RFC spec: 39 // https://datatracker.ietf.org/doc/html/rfc2616#section-14.42 40 // Quote: "the upgrade keyword MUST be supplied within a Connection header 41 // field (section 14.10) whenever Upgrade is present in an HTTP/1.1 42 // message." 43 req.Header.Set(constants.WebAPIConnUpgradeConnectionHeader, constants.WebAPIConnUpgradeConnectionType) 44 45 // Set alpnUpgradeType as sub protocol. 46 req.Header.Set(websocketHeaderKeyProtocol, alpnUpgradeType) 47 req.Header.Set(websocketHeaderKeyVersion, "13") 48 req.Header.Set(websocketHeaderKeyChallengeKey, challengeKey) 49 } 50 51 func computeWebSocketAcceptKey(challengeKey string) string { 52 h := sha1.New() 53 h.Write([]byte(challengeKey)) 54 h.Write([]byte(websocketAcceptKeyMagicString)) 55 return base64.StdEncoding.EncodeToString(h.Sum(nil)) 56 } 57 58 func generateWebSocketChallengeKey() (string, error) { 59 // Quote from https://www.rfc-editor.org/rfc/rfc6455: 60 // 61 // A |Sec-WebSocket-Key| header field with a base64-encoded (see Section 4 62 // of [RFC4648]) value that, when decoded, is 16 bytes in length. 63 p := make([]byte, 16) 64 if _, err := io.ReadFull(rand.Reader, p); err != nil { 65 return "", trace.Wrap(err) 66 } 67 return base64.StdEncoding.EncodeToString(p), nil 68 } 69 70 func checkWebSocketUpgradeResponse(resp *http.Response, alpnUpgradeType, challengeKey string) error { 71 if alpnUpgradeType != resp.Header.Get(websocketHeaderKeyProtocol) { 72 return trace.BadParameter("WebSocket handshake failed: Sec-WebSocket-Protocol does not match") 73 } 74 if computeWebSocketAcceptKey(challengeKey) != resp.Header.Get(websocketHeaderKeyAccept) { 75 return trace.BadParameter("WebSocket handshake failed: invalid Sec-WebSocket-Accept") 76 } 77 return nil 78 } 79 80 type websocketALPNClientConn struct { 81 net.Conn 82 readBuffer []byte 83 readMutex sync.Mutex 84 writeMutex sync.Mutex 85 } 86 87 func newWebSocketALPNClientConn(conn net.Conn) *websocketALPNClientConn { 88 return &websocketALPNClientConn{ 89 Conn: conn, 90 } 91 } 92 93 func (c *websocketALPNClientConn) Read(b []byte) (int, error) { 94 c.readMutex.Lock() 95 defer c.readMutex.Unlock() 96 97 n, err := c.readLocked(b) 98 return n, trace.Wrap(err) 99 } 100 101 func (c *websocketALPNClientConn) readLocked(b []byte) (int, error) { 102 if len(c.readBuffer) > 0 { 103 n := copy(b, c.readBuffer) 104 if n < len(c.readBuffer) { 105 c.readBuffer = c.readBuffer[n:] 106 } else { 107 c.readBuffer = nil 108 } 109 return n, nil 110 } 111 112 for { 113 frame, err := ws.ReadFrame(c.Conn) 114 if err != nil { 115 return 0, trace.Wrap(err) 116 } 117 118 switch frame.Header.OpCode { 119 case ws.OpClose: 120 return 0, io.EOF 121 case ws.OpPing: 122 pong := ws.NewPongFrame(frame.Payload) 123 if err := c.writeFrame(pong); err != nil { 124 return 0, trace.Wrap(err) 125 } 126 case ws.OpBinary: 127 c.readBuffer = frame.Payload 128 return c.readLocked(b) 129 } 130 } 131 } 132 133 func (c *websocketALPNClientConn) Write(b []byte) (int, error) { 134 frame := ws.NewFrame(ws.OpBinary, true, b) 135 return len(b), trace.Wrap(c.writeFrame(frame)) 136 } 137 138 func (c *websocketALPNClientConn) writeFrame(frame ws.Frame) error { 139 c.writeMutex.Lock() 140 defer c.writeMutex.Unlock() 141 // By RFC standard, all client frames must be masked: 142 // https://datatracker.ietf.org/doc/html/rfc6455#section-5.1 143 frame.Header.Masked = true 144 return trace.Wrap(ws.WriteFrame(c.Conn, frame)) 145 } 146 147 const ( 148 websocketHeaderKeyProtocol = "Sec-WebSocket-Protocol" 149 websocketHeaderKeyVersion = "Sec-WebSocket-Version" 150 websocketHeaderKeyChallengeKey = "Sec-WebSocket-Key" 151 websocketHeaderKeyAccept = "Sec-WebSocket-Accept" 152 153 // websocketAcceptKeyMagicString is the magic string used for computing 154 // the accept key during WebSocket handshake. 155 // 156 // RFC reference: 157 // https://www.rfc-editor.org/rfc/rfc6455 158 // 159 // Server side uses gorilla: 160 // https://github.com/gorilla/websocket/blob/dcea2f088ce10b1b0722c4eb995a4e145b5e9047/util.go#L17-L24 161 websocketAcceptKeyMagicString = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" 162 )