github.com/psiphon-Labs/psiphon-tunnel-core@v2.0.28+incompatible/psiphon/common/crypto/ssh/channel.go (about)

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package ssh
     6  
     7  import (
     8  	"encoding/binary"
     9  	"errors"
    10  	"fmt"
    11  	"io"
    12  	"log"
    13  	"sync"
    14  )
    15  
    16  const (
    17  	minPacketLength = 9
    18  	// channelMaxPacket contains the maximum number of bytes that will be
    19  	// sent in a single packet. As per RFC 4253, section 6.1, 32k is also
    20  	// the minimum.
    21  	channelMaxPacket = 1 << 15
    22  
    23  	// [Psiphon]
    24  	//
    25  	// Use getChannelWindowSize() instead of channelWindowSize.
    26  	//
    27  	// We follow OpenSSH here.
    28  	//channelWindowSize = 64 * channelMaxPacket
    29  )
    30  
    31  // [Psiphon]
    32  //
    33  // - Use a smaller initial/max channel window size.
    34  // - Testing with the full Psiphon stack shows that
    35  //   this smaller channel window size is more performant
    36  //   for low bandwidth connections while still adequate for
    37  //   higher bandwidth connections.
    38  // - In Psiphon, a single SSH connection is used for all
    39  //   client port forwards. Bulk data transfers with large
    40  //   channel windows can immediately backlog the connection
    41  //   with many large SSH packets, introducing large latency
    42  //   for opening new channels. For Psiphon, we don't wish to
    43  //   optimize for a single bulk transfer throughput.
    44  // - TODO: can we implement some sort of adaptive max
    45  //   channel window size, starting with this small initial
    46  //   value and only growing based on connection properties?
    47  // - channelWindowSize directly defines the local channel
    48  //   window initial and max size. We also cap remote channel
    49  //   window sizes via an extra customization in the
    50  //   channelOpenConfirmMsg handler. Both upstream and
    51  //   downstream bulk data transfers have the same latency
    52  //   issue.
    53  // - For packet tunnel, use a larger channel window size,
    54  //   since all tunneled traffic flows through a single
    55  //   channel; we still select a size smaller than the stock
    56  //   channelWindowSize due to client memory constraints.
    57  func getChannelWindowSize(chanType string) int {
    58  
    59  	// From "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol".
    60  	// Copied here to avoid import cycle.
    61  	packetTunnelChannelType := "tun@psiphon.ca"
    62  
    63  	if chanType == packetTunnelChannelType {
    64  		return 16 * channelMaxPacket
    65  	}
    66  	return 4 * channelMaxPacket
    67  }
    68  
    69  // NewChannel represents an incoming request to a channel. It must either be
    70  // accepted for use by calling Accept, or rejected by calling Reject.
    71  type NewChannel interface {
    72  	// Accept accepts the channel creation request. It returns the Channel
    73  	// and a Go channel containing SSH requests. The Go channel must be
    74  	// serviced otherwise the Channel will hang.
    75  	Accept() (Channel, <-chan *Request, error)
    76  
    77  	// Reject rejects the channel creation request. After calling
    78  	// this, no other methods on the Channel may be called.
    79  	Reject(reason RejectionReason, message string) error
    80  
    81  	// ChannelType returns the type of the channel, as supplied by the
    82  	// client.
    83  	ChannelType() string
    84  
    85  	// ExtraData returns the arbitrary payload for this channel, as supplied
    86  	// by the client. This data is specific to the channel type.
    87  	ExtraData() []byte
    88  }
    89  
    90  // A Channel is an ordered, reliable, flow-controlled, duplex stream
    91  // that is multiplexed over an SSH connection.
    92  type Channel interface {
    93  	// Read reads up to len(data) bytes from the channel.
    94  	Read(data []byte) (int, error)
    95  
    96  	// Write writes len(data) bytes to the channel.
    97  	Write(data []byte) (int, error)
    98  
    99  	// Close signals end of channel use. No data may be sent after this
   100  	// call.
   101  	Close() error
   102  
   103  	// CloseWrite signals the end of sending in-band
   104  	// data. Requests may still be sent, and the other side may
   105  	// still send data
   106  	CloseWrite() error
   107  
   108  	// SendRequest sends a channel request.  If wantReply is true,
   109  	// it will wait for a reply and return the result as a
   110  	// boolean, otherwise the return value will be false. Channel
   111  	// requests are out-of-band messages so they may be sent even
   112  	// if the data stream is closed or blocked by flow control.
   113  	// If the channel is closed before a reply is returned, io.EOF
   114  	// is returned.
   115  	SendRequest(name string, wantReply bool, payload []byte) (bool, error)
   116  
   117  	// Stderr returns an io.ReadWriter that writes to this channel
   118  	// with the extended data type set to stderr. Stderr may
   119  	// safely be read and written from a different goroutine than
   120  	// Read and Write respectively.
   121  	Stderr() io.ReadWriter
   122  }
   123  
   124  // Request is a request sent outside of the normal stream of
   125  // data. Requests can either be specific to an SSH channel, or they
   126  // can be global.
   127  type Request struct {
   128  	Type      string
   129  	WantReply bool
   130  	Payload   []byte
   131  
   132  	ch  *channel
   133  	mux *mux
   134  }
   135  
   136  // Reply sends a response to a request. It must be called for all requests
   137  // where WantReply is true and is a no-op otherwise. The payload argument is
   138  // ignored for replies to channel-specific requests.
   139  func (r *Request) Reply(ok bool, payload []byte) error {
   140  	if !r.WantReply {
   141  		return nil
   142  	}
   143  
   144  	if r.ch == nil {
   145  		return r.mux.ackRequest(ok, payload)
   146  	}
   147  
   148  	return r.ch.ackRequest(ok)
   149  }
   150  
   151  // RejectionReason is an enumeration used when rejecting channel creation
   152  // requests. See RFC 4254, section 5.1.
   153  type RejectionReason uint32
   154  
   155  const (
   156  	Prohibited RejectionReason = iota + 1
   157  	ConnectionFailed
   158  	UnknownChannelType
   159  	ResourceShortage
   160  )
   161  
   162  // String converts the rejection reason to human readable form.
   163  func (r RejectionReason) String() string {
   164  	switch r {
   165  	case Prohibited:
   166  		return "administratively prohibited"
   167  	case ConnectionFailed:
   168  		return "connect failed"
   169  	case UnknownChannelType:
   170  		return "unknown channel type"
   171  	case ResourceShortage:
   172  		return "resource shortage"
   173  	}
   174  	return fmt.Sprintf("unknown reason %d", int(r))
   175  }
   176  
   177  func min(a uint32, b int) uint32 {
   178  	if a < uint32(b) {
   179  		return a
   180  	}
   181  	return uint32(b)
   182  }
   183  
   184  type channelDirection uint8
   185  
   186  const (
   187  	channelInbound channelDirection = iota
   188  	channelOutbound
   189  )
   190  
   191  // channel is an implementation of the Channel interface that works
   192  // with the mux class.
   193  type channel struct {
   194  	// R/O after creation
   195  	chanType          string
   196  	extraData         []byte
   197  	localId, remoteId uint32
   198  
   199  	// maxIncomingPayload and maxRemotePayload are the maximum
   200  	// payload sizes of normal and extended data packets for
   201  	// receiving and sending, respectively. The wire packet will
   202  	// be 9 or 13 bytes larger (excluding encryption overhead).
   203  	maxIncomingPayload uint32
   204  	maxRemotePayload   uint32
   205  
   206  	mux *mux
   207  
   208  	// decided is set to true if an accept or reject message has been sent
   209  	// (for outbound channels) or received (for inbound channels).
   210  	decided bool
   211  
   212  	// direction contains either channelOutbound, for channels created
   213  	// locally, or channelInbound, for channels created by the peer.
   214  	direction channelDirection
   215  
   216  	// Pending internal channel messages.
   217  	msg chan interface{}
   218  
   219  	// Since requests have no ID, there can be only one request
   220  	// with WantReply=true outstanding.  This lock is held by a
   221  	// goroutine that has such an outgoing request pending.
   222  	sentRequestMu sync.Mutex
   223  
   224  	incomingRequests chan *Request
   225  
   226  	sentEOF bool
   227  
   228  	// thread-safe data
   229  	remoteWin  window
   230  	pending    *buffer
   231  	extPending *buffer
   232  
   233  	// windowMu protects myWindow, the flow-control window.
   234  	windowMu sync.Mutex
   235  	myWindow uint32
   236  
   237  	// writeMu serializes calls to mux.conn.writePacket() and
   238  	// protects sentClose and packetPool. This mutex must be
   239  	// different from windowMu, as writePacket can block if there
   240  	// is a key exchange pending.
   241  	writeMu   sync.Mutex
   242  	sentClose bool
   243  
   244  	// packetPool has a buffer for each extended channel ID to
   245  	// save allocations during writes.
   246  	packetPool map[uint32][]byte
   247  }
   248  
   249  // writePacket sends a packet. If the packet is a channel close, it updates
   250  // sentClose. This method takes the lock c.writeMu.
   251  func (ch *channel) writePacket(packet []byte) error {
   252  	ch.writeMu.Lock()
   253  	if ch.sentClose {
   254  		ch.writeMu.Unlock()
   255  		return io.EOF
   256  	}
   257  	ch.sentClose = (packet[0] == msgChannelClose)
   258  	err := ch.mux.conn.writePacket(packet)
   259  	ch.writeMu.Unlock()
   260  	return err
   261  }
   262  
   263  func (ch *channel) sendMessage(msg interface{}) error {
   264  	if debugMux {
   265  		log.Printf("send(%d): %#v", ch.mux.chanList.offset, msg)
   266  	}
   267  
   268  	p := Marshal(msg)
   269  	binary.BigEndian.PutUint32(p[1:], ch.remoteId)
   270  	return ch.writePacket(p)
   271  }
   272  
   273  // WriteExtended writes data to a specific extended stream. These streams are
   274  // used, for example, for stderr.
   275  func (ch *channel) WriteExtended(data []byte, extendedCode uint32) (n int, err error) {
   276  	if ch.sentEOF {
   277  		return 0, io.EOF
   278  	}
   279  	// 1 byte message type, 4 bytes remoteId, 4 bytes data length
   280  	opCode := byte(msgChannelData)
   281  	headerLength := uint32(9)
   282  	if extendedCode > 0 {
   283  		headerLength += 4
   284  		opCode = msgChannelExtendedData
   285  	}
   286  
   287  	ch.writeMu.Lock()
   288  	packet := ch.packetPool[extendedCode]
   289  	// We don't remove the buffer from packetPool, so
   290  	// WriteExtended calls from different goroutines will be
   291  	// flagged as errors by the race detector.
   292  	ch.writeMu.Unlock()
   293  
   294  	for len(data) > 0 {
   295  		space := min(ch.maxRemotePayload, len(data))
   296  		if space, err = ch.remoteWin.reserve(space); err != nil {
   297  			return n, err
   298  		}
   299  		if want := headerLength + space; uint32(cap(packet)) < want {
   300  			packet = make([]byte, want)
   301  		} else {
   302  			packet = packet[:want]
   303  		}
   304  
   305  		todo := data[:space]
   306  
   307  		packet[0] = opCode
   308  		binary.BigEndian.PutUint32(packet[1:], ch.remoteId)
   309  		if extendedCode > 0 {
   310  			binary.BigEndian.PutUint32(packet[5:], uint32(extendedCode))
   311  		}
   312  		binary.BigEndian.PutUint32(packet[headerLength-4:], uint32(len(todo)))
   313  		copy(packet[headerLength:], todo)
   314  		if err = ch.writePacket(packet); err != nil {
   315  			return n, err
   316  		}
   317  
   318  		n += len(todo)
   319  		data = data[len(todo):]
   320  	}
   321  
   322  	ch.writeMu.Lock()
   323  	ch.packetPool[extendedCode] = packet
   324  	ch.writeMu.Unlock()
   325  
   326  	return n, err
   327  }
   328  
   329  func (ch *channel) handleData(packet []byte) error {
   330  	headerLen := 9
   331  	isExtendedData := packet[0] == msgChannelExtendedData
   332  	if isExtendedData {
   333  		headerLen = 13
   334  	}
   335  	if len(packet) < headerLen {
   336  		// malformed data packet
   337  		return parseError(packet[0])
   338  	}
   339  
   340  	var extended uint32
   341  	if isExtendedData {
   342  		extended = binary.BigEndian.Uint32(packet[5:])
   343  	}
   344  
   345  	length := binary.BigEndian.Uint32(packet[headerLen-4 : headerLen])
   346  	if length == 0 {
   347  		return nil
   348  	}
   349  	if length > ch.maxIncomingPayload {
   350  		// TODO(hanwen): should send Disconnect?
   351  		return errors.New("ssh: incoming packet exceeds maximum payload size")
   352  	}
   353  
   354  	data := packet[headerLen:]
   355  	if length != uint32(len(data)) {
   356  		return errors.New("ssh: wrong packet length")
   357  	}
   358  
   359  	ch.windowMu.Lock()
   360  	if ch.myWindow < length {
   361  		ch.windowMu.Unlock()
   362  		// TODO(hanwen): should send Disconnect with reason?
   363  		return errors.New("ssh: remote side wrote too much")
   364  	}
   365  	ch.myWindow -= length
   366  	ch.windowMu.Unlock()
   367  
   368  	if extended == 1 {
   369  		ch.extPending.write(data)
   370  	} else if extended > 0 {
   371  		// discard other extended data.
   372  	} else {
   373  		ch.pending.write(data)
   374  	}
   375  	return nil
   376  }
   377  
   378  func (c *channel) adjustWindow(n uint32) error {
   379  	c.windowMu.Lock()
   380  	// Since myWindow is managed on our side, and can never exceed
   381  	// the initial window setting, we don't worry about overflow.
   382  	c.myWindow += uint32(n)
   383  	c.windowMu.Unlock()
   384  	return c.sendMessage(windowAdjustMsg{
   385  		AdditionalBytes: uint32(n),
   386  	})
   387  }
   388  
   389  func (c *channel) ReadExtended(data []byte, extended uint32) (n int, err error) {
   390  	switch extended {
   391  	case 1:
   392  		n, err = c.extPending.Read(data)
   393  	case 0:
   394  		n, err = c.pending.Read(data)
   395  	default:
   396  		return 0, fmt.Errorf("ssh: extended code %d unimplemented", extended)
   397  	}
   398  
   399  	if n > 0 {
   400  		err = c.adjustWindow(uint32(n))
   401  		// sendWindowAdjust can return io.EOF if the remote
   402  		// peer has closed the connection, however we want to
   403  		// defer forwarding io.EOF to the caller of Read until
   404  		// the buffer has been drained.
   405  		if n > 0 && err == io.EOF {
   406  			err = nil
   407  		}
   408  	}
   409  
   410  	return n, err
   411  }
   412  
   413  func (c *channel) close() {
   414  	c.pending.eof()
   415  	c.extPending.eof()
   416  	close(c.msg)
   417  	close(c.incomingRequests)
   418  	c.writeMu.Lock()
   419  	// This is not necessary for a normal channel teardown, but if
   420  	// there was another error, it is.
   421  	c.sentClose = true
   422  	c.writeMu.Unlock()
   423  	// Unblock writers.
   424  	c.remoteWin.close()
   425  }
   426  
   427  // responseMessageReceived is called when a success or failure message is
   428  // received on a channel to check that such a message is reasonable for the
   429  // given channel.
   430  func (ch *channel) responseMessageReceived() error {
   431  	if ch.direction == channelInbound {
   432  		return errors.New("ssh: channel response message received on inbound channel")
   433  	}
   434  	if ch.decided {
   435  		return errors.New("ssh: duplicate response received for channel")
   436  	}
   437  	ch.decided = true
   438  	return nil
   439  }
   440  
   441  func (ch *channel) handlePacket(packet []byte) error {
   442  	switch packet[0] {
   443  	case msgChannelData, msgChannelExtendedData:
   444  		return ch.handleData(packet)
   445  	case msgChannelClose:
   446  		ch.sendMessage(channelCloseMsg{PeersID: ch.remoteId})
   447  		ch.mux.chanList.remove(ch.localId)
   448  		ch.close()
   449  		return nil
   450  	case msgChannelEOF:
   451  		// RFC 4254 is mute on how EOF affects dataExt messages but
   452  		// it is logical to signal EOF at the same time.
   453  		ch.extPending.eof()
   454  		ch.pending.eof()
   455  		return nil
   456  	}
   457  
   458  	decoded, err := decode(packet)
   459  	if err != nil {
   460  		return err
   461  	}
   462  
   463  	switch msg := decoded.(type) {
   464  	case *channelOpenFailureMsg:
   465  		if err := ch.responseMessageReceived(); err != nil {
   466  			return err
   467  		}
   468  		ch.mux.chanList.remove(msg.PeersID)
   469  		ch.msg <- msg
   470  	case *channelOpenConfirmMsg:
   471  		if err := ch.responseMessageReceived(); err != nil {
   472  			return err
   473  		}
   474  		if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 {
   475  			return fmt.Errorf("ssh: invalid MaxPacketSize %d from peer", msg.MaxPacketSize)
   476  		}
   477  		ch.remoteId = msg.MyID
   478  		ch.maxRemotePayload = msg.MaxPacketSize
   479  
   480  		// [Psiphon]
   481  		//
   482  		// - Use a smaller initial/max channel window size.
   483  		// - See comments above channelWindowSize definition.
   484  
   485  		//ch.remoteWin.add(msg.MyWindow)
   486  		ch.remoteWin.add(min(msg.MyWindow, getChannelWindowSize(ch.chanType)))
   487  
   488  		ch.msg <- msg
   489  	case *windowAdjustMsg:
   490  		if !ch.remoteWin.add(msg.AdditionalBytes) {
   491  			return fmt.Errorf("ssh: invalid window update for %d bytes", msg.AdditionalBytes)
   492  		}
   493  	case *channelRequestMsg:
   494  		req := Request{
   495  			Type:      msg.Request,
   496  			WantReply: msg.WantReply,
   497  			Payload:   msg.RequestSpecificData,
   498  			ch:        ch,
   499  		}
   500  
   501  		ch.incomingRequests <- &req
   502  	default:
   503  		ch.msg <- msg
   504  	}
   505  	return nil
   506  }
   507  
   508  func (m *mux) newChannel(chanType string, direction channelDirection, extraData []byte) *channel {
   509  	ch := &channel{
   510  		remoteWin:        window{Cond: newCond()},
   511  		myWindow:         uint32(getChannelWindowSize(chanType)),
   512  		pending:          newBuffer(),
   513  		extPending:       newBuffer(),
   514  		direction:        direction,
   515  		incomingRequests: make(chan *Request, chanSize),
   516  		msg:              make(chan interface{}, chanSize),
   517  		chanType:         chanType,
   518  		extraData:        extraData,
   519  		mux:              m,
   520  		packetPool:       make(map[uint32][]byte),
   521  	}
   522  	ch.localId = m.chanList.add(ch)
   523  	return ch
   524  }
   525  
   526  var errUndecided = errors.New("ssh: must Accept or Reject channel")
   527  var errDecidedAlready = errors.New("ssh: can call Accept or Reject only once")
   528  
   529  type extChannel struct {
   530  	code uint32
   531  	ch   *channel
   532  }
   533  
   534  func (e *extChannel) Write(data []byte) (n int, err error) {
   535  	return e.ch.WriteExtended(data, e.code)
   536  }
   537  
   538  func (e *extChannel) Read(data []byte) (n int, err error) {
   539  	return e.ch.ReadExtended(data, e.code)
   540  }
   541  
   542  func (ch *channel) Accept() (Channel, <-chan *Request, error) {
   543  	if ch.decided {
   544  		return nil, nil, errDecidedAlready
   545  	}
   546  	ch.maxIncomingPayload = channelMaxPacket
   547  	confirm := channelOpenConfirmMsg{
   548  		PeersID:       ch.remoteId,
   549  		MyID:          ch.localId,
   550  		MyWindow:      ch.myWindow,
   551  		MaxPacketSize: ch.maxIncomingPayload,
   552  	}
   553  	ch.decided = true
   554  	if err := ch.sendMessage(confirm); err != nil {
   555  		return nil, nil, err
   556  	}
   557  
   558  	return ch, ch.incomingRequests, nil
   559  }
   560  
   561  func (ch *channel) Reject(reason RejectionReason, message string) error {
   562  	if ch.decided {
   563  		return errDecidedAlready
   564  	}
   565  	reject := channelOpenFailureMsg{
   566  		PeersID:  ch.remoteId,
   567  		Reason:   reason,
   568  		Message:  message,
   569  		Language: "en",
   570  	}
   571  	ch.decided = true
   572  
   573  	// [Psiphon]
   574  	// Don't leave reference to channel in chanList
   575  	ch.mux.chanList.remove(ch.localId)
   576  
   577  	return ch.sendMessage(reject)
   578  }
   579  
   580  func (ch *channel) Read(data []byte) (int, error) {
   581  	if !ch.decided {
   582  		return 0, errUndecided
   583  	}
   584  	return ch.ReadExtended(data, 0)
   585  }
   586  
   587  func (ch *channel) Write(data []byte) (int, error) {
   588  	if !ch.decided {
   589  		return 0, errUndecided
   590  	}
   591  	return ch.WriteExtended(data, 0)
   592  }
   593  
   594  func (ch *channel) CloseWrite() error {
   595  	if !ch.decided {
   596  		return errUndecided
   597  	}
   598  	ch.sentEOF = true
   599  	return ch.sendMessage(channelEOFMsg{
   600  		PeersID: ch.remoteId})
   601  }
   602  
   603  func (ch *channel) Close() error {
   604  	if !ch.decided {
   605  		return errUndecided
   606  	}
   607  
   608  	return ch.sendMessage(channelCloseMsg{
   609  		PeersID: ch.remoteId})
   610  }
   611  
   612  // Extended returns an io.ReadWriter that sends and receives data on the given,
   613  // SSH extended stream. Such streams are used, for example, for stderr.
   614  func (ch *channel) Extended(code uint32) io.ReadWriter {
   615  	if !ch.decided {
   616  		return nil
   617  	}
   618  	return &extChannel{code, ch}
   619  }
   620  
   621  func (ch *channel) Stderr() io.ReadWriter {
   622  	return ch.Extended(1)
   623  }
   624  
   625  func (ch *channel) SendRequest(name string, wantReply bool, payload []byte) (bool, error) {
   626  	if !ch.decided {
   627  		return false, errUndecided
   628  	}
   629  
   630  	if wantReply {
   631  		ch.sentRequestMu.Lock()
   632  		defer ch.sentRequestMu.Unlock()
   633  	}
   634  
   635  	msg := channelRequestMsg{
   636  		PeersID:             ch.remoteId,
   637  		Request:             name,
   638  		WantReply:           wantReply,
   639  		RequestSpecificData: payload,
   640  	}
   641  
   642  	if err := ch.sendMessage(msg); err != nil {
   643  		return false, err
   644  	}
   645  
   646  	if wantReply {
   647  		m, ok := (<-ch.msg)
   648  		if !ok {
   649  			return false, io.EOF
   650  		}
   651  		switch m.(type) {
   652  		case *channelRequestFailureMsg:
   653  			return false, nil
   654  		case *channelRequestSuccessMsg:
   655  			return true, nil
   656  		default:
   657  			return false, fmt.Errorf("ssh: unexpected response to channel request: %#v", m)
   658  		}
   659  	}
   660  
   661  	return false, nil
   662  }
   663  
   664  // ackRequest either sends an ack or nack to the channel request.
   665  func (ch *channel) ackRequest(ok bool) error {
   666  	if !ch.decided {
   667  		return errUndecided
   668  	}
   669  
   670  	var msg interface{}
   671  	if !ok {
   672  		msg = channelRequestFailureMsg{
   673  			PeersID: ch.remoteId,
   674  		}
   675  	} else {
   676  		msg = channelRequestSuccessMsg{
   677  			PeersID: ch.remoteId,
   678  		}
   679  	}
   680  	return ch.sendMessage(msg)
   681  }
   682  
   683  func (ch *channel) ChannelType() string {
   684  	return ch.chanType
   685  }
   686  
   687  func (ch *channel) ExtraData() []byte {
   688  	return ch.extraData
   689  }