github.com/gravitational/teleport/api@v0.0.0-20240507183017-3110591cbafc/client/alpn_websocket.go (about)

     1  /*
     2  Copyright 2024 Gravitational, Inc.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package client
    18  
    19  import (
    20  	"crypto/rand"
    21  	"crypto/sha1"
    22  	"encoding/base64"
    23  	"io"
    24  	"net"
    25  	"net/http"
    26  	"sync"
    27  
    28  	"github.com/gobwas/ws"
    29  	"github.com/gravitational/trace"
    30  
    31  	"github.com/gravitational/teleport/api/constants"
    32  )
    33  
    34  func applyWebSocketUpgradeHeaders(req *http.Request, alpnUpgradeType, challengeKey string) {
    35  	// Set standard WebSocket upgrade type.
    36  	req.Header.Add(constants.WebAPIConnUpgradeHeader, constants.WebAPIConnUpgradeTypeWebSocket)
    37  
    38  	// Set "Connection" header to meet RFC spec:
    39  	// https://datatracker.ietf.org/doc/html/rfc2616#section-14.42
    40  	// Quote: "the upgrade keyword MUST be supplied within a Connection header
    41  	// field (section 14.10) whenever Upgrade is present in an HTTP/1.1
    42  	// message."
    43  	req.Header.Set(constants.WebAPIConnUpgradeConnectionHeader, constants.WebAPIConnUpgradeConnectionType)
    44  
    45  	// Set alpnUpgradeType as sub protocol.
    46  	req.Header.Set(websocketHeaderKeyProtocol, alpnUpgradeType)
    47  	req.Header.Set(websocketHeaderKeyVersion, "13")
    48  	req.Header.Set(websocketHeaderKeyChallengeKey, challengeKey)
    49  }
    50  
    51  func computeWebSocketAcceptKey(challengeKey string) string {
    52  	h := sha1.New()
    53  	h.Write([]byte(challengeKey))
    54  	h.Write([]byte(websocketAcceptKeyMagicString))
    55  	return base64.StdEncoding.EncodeToString(h.Sum(nil))
    56  }
    57  
    58  func generateWebSocketChallengeKey() (string, error) {
    59  	// Quote from https://www.rfc-editor.org/rfc/rfc6455:
    60  	//
    61  	// A |Sec-WebSocket-Key| header field with a base64-encoded (see Section 4
    62  	// of [RFC4648]) value that, when decoded, is 16 bytes in length.
    63  	p := make([]byte, 16)
    64  	if _, err := io.ReadFull(rand.Reader, p); err != nil {
    65  		return "", trace.Wrap(err)
    66  	}
    67  	return base64.StdEncoding.EncodeToString(p), nil
    68  }
    69  
    70  func checkWebSocketUpgradeResponse(resp *http.Response, alpnUpgradeType, challengeKey string) error {
    71  	if alpnUpgradeType != resp.Header.Get(websocketHeaderKeyProtocol) {
    72  		return trace.BadParameter("WebSocket handshake failed: Sec-WebSocket-Protocol does not match")
    73  	}
    74  	if computeWebSocketAcceptKey(challengeKey) != resp.Header.Get(websocketHeaderKeyAccept) {
    75  		return trace.BadParameter("WebSocket handshake failed: invalid Sec-WebSocket-Accept")
    76  	}
    77  	return nil
    78  }
    79  
    80  type websocketALPNClientConn struct {
    81  	net.Conn
    82  	readBuffer []byte
    83  	readMutex  sync.Mutex
    84  	writeMutex sync.Mutex
    85  }
    86  
    87  func newWebSocketALPNClientConn(conn net.Conn) *websocketALPNClientConn {
    88  	return &websocketALPNClientConn{
    89  		Conn: conn,
    90  	}
    91  }
    92  
    93  func (c *websocketALPNClientConn) Read(b []byte) (int, error) {
    94  	c.readMutex.Lock()
    95  	defer c.readMutex.Unlock()
    96  
    97  	n, err := c.readLocked(b)
    98  	return n, trace.Wrap(err)
    99  }
   100  
   101  func (c *websocketALPNClientConn) readLocked(b []byte) (int, error) {
   102  	if len(c.readBuffer) > 0 {
   103  		n := copy(b, c.readBuffer)
   104  		if n < len(c.readBuffer) {
   105  			c.readBuffer = c.readBuffer[n:]
   106  		} else {
   107  			c.readBuffer = nil
   108  		}
   109  		return n, nil
   110  	}
   111  
   112  	for {
   113  		frame, err := ws.ReadFrame(c.Conn)
   114  		if err != nil {
   115  			return 0, trace.Wrap(err)
   116  		}
   117  
   118  		switch frame.Header.OpCode {
   119  		case ws.OpClose:
   120  			return 0, io.EOF
   121  		case ws.OpPing:
   122  			pong := ws.NewPongFrame(frame.Payload)
   123  			if err := c.writeFrame(pong); err != nil {
   124  				return 0, trace.Wrap(err)
   125  			}
   126  		case ws.OpBinary:
   127  			c.readBuffer = frame.Payload
   128  			return c.readLocked(b)
   129  		}
   130  	}
   131  }
   132  
   133  func (c *websocketALPNClientConn) Write(b []byte) (int, error) {
   134  	frame := ws.NewFrame(ws.OpBinary, true, b)
   135  	return len(b), trace.Wrap(c.writeFrame(frame))
   136  }
   137  
   138  func (c *websocketALPNClientConn) writeFrame(frame ws.Frame) error {
   139  	c.writeMutex.Lock()
   140  	defer c.writeMutex.Unlock()
   141  	// By RFC standard, all client frames must be masked:
   142  	// https://datatracker.ietf.org/doc/html/rfc6455#section-5.1
   143  	frame.Header.Masked = true
   144  	return trace.Wrap(ws.WriteFrame(c.Conn, frame))
   145  }
   146  
   147  const (
   148  	websocketHeaderKeyProtocol     = "Sec-WebSocket-Protocol"
   149  	websocketHeaderKeyVersion      = "Sec-WebSocket-Version"
   150  	websocketHeaderKeyChallengeKey = "Sec-WebSocket-Key"
   151  	websocketHeaderKeyAccept       = "Sec-WebSocket-Accept"
   152  
   153  	// websocketAcceptKeyMagicString is the magic string used for computing
   154  	// the accept key during WebSocket handshake.
   155  	//
   156  	// RFC reference:
   157  	// https://www.rfc-editor.org/rfc/rfc6455
   158  	//
   159  	// Server side uses gorilla:
   160  	// https://github.com/gorilla/websocket/blob/dcea2f088ce10b1b0722c4eb995a4e145b5e9047/util.go#L17-L24
   161  	websocketAcceptKeyMagicString = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
   162  )