github.com/psiphon-labs/psiphon-tunnel-core@v2.0.28+incompatible/psiphon/common/crypto/ssh/mux.go (about)

     1  // Copyright 2013 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package ssh
     6  
     7  import (
     8  	"encoding/binary"
     9  	"fmt"
    10  	"io"
    11  	"log"
    12  	"sync"
    13  	"sync/atomic"
    14  )
    15  
    16  // debugMux, if set, causes messages in the connection protocol to be
    17  // logged.
    18  const debugMux = false
    19  
    20  // chanList is a thread safe channel list.
    21  type chanList struct {
    22  	// protects concurrent access to chans
    23  	sync.Mutex
    24  
    25  	// chans are indexed by the local id of the channel, which the
    26  	// other side should send in the PeersId field.
    27  	chans []*channel
    28  
    29  	// This is a debugging aid: it offsets all IDs by this
    30  	// amount. This helps distinguish otherwise identical
    31  	// server/client muxes
    32  	offset uint32
    33  }
    34  
    35  // Assigns a channel ID to the given channel.
    36  func (c *chanList) add(ch *channel) uint32 {
    37  	c.Lock()
    38  	defer c.Unlock()
    39  	for i := range c.chans {
    40  		if c.chans[i] == nil {
    41  			c.chans[i] = ch
    42  			return uint32(i) + c.offset
    43  		}
    44  	}
    45  	c.chans = append(c.chans, ch)
    46  	return uint32(len(c.chans)-1) + c.offset
    47  }
    48  
    49  // getChan returns the channel for the given ID.
    50  func (c *chanList) getChan(id uint32) *channel {
    51  	id -= c.offset
    52  
    53  	c.Lock()
    54  	defer c.Unlock()
    55  	if id < uint32(len(c.chans)) {
    56  		return c.chans[id]
    57  	}
    58  	return nil
    59  }
    60  
    61  func (c *chanList) remove(id uint32) {
    62  	id -= c.offset
    63  	c.Lock()
    64  	if id < uint32(len(c.chans)) {
    65  		c.chans[id] = nil
    66  	}
    67  	c.Unlock()
    68  }
    69  
    70  // dropAll forgets all channels it knows, returning them in a slice.
    71  func (c *chanList) dropAll() []*channel {
    72  	c.Lock()
    73  	defer c.Unlock()
    74  	var r []*channel
    75  
    76  	for _, ch := range c.chans {
    77  		if ch == nil {
    78  			continue
    79  		}
    80  		r = append(r, ch)
    81  	}
    82  	c.chans = nil
    83  	return r
    84  }
    85  
    86  // mux represents the state for the SSH connection protocol, which
    87  // multiplexes many channels onto a single packet transport.
    88  type mux struct {
    89  	conn     packetConn
    90  	chanList chanList
    91  
    92  	incomingChannels chan NewChannel
    93  
    94  	globalSentMu     sync.Mutex
    95  	globalResponses  chan interface{}
    96  	incomingRequests chan *Request
    97  
    98  	errCond *sync.Cond
    99  	err     error
   100  }
   101  
   102  // When debugging, each new chanList instantiation has a different
   103  // offset.
   104  var globalOff uint32
   105  
   106  func (m *mux) Wait() error {
   107  	m.errCond.L.Lock()
   108  	defer m.errCond.L.Unlock()
   109  	for m.err == nil {
   110  		m.errCond.Wait()
   111  	}
   112  	return m.err
   113  }
   114  
   115  // newMux returns a mux that runs over the given connection.
   116  func newMux(p packetConn) *mux {
   117  	m := &mux{
   118  		conn:             p,
   119  		incomingChannels: make(chan NewChannel, chanSize),
   120  		globalResponses:  make(chan interface{}, 1),
   121  		incomingRequests: make(chan *Request, chanSize),
   122  		errCond:          newCond(),
   123  	}
   124  	if debugMux {
   125  		m.chanList.offset = atomic.AddUint32(&globalOff, 1)
   126  	}
   127  
   128  	go m.loop()
   129  	return m
   130  }
   131  
   132  func (m *mux) sendMessage(msg interface{}) error {
   133  	p := Marshal(msg)
   134  	if debugMux {
   135  		log.Printf("send global(%d): %#v", m.chanList.offset, msg)
   136  	}
   137  	return m.conn.writePacket(p)
   138  }
   139  
   140  func (m *mux) SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error) {
   141  	if wantReply {
   142  		m.globalSentMu.Lock()
   143  		defer m.globalSentMu.Unlock()
   144  	}
   145  
   146  	if err := m.sendMessage(globalRequestMsg{
   147  		Type:      name,
   148  		WantReply: wantReply,
   149  		Data:      payload,
   150  	}); err != nil {
   151  		return false, nil, err
   152  	}
   153  
   154  	if !wantReply {
   155  		return false, nil, nil
   156  	}
   157  
   158  	msg, ok := <-m.globalResponses
   159  	if !ok {
   160  		return false, nil, io.EOF
   161  	}
   162  	switch msg := msg.(type) {
   163  	case *globalRequestFailureMsg:
   164  		return false, msg.Data, nil
   165  	case *globalRequestSuccessMsg:
   166  		return true, msg.Data, nil
   167  	default:
   168  		return false, nil, fmt.Errorf("ssh: unexpected response to request: %#v", msg)
   169  	}
   170  }
   171  
   172  // ackRequest must be called after processing a global request that
   173  // has WantReply set.
   174  func (m *mux) ackRequest(ok bool, data []byte) error {
   175  	if ok {
   176  		return m.sendMessage(globalRequestSuccessMsg{Data: data})
   177  	}
   178  	return m.sendMessage(globalRequestFailureMsg{Data: data})
   179  }
   180  
   181  func (m *mux) Close() error {
   182  	return m.conn.Close()
   183  }
   184  
   185  // loop runs the connection machine. It will process packets until an
   186  // error is encountered. To synchronize on loop exit, use mux.Wait.
   187  func (m *mux) loop() {
   188  	var err error
   189  	for err == nil {
   190  		err = m.onePacket()
   191  	}
   192  
   193  	for _, ch := range m.chanList.dropAll() {
   194  		ch.close()
   195  	}
   196  
   197  	close(m.incomingChannels)
   198  	close(m.incomingRequests)
   199  	close(m.globalResponses)
   200  
   201  	m.conn.Close()
   202  
   203  	m.errCond.L.Lock()
   204  	m.err = err
   205  	m.errCond.Broadcast()
   206  	m.errCond.L.Unlock()
   207  
   208  	if debugMux {
   209  		log.Println("loop exit", err)
   210  	}
   211  }
   212  
   213  // onePacket reads and processes one packet.
   214  func (m *mux) onePacket() error {
   215  	packet, err := m.conn.readPacket()
   216  	if err != nil {
   217  		return err
   218  	}
   219  
   220  	if debugMux {
   221  		if packet[0] == msgChannelData || packet[0] == msgChannelExtendedData {
   222  			log.Printf("decoding(%d): data packet - %d bytes", m.chanList.offset, len(packet))
   223  		} else {
   224  			p, _ := decode(packet)
   225  			log.Printf("decoding(%d): %d %#v - %d bytes", m.chanList.offset, packet[0], p, len(packet))
   226  		}
   227  	}
   228  
   229  	switch packet[0] {
   230  	case msgChannelOpen:
   231  		return m.handleChannelOpen(packet)
   232  	case msgGlobalRequest, msgRequestSuccess, msgRequestFailure:
   233  		return m.handleGlobalPacket(packet)
   234  	}
   235  
   236  	// assume a channel packet.
   237  	if len(packet) < 5 {
   238  		return parseError(packet[0])
   239  	}
   240  	id := binary.BigEndian.Uint32(packet[1:])
   241  	ch := m.chanList.getChan(id)
   242  	if ch == nil {
   243  		return m.handleUnknownChannelPacket(id, packet)
   244  	}
   245  
   246  	return ch.handlePacket(packet)
   247  }
   248  
   249  func (m *mux) handleGlobalPacket(packet []byte) error {
   250  	msg, err := decode(packet)
   251  	if err != nil {
   252  		return err
   253  	}
   254  
   255  	switch msg := msg.(type) {
   256  	case *globalRequestMsg:
   257  		m.incomingRequests <- &Request{
   258  			Type:      msg.Type,
   259  			WantReply: msg.WantReply,
   260  			Payload:   msg.Data,
   261  			mux:       m,
   262  		}
   263  	case *globalRequestSuccessMsg, *globalRequestFailureMsg:
   264  		m.globalResponses <- msg
   265  	default:
   266  		panic(fmt.Sprintf("not a global message %#v", msg))
   267  	}
   268  
   269  	return nil
   270  }
   271  
   272  // handleChannelOpen schedules a channel to be Accept()ed.
   273  func (m *mux) handleChannelOpen(packet []byte) error {
   274  	var msg channelOpenMsg
   275  	if err := Unmarshal(packet, &msg); err != nil {
   276  		return err
   277  	}
   278  
   279  	if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 {
   280  		failMsg := channelOpenFailureMsg{
   281  			PeersID:  msg.PeersID,
   282  			Reason:   ConnectionFailed,
   283  			Message:  "invalid request",
   284  			Language: "en_US.UTF-8",
   285  		}
   286  		return m.sendMessage(failMsg)
   287  	}
   288  
   289  	c := m.newChannel(msg.ChanType, channelInbound, msg.TypeSpecificData)
   290  	c.remoteId = msg.PeersID
   291  	c.maxRemotePayload = msg.MaxPacketSize
   292  	c.remoteWin.add(msg.PeersWindow)
   293  	m.incomingChannels <- c
   294  	return nil
   295  }
   296  
   297  func (m *mux) OpenChannel(chanType string, extra []byte) (Channel, <-chan *Request, error) {
   298  	ch, err := m.openChannel(chanType, extra)
   299  	if err != nil {
   300  		return nil, nil, err
   301  	}
   302  
   303  	return ch, ch.incomingRequests, nil
   304  }
   305  
   306  func (m *mux) openChannel(chanType string, extra []byte) (*channel, error) {
   307  	ch := m.newChannel(chanType, channelOutbound, extra)
   308  
   309  	ch.maxIncomingPayload = channelMaxPacket
   310  
   311  	open := channelOpenMsg{
   312  		ChanType:         chanType,
   313  		PeersWindow:      ch.myWindow,
   314  		MaxPacketSize:    ch.maxIncomingPayload,
   315  		TypeSpecificData: extra,
   316  		PeersID:          ch.localId,
   317  	}
   318  	if err := m.sendMessage(open); err != nil {
   319  		return nil, err
   320  	}
   321  
   322  	switch msg := (<-ch.msg).(type) {
   323  	case *channelOpenConfirmMsg:
   324  		return ch, nil
   325  	case *channelOpenFailureMsg:
   326  		return nil, &OpenChannelError{msg.Reason, msg.Message}
   327  	default:
   328  		return nil, fmt.Errorf("ssh: unexpected packet in response to channel open: %T", msg)
   329  	}
   330  }
   331  
   332  func (m *mux) handleUnknownChannelPacket(id uint32, packet []byte) error {
   333  	msg, err := decode(packet)
   334  	if err != nil {
   335  		return err
   336  	}
   337  
   338  	switch msg := msg.(type) {
   339  	// RFC 4254 section 5.4 says unrecognized channel requests should
   340  	// receive a failure response.
   341  	case *channelRequestMsg:
   342  		if msg.WantReply {
   343  			return m.sendMessage(channelRequestFailureMsg{
   344  				PeersID: msg.PeersID,
   345  			})
   346  		}
   347  		return nil
   348  	default:
   349  		return fmt.Errorf("ssh: invalid channel %d", id)
   350  	}
   351  }