k8s.io/apimachinery@v0.29.2/pkg/util/httpstream/wsstream/conn.go (about)

     1  /*
     2  Copyright 2015 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package wsstream
    18  
    19  import (
    20  	"encoding/base64"
    21  	"fmt"
    22  	"io"
    23  	"net/http"
    24  	"strings"
    25  	"time"
    26  
    27  	"golang.org/x/net/websocket"
    28  
    29  	"k8s.io/apimachinery/pkg/util/httpstream"
    30  	"k8s.io/apimachinery/pkg/util/remotecommand"
    31  	"k8s.io/apimachinery/pkg/util/runtime"
    32  	"k8s.io/klog/v2"
    33  )
    34  
    35  const WebSocketProtocolHeader = "Sec-Websocket-Protocol"
    36  
    37  // The Websocket subprotocol "channel.k8s.io" prepends each binary message with a byte indicating
    38  // the channel number (zero indexed) the message was sent on. Messages in both directions should
    39  // prefix their messages with this channel byte. When used for remote execution, the channel numbers
    40  // are by convention defined to match the POSIX file-descriptors assigned to STDIN, STDOUT, and STDERR
    41  // (0, 1, and 2). No other conversion is performed on the raw subprotocol - writes are sent as they
    42  // are received by the server.
    43  //
    44  // Example client session:
    45  //
    46  //	CONNECT http://server.com with subprotocol "channel.k8s.io"
    47  //	WRITE []byte{0, 102, 111, 111, 10} # send "foo\n" on channel 0 (STDIN)
    48  //	READ  []byte{1, 10}                # receive "\n" on channel 1 (STDOUT)
    49  //	CLOSE
    50  const ChannelWebSocketProtocol = "channel.k8s.io"
    51  
    52  // The Websocket subprotocol "base64.channel.k8s.io" base64 encodes each message with a character
    53  // indicating the channel number (zero indexed) the message was sent on. Messages in both directions
    54  // should prefix their messages with this channel char. When used for remote execution, the channel
    55  // numbers are by convention defined to match the POSIX file-descriptors assigned to STDIN, STDOUT,
    56  // and STDERR ('0', '1', and '2'). The data received on the server is base64 decoded (and must be
    57  // be valid) and data written by the server to the client is base64 encoded.
    58  //
    59  // Example client session:
    60  //
    61  //	CONNECT http://server.com with subprotocol "base64.channel.k8s.io"
    62  //	WRITE []byte{48, 90, 109, 57, 118, 67, 103, 111, 61} # send "foo\n" (base64: "Zm9vCgo=") on channel '0' (STDIN)
    63  //	READ  []byte{49, 67, 103, 61, 61} # receive "\n" (base64: "Cg==") on channel '1' (STDOUT)
    64  //	CLOSE
    65  const Base64ChannelWebSocketProtocol = "base64.channel.k8s.io"
    66  
    67  type codecType int
    68  
    69  const (
    70  	rawCodec codecType = iota
    71  	base64Codec
    72  )
    73  
    74  type ChannelType int
    75  
    76  const (
    77  	IgnoreChannel ChannelType = iota
    78  	ReadChannel
    79  	WriteChannel
    80  	ReadWriteChannel
    81  )
    82  
    83  // IsWebSocketRequest returns true if the incoming request contains connection upgrade headers
    84  // for WebSockets.
    85  func IsWebSocketRequest(req *http.Request) bool {
    86  	if !strings.EqualFold(req.Header.Get("Upgrade"), "websocket") {
    87  		return false
    88  	}
    89  	return httpstream.IsUpgradeRequest(req)
    90  }
    91  
    92  // IsWebSocketRequestWithStreamCloseProtocol returns true if the request contains headers
    93  // identifying that it is requesting a websocket upgrade with a remotecommand protocol
    94  // version that supports the "CLOSE" signal; false otherwise.
    95  func IsWebSocketRequestWithStreamCloseProtocol(req *http.Request) bool {
    96  	if !IsWebSocketRequest(req) {
    97  		return false
    98  	}
    99  	requestedProtocols := strings.TrimSpace(req.Header.Get(WebSocketProtocolHeader))
   100  	for _, requestedProtocol := range strings.Split(requestedProtocols, ",") {
   101  		if protocolSupportsStreamClose(strings.TrimSpace(requestedProtocol)) {
   102  			return true
   103  		}
   104  	}
   105  
   106  	return false
   107  }
   108  
   109  // IgnoreReceives reads from a WebSocket until it is closed, then returns. If timeout is set, the
   110  // read and write deadlines are pushed every time a new message is received.
   111  func IgnoreReceives(ws *websocket.Conn, timeout time.Duration) {
   112  	defer runtime.HandleCrash()
   113  	var data []byte
   114  	for {
   115  		resetTimeout(ws, timeout)
   116  		if err := websocket.Message.Receive(ws, &data); err != nil {
   117  			return
   118  		}
   119  	}
   120  }
   121  
   122  // handshake ensures the provided user protocol matches one of the allowed protocols. It returns
   123  // no error if no protocol is specified.
   124  func handshake(config *websocket.Config, req *http.Request, allowed []string) error {
   125  	protocols := config.Protocol
   126  	if len(protocols) == 0 {
   127  		protocols = []string{""}
   128  	}
   129  
   130  	for _, protocol := range protocols {
   131  		for _, allow := range allowed {
   132  			if allow == protocol {
   133  				config.Protocol = []string{protocol}
   134  				return nil
   135  			}
   136  		}
   137  	}
   138  
   139  	return fmt.Errorf("requested protocol(s) are not supported: %v; supports %v", config.Protocol, allowed)
   140  }
   141  
   142  // ChannelProtocolConfig describes a websocket subprotocol with channels.
   143  type ChannelProtocolConfig struct {
   144  	Binary   bool
   145  	Channels []ChannelType
   146  }
   147  
   148  // NewDefaultChannelProtocols returns a channel protocol map with the
   149  // subprotocols "", "channel.k8s.io", "base64.channel.k8s.io" and the given
   150  // channels.
   151  func NewDefaultChannelProtocols(channels []ChannelType) map[string]ChannelProtocolConfig {
   152  	return map[string]ChannelProtocolConfig{
   153  		"":                             {Binary: true, Channels: channels},
   154  		ChannelWebSocketProtocol:       {Binary: true, Channels: channels},
   155  		Base64ChannelWebSocketProtocol: {Binary: false, Channels: channels},
   156  	}
   157  }
   158  
   159  // Conn supports sending multiple binary channels over a websocket connection.
   160  type Conn struct {
   161  	protocols        map[string]ChannelProtocolConfig
   162  	selectedProtocol string
   163  	channels         []*websocketChannel
   164  	codec            codecType
   165  	ready            chan struct{}
   166  	ws               *websocket.Conn
   167  	timeout          time.Duration
   168  }
   169  
   170  // NewConn creates a WebSocket connection that supports a set of channels. Channels begin each
   171  // web socket message with a single byte indicating the channel number (0-N). 255 is reserved for
   172  // future use. The channel types for each channel are passed as an array, supporting the different
   173  // duplex modes. Read and Write refer to whether the channel can be used as a Reader or Writer.
   174  //
   175  // The protocols parameter maps subprotocol names to ChannelProtocols. The empty string subprotocol
   176  // name is used if websocket.Config.Protocol is empty.
   177  func NewConn(protocols map[string]ChannelProtocolConfig) *Conn {
   178  	return &Conn{
   179  		ready:     make(chan struct{}),
   180  		protocols: protocols,
   181  	}
   182  }
   183  
   184  // SetIdleTimeout sets the interval for both reads and writes before timeout. If not specified,
   185  // there is no timeout on the connection.
   186  func (conn *Conn) SetIdleTimeout(duration time.Duration) {
   187  	conn.timeout = duration
   188  }
   189  
   190  // SetWriteDeadline sets a timeout on writing to the websocket connection. The
   191  // passed "duration" identifies how far into the future the write must complete
   192  // by before the timeout fires.
   193  func (conn *Conn) SetWriteDeadline(duration time.Duration) {
   194  	conn.ws.SetWriteDeadline(time.Now().Add(duration)) //nolint:errcheck
   195  }
   196  
   197  // Open the connection and create channels for reading and writing. It returns
   198  // the selected subprotocol, a slice of channels and an error.
   199  func (conn *Conn) Open(w http.ResponseWriter, req *http.Request) (string, []io.ReadWriteCloser, error) {
   200  	// serveHTTPComplete is channel that is closed/selected when "websocket#ServeHTTP" finishes.
   201  	serveHTTPComplete := make(chan struct{})
   202  	// Ensure panic in spawned goroutine is propagated into the parent goroutine.
   203  	panicChan := make(chan any, 1)
   204  	go func() {
   205  		// If websocket server returns, propagate panic if necessary. Otherwise,
   206  		// signal HTTPServe finished by closing "serveHTTPComplete".
   207  		defer func() {
   208  			if p := recover(); p != nil {
   209  				panicChan <- p
   210  			} else {
   211  				close(serveHTTPComplete)
   212  			}
   213  		}()
   214  		websocket.Server{Handshake: conn.handshake, Handler: conn.handle}.ServeHTTP(w, req)
   215  	}()
   216  
   217  	// In normal circumstances, "websocket.Server#ServeHTTP" calls "initialize" which closes
   218  	// "conn.ready" and then blocks until serving is complete.
   219  	select {
   220  	case <-conn.ready:
   221  		klog.V(8).Infof("websocket server initialized--serving")
   222  	case <-serveHTTPComplete:
   223  		// websocket server returned before completing initialization; cleanup and return error.
   224  		conn.closeNonThreadSafe() //nolint:errcheck
   225  		return "", nil, fmt.Errorf("websocket server finished before becoming ready")
   226  	case p := <-panicChan:
   227  		panic(p)
   228  	}
   229  
   230  	rwc := make([]io.ReadWriteCloser, len(conn.channels))
   231  	for i := range conn.channels {
   232  		rwc[i] = conn.channels[i]
   233  	}
   234  	return conn.selectedProtocol, rwc, nil
   235  }
   236  
   237  func (conn *Conn) initialize(ws *websocket.Conn) {
   238  	negotiated := ws.Config().Protocol
   239  	conn.selectedProtocol = negotiated[0]
   240  	p := conn.protocols[conn.selectedProtocol]
   241  	if p.Binary {
   242  		conn.codec = rawCodec
   243  	} else {
   244  		conn.codec = base64Codec
   245  	}
   246  	conn.ws = ws
   247  	conn.channels = make([]*websocketChannel, len(p.Channels))
   248  	for i, t := range p.Channels {
   249  		switch t {
   250  		case ReadChannel:
   251  			conn.channels[i] = newWebsocketChannel(conn, byte(i), true, false)
   252  		case WriteChannel:
   253  			conn.channels[i] = newWebsocketChannel(conn, byte(i), false, true)
   254  		case ReadWriteChannel:
   255  			conn.channels[i] = newWebsocketChannel(conn, byte(i), true, true)
   256  		case IgnoreChannel:
   257  			conn.channels[i] = newWebsocketChannel(conn, byte(i), false, false)
   258  		}
   259  	}
   260  
   261  	close(conn.ready)
   262  }
   263  
   264  func (conn *Conn) handshake(config *websocket.Config, req *http.Request) error {
   265  	supportedProtocols := make([]string, 0, len(conn.protocols))
   266  	for p := range conn.protocols {
   267  		supportedProtocols = append(supportedProtocols, p)
   268  	}
   269  	return handshake(config, req, supportedProtocols)
   270  }
   271  
   272  func (conn *Conn) resetTimeout() {
   273  	if conn.timeout > 0 {
   274  		conn.ws.SetDeadline(time.Now().Add(conn.timeout))
   275  	}
   276  }
   277  
   278  // closeNonThreadSafe cleans up by closing streams and the websocket
   279  // connection *without* waiting for the "ready" channel.
   280  func (conn *Conn) closeNonThreadSafe() error {
   281  	for _, s := range conn.channels {
   282  		s.Close()
   283  	}
   284  	var err error
   285  	if conn.ws != nil {
   286  		err = conn.ws.Close()
   287  	}
   288  	return err
   289  }
   290  
   291  // Close is only valid after Open has been called
   292  func (conn *Conn) Close() error {
   293  	<-conn.ready
   294  	return conn.closeNonThreadSafe()
   295  }
   296  
   297  // protocolSupportsStreamClose returns true if the passed protocol
   298  // supports the stream close signal (currently only V5 remotecommand);
   299  // false otherwise.
   300  func protocolSupportsStreamClose(protocol string) bool {
   301  	return protocol == remotecommand.StreamProtocolV5Name
   302  }
   303  
   304  // handle implements a websocket handler.
   305  func (conn *Conn) handle(ws *websocket.Conn) {
   306  	conn.initialize(ws)
   307  	defer conn.Close()
   308  	supportsStreamClose := protocolSupportsStreamClose(conn.selectedProtocol)
   309  
   310  	for {
   311  		conn.resetTimeout()
   312  		var data []byte
   313  		if err := websocket.Message.Receive(ws, &data); err != nil {
   314  			if err != io.EOF {
   315  				klog.Errorf("Error on socket receive: %v", err)
   316  			}
   317  			break
   318  		}
   319  		if len(data) == 0 {
   320  			continue
   321  		}
   322  		if supportsStreamClose && data[0] == remotecommand.StreamClose {
   323  			if len(data) != 2 {
   324  				klog.Errorf("Single channel byte should follow stream close signal. Got %d bytes", len(data)-1)
   325  				break
   326  			} else {
   327  				channel := data[1]
   328  				if int(channel) >= len(conn.channels) {
   329  					klog.Errorf("Close is targeted for a channel %d that is not valid, possible protocol error", channel)
   330  					break
   331  				}
   332  				klog.V(4).Infof("Received half-close signal from client; close %d stream", channel)
   333  				conn.channels[channel].Close() // After first Close, other closes are noop.
   334  			}
   335  			continue
   336  		}
   337  		channel := data[0]
   338  		if conn.codec == base64Codec {
   339  			channel = channel - '0'
   340  		}
   341  		data = data[1:]
   342  		if int(channel) >= len(conn.channels) {
   343  			klog.V(6).Infof("Frame is targeted for a reader %d that is not valid, possible protocol error", channel)
   344  			continue
   345  		}
   346  		if _, err := conn.channels[channel].DataFromSocket(data); err != nil {
   347  			klog.Errorf("Unable to write frame to %d: %v\n%s", channel, err, string(data))
   348  			continue
   349  		}
   350  	}
   351  }
   352  
   353  // write multiplexes the specified channel onto the websocket
   354  func (conn *Conn) write(num byte, data []byte) (int, error) {
   355  	conn.resetTimeout()
   356  	switch conn.codec {
   357  	case rawCodec:
   358  		frame := make([]byte, len(data)+1)
   359  		frame[0] = num
   360  		copy(frame[1:], data)
   361  		if err := websocket.Message.Send(conn.ws, frame); err != nil {
   362  			return 0, err
   363  		}
   364  	case base64Codec:
   365  		frame := string('0'+num) + base64.StdEncoding.EncodeToString(data)
   366  		if err := websocket.Message.Send(conn.ws, frame); err != nil {
   367  			return 0, err
   368  		}
   369  	}
   370  	return len(data), nil
   371  }
   372  
   373  // websocketChannel represents a channel in a connection
   374  type websocketChannel struct {
   375  	conn *Conn
   376  	num  byte
   377  	r    io.Reader
   378  	w    io.WriteCloser
   379  
   380  	read, write bool
   381  }
   382  
   383  // newWebsocketChannel creates a pipe for writing to a websocket. Do not write to this pipe
   384  // prior to the connection being opened. It may be no, half, or full duplex depending on
   385  // read and write.
   386  func newWebsocketChannel(conn *Conn, num byte, read, write bool) *websocketChannel {
   387  	r, w := io.Pipe()
   388  	return &websocketChannel{conn, num, r, w, read, write}
   389  }
   390  
   391  func (p *websocketChannel) Write(data []byte) (int, error) {
   392  	if !p.write {
   393  		return len(data), nil
   394  	}
   395  	return p.conn.write(p.num, data)
   396  }
   397  
   398  // DataFromSocket is invoked by the connection receiver to move data from the connection
   399  // into a specific channel.
   400  func (p *websocketChannel) DataFromSocket(data []byte) (int, error) {
   401  	if !p.read {
   402  		return len(data), nil
   403  	}
   404  
   405  	switch p.conn.codec {
   406  	case rawCodec:
   407  		return p.w.Write(data)
   408  	case base64Codec:
   409  		dst := make([]byte, len(data))
   410  		n, err := base64.StdEncoding.Decode(dst, data)
   411  		if err != nil {
   412  			return 0, err
   413  		}
   414  		return p.w.Write(dst[:n])
   415  	}
   416  	return 0, nil
   417  }
   418  
   419  func (p *websocketChannel) Read(data []byte) (int, error) {
   420  	if !p.read {
   421  		return 0, io.EOF
   422  	}
   423  	return p.r.Read(data)
   424  }
   425  
   426  func (p *websocketChannel) Close() error {
   427  	return p.w.Close()
   428  }