github.com/glycerine/xcryptossh@v7.0.4+incompatible/mux.go (about)

     1  // Copyright 2013 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package ssh
     6  
     7  import (
     8  	"context"
     9  	"encoding/binary"
    10  	"fmt"
    11  	"io"
    12  	"log"
    13  	"net"
    14  	"sync"
    15  	"sync/atomic"
    16  )
    17  
    18  // debugMux, if set, causes messages in the connection protocol to be
    19  // logged.
    20  const debugMux = false
    21  
    22  // chanList is a thread safe channel list.
    23  type chanList struct {
    24  	// protects concurrent access to chans
    25  	sync.Mutex
    26  
    27  	// chans are indexed by the local id of the channel, which the
    28  	// other side should send in the PeersId field.
    29  	chans []*channel
    30  
    31  	// This is a debugging aid: it offsets all IDs by this
    32  	// amount. This helps distinguish otherwise identical
    33  	// server/client muxes
    34  	offset uint32
    35  }
    36  
    37  // Assigns a channel ID to the given channel.
    38  func (c *chanList) add(ch *channel) uint32 {
    39  	c.Lock()
    40  	defer c.Unlock()
    41  	for i := range c.chans {
    42  		if c.chans[i] == nil {
    43  			c.chans[i] = ch
    44  			return uint32(i) + c.offset
    45  		}
    46  	}
    47  	c.chans = append(c.chans, ch)
    48  	return uint32(len(c.chans)-1) + c.offset
    49  }
    50  
    51  // getChan returns the channel for the given ID.
    52  func (c *chanList) getChan(id uint32) *channel {
    53  	id -= c.offset
    54  
    55  	c.Lock()
    56  	defer c.Unlock()
    57  	if id < uint32(len(c.chans)) {
    58  		return c.chans[id]
    59  	}
    60  	return nil
    61  }
    62  
    63  func (c *chanList) remove(id uint32) {
    64  	id -= c.offset
    65  	c.Lock()
    66  	if id < uint32(len(c.chans)) {
    67  		c.chans[id] = nil
    68  	}
    69  	c.Unlock()
    70  }
    71  
    72  // dropAll forgets all channels it knows, returning them in a slice.
    73  func (c *chanList) dropAll() []*channel {
    74  	c.Lock()
    75  	defer c.Unlock()
    76  	var r []*channel
    77  
    78  	for _, ch := range c.chans {
    79  		if ch == nil {
    80  			continue
    81  		}
    82  		r = append(r, ch)
    83  	}
    84  	c.chans = nil
    85  	return r
    86  }
    87  
    88  // mux represents the state for the SSH connection protocol, which
    89  // multiplexes many channels onto a single packet transport.
    90  type mux struct {
    91  	conn     packetConn
    92  	chanList chanList
    93  
    94  	incomingChannels chan NewChannel
    95  
    96  	globalSentMu     sync.Mutex
    97  	globalResponses  chan interface{}
    98  	incomingRequests chan *Request
    99  
   100  	errCond *sync.Cond
   101  	err     error
   102  
   103  	halt *Halter
   104  }
   105  
   106  // When debugging, each new chanList instantiation has a different
   107  // offset.
   108  var globalOff uint32
   109  
   110  func (m *mux) Wait() error {
   111  	m.errCond.L.Lock()
   112  	defer m.errCond.L.Unlock()
   113  	for m.err == nil {
   114  		m.errCond.Wait()
   115  	}
   116  	return m.err
   117  }
   118  
   119  // newMux returns a mux that runs over the given connection.
   120  func newMux(ctx context.Context, p packetConn, halt *Halter) *mux {
   121  	// idle is nil on server
   122  	m := &mux{
   123  		conn:             p,
   124  		incomingChannels: make(chan NewChannel, chanSize),
   125  		globalResponses:  make(chan interface{}, 1),
   126  		incomingRequests: make(chan *Request, chanSize),
   127  		errCond:          newCond(),
   128  		halt:             halt,
   129  	}
   130  
   131  	if debugMux {
   132  		m.chanList.offset = atomic.AddUint32(&globalOff, 1)
   133  	}
   134  
   135  	go m.loop(ctx)
   136  	return m
   137  }
   138  
   139  func (m *mux) sendMessage(msg interface{}) error {
   140  	p := Marshal(msg)
   141  	if debugMux {
   142  		log.Printf("send global(%d): %#v", m.chanList.offset, msg)
   143  	}
   144  	return m.conn.writePacket(p)
   145  }
   146  
   147  // SendRequest sends a global request, and returns the
   148  // reply. This is the ssh.Conn implimentation, described
   149  // in connection.go. If wantReply is true, it returns the
   150  // response status and payload. See also RFC4254, section 4.
   151  func (m *mux) SendRequest(ctx context.Context, name string, wantReply bool, payload []byte) (bool, []byte, error) {
   152  	if wantReply {
   153  		m.globalSentMu.Lock()
   154  		defer m.globalSentMu.Unlock()
   155  	}
   156  
   157  	if err := m.sendMessage(globalRequestMsg{
   158  		Type:      name,
   159  		WantReply: wantReply,
   160  		Data:      payload,
   161  	}); err != nil {
   162  		return false, nil, err
   163  	}
   164  
   165  	if !wantReply {
   166  		return false, nil, nil
   167  	}
   168  
   169  	select {
   170  	case msg, ok := <-m.globalResponses:
   171  		if !ok {
   172  			return false, nil, io.EOF
   173  		}
   174  		switch msg := msg.(type) {
   175  		case *globalRequestFailureMsg:
   176  			return false, msg.Data, nil
   177  		case *globalRequestSuccessMsg:
   178  			return true, msg.Data, nil
   179  		default:
   180  			return false, nil, fmt.Errorf("ssh: unexpected response to request: %#v", msg)
   181  		}
   182  
   183  	case <-m.halt.ReqStopChan():
   184  		return false, nil, io.EOF
   185  	case <-ctx.Done():
   186  		return false, nil, io.EOF
   187  	}
   188  }
   189  
   190  // ackRequest must be called after processing a global request that
   191  // has WantReply set.
   192  func (m *mux) ackRequest(ok bool, data []byte) error {
   193  	if ok {
   194  		return m.sendMessage(globalRequestSuccessMsg{Data: data})
   195  	}
   196  	return m.sendMessage(globalRequestFailureMsg{Data: data})
   197  }
   198  
   199  func (m *mux) Close() error {
   200  	return m.conn.Close()
   201  }
   202  
   203  // loop runs the connection machine. It will process packets until an
   204  // error is encountered. To synchronize on loop exit, use mux.Wait.
   205  func (m *mux) loop(ctx context.Context) {
   206  	var err error
   207  	for err == nil {
   208  		err = m.onePacket(ctx)
   209  
   210  		// We can't have timeout errors here cause us to
   211  		// leave the loop and close down, because we need to be able to
   212  		// resume from a timeout where we left off.
   213  		if err != nil {
   214  			nerr, ok := err.(net.Error)
   215  			if ok && nerr.Timeout() {
   216  				err = nil
   217  			}
   218  		}
   219  	}
   220  	for _, ch := range m.chanList.dropAll() {
   221  		ch.close()
   222  	}
   223  
   224  	close(m.incomingChannels)
   225  	close(m.incomingRequests)
   226  	close(m.globalResponses)
   227  
   228  	m.conn.Close()
   229  
   230  	m.errCond.L.Lock()
   231  	m.err = err
   232  	m.errCond.Broadcast()
   233  	m.errCond.L.Unlock()
   234  
   235  	if debugMux {
   236  		log.Println("loop exit", err)
   237  	}
   238  }
   239  
   240  // onePacket reads and processes one packet.
   241  func (m *mux) onePacket(ctx context.Context) error {
   242  	packet, err := m.conn.readPacket(ctx)
   243  	if err != nil {
   244  		return err
   245  	}
   246  
   247  	if debugMux {
   248  		if packet[0] == msgChannelData || packet[0] == msgChannelExtendedData {
   249  			log.Printf("decoding(%d): data packet - %d bytes", m.chanList.offset, len(packet))
   250  		} else {
   251  			p, _ := decode(packet)
   252  			log.Printf("decoding(%d): %d %#v - %d bytes", m.chanList.offset, packet[0], p, len(packet))
   253  		}
   254  	}
   255  
   256  	switch packet[0] {
   257  	case msgChannelOpen:
   258  		return m.handleChannelOpen(ctx, packet)
   259  	case msgGlobalRequest, msgRequestSuccess, msgRequestFailure:
   260  		return m.handleGlobalPacket(ctx, packet)
   261  	}
   262  
   263  	// assume a channel packet.
   264  	if len(packet) < 5 {
   265  		return parseError(packet[0])
   266  	}
   267  	id := binary.BigEndian.Uint32(packet[1:])
   268  	ch := m.chanList.getChan(id)
   269  	if ch == nil {
   270  		return fmt.Errorf("ssh: invalid channel %d", id)
   271  	}
   272  
   273  	return ch.handlePacket(packet)
   274  }
   275  
   276  func (m *mux) handleGlobalPacket(ctx context.Context, packet []byte) error {
   277  	msg, err := decode(packet)
   278  	if err != nil {
   279  		return err
   280  	}
   281  
   282  	switch msg := msg.(type) {
   283  	case *globalRequestMsg:
   284  		select {
   285  		case m.incomingRequests <- &Request{
   286  			Type:      msg.Type,
   287  			WantReply: msg.WantReply,
   288  			Payload:   msg.Data,
   289  			mux:       m,
   290  		}:
   291  			// just the send
   292  		case <-m.halt.ReqStopChan():
   293  			return io.EOF
   294  		case <-ctx.Done():
   295  			return io.EOF
   296  		}
   297  	case *globalRequestSuccessMsg, *globalRequestFailureMsg:
   298  		select {
   299  		case m.globalResponses <- msg:
   300  		case <-m.halt.ReqStopChan():
   301  			return io.EOF
   302  		case <-ctx.Done():
   303  			return io.EOF
   304  		}
   305  	default:
   306  		panic(fmt.Sprintf("not a global message %#v", msg))
   307  	}
   308  
   309  	return nil
   310  }
   311  
   312  // handleChannelOpen schedules a channel to be Accept()ed.
   313  func (m *mux) handleChannelOpen(ctx context.Context, packet []byte) error {
   314  	var msg channelOpenMsg
   315  	if err := Unmarshal(packet, &msg); err != nil {
   316  		return err
   317  	}
   318  
   319  	if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 {
   320  		failMsg := channelOpenFailureMsg{
   321  			PeersId:  msg.PeersId,
   322  			Reason:   ConnectionFailed,
   323  			Message:  "invalid request",
   324  			Language: "en_US.UTF-8",
   325  		}
   326  		return m.sendMessage(failMsg)
   327  	}
   328  
   329  	c := m.newChannel(msg.ChanType, channelInbound, msg.TypeSpecificData)
   330  	c.remoteId = msg.PeersId
   331  	c.maxRemotePayload = msg.MaxPacketSize
   332  	c.remoteWin.add(msg.PeersWindow)
   333  	select {
   334  	case m.incomingChannels <- c:
   335  	case <-m.halt.ReqStopChan():
   336  		return io.EOF
   337  	case <-ctx.Done():
   338  		return io.EOF
   339  	}
   340  	return nil
   341  }
   342  
   343  func (m *mux) OpenChannel(ctx context.Context, chanType string, extra []byte, parentHalt *Halter) (Channel, <-chan *Request, error) {
   344  	ch, err := m.openChannel(ctx, chanType, extra, parentHalt)
   345  	if err != nil {
   346  		return nil, nil, err
   347  	}
   348  
   349  	return ch, ch.incomingRequests, nil
   350  }
   351  
   352  func (m *mux) openChannel(ctx context.Context, chanType string, extra []byte, parentHalt *Halter) (*channel, error) {
   353  	ch := m.newChannel(chanType, channelOutbound, extra)
   354  
   355  	ch.maxIncomingPayload = channelMaxPacket
   356  
   357  	open := channelOpenMsg{
   358  		ChanType:         chanType,
   359  		PeersWindow:      ch.myWindow,
   360  		MaxPacketSize:    ch.maxIncomingPayload,
   361  		TypeSpecificData: extra,
   362  		PeersId:          ch.localId,
   363  	}
   364  	if err := m.sendMessage(open); err != nil {
   365  		ch.idleR.Halt.RequestStop()
   366  		ch.idleW.Halt.RequestStop()
   367  		return nil, err
   368  	}
   369  
   370  	var done chan struct{}
   371  	if m.halt != nil {
   372  		done = m.halt.ReqStopChan()
   373  	}
   374  
   375  	select {
   376  	case msg := <-ch.msg:
   377  		switch msgt := msg.(type) {
   378  		case *channelOpenConfirmMsg:
   379  			if parentHalt != nil {
   380  				parentHalt.AddDownstream(ch.halt)
   381  			}
   382  			return ch, nil
   383  		case *channelOpenFailureMsg:
   384  			ch.idleR.Halt.RequestStop()
   385  			ch.idleW.Halt.RequestStop()
   386  			return nil, &OpenChannelError{msgt.Reason, msgt.Message}
   387  		default:
   388  			ch.idleR.Halt.RequestStop()
   389  			ch.idleW.Halt.RequestStop()
   390  			return nil, fmt.Errorf("ssh: unexpected packet in response to channel open: %T", msgt)
   391  		}
   392  	case <-done:
   393  		return nil, io.EOF
   394  	case <-ctx.Done():
   395  		return nil, io.EOF
   396  	}
   397  }