github.com/timstclair/heapster@v0.20.0-alpha1/Godeps/_workspace/src/golang.org/x/crypto/ssh/mux.go (about)

     1  // Copyright 2013 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package ssh
     6  
     7  import (
     8  	"encoding/binary"
     9  	"fmt"
    10  	"io"
    11  	"log"
    12  	"sync"
    13  	"sync/atomic"
    14  )
    15  
    16  // debugMux, if set, causes messages in the connection protocol to be
    17  // logged.
    18  const debugMux = false
    19  
    20  // chanList is a thread safe channel list.
    21  type chanList struct {
    22  	// protects concurrent access to chans
    23  	sync.Mutex
    24  
    25  	// chans are indexed by the local id of the channel, which the
    26  	// other side should send in the PeersId field.
    27  	chans []*channel
    28  
    29  	// This is a debugging aid: it offsets all IDs by this
    30  	// amount. This helps distinguish otherwise identical
    31  	// server/client muxes
    32  	offset uint32
    33  }
    34  
    35  // Assigns a channel ID to the given channel.
    36  func (c *chanList) add(ch *channel) uint32 {
    37  	c.Lock()
    38  	defer c.Unlock()
    39  	for i := range c.chans {
    40  		if c.chans[i] == nil {
    41  			c.chans[i] = ch
    42  			return uint32(i) + c.offset
    43  		}
    44  	}
    45  	c.chans = append(c.chans, ch)
    46  	return uint32(len(c.chans)-1) + c.offset
    47  }
    48  
    49  // getChan returns the channel for the given ID.
    50  func (c *chanList) getChan(id uint32) *channel {
    51  	id -= c.offset
    52  
    53  	c.Lock()
    54  	defer c.Unlock()
    55  	if id < uint32(len(c.chans)) {
    56  		return c.chans[id]
    57  	}
    58  	return nil
    59  }
    60  
    61  func (c *chanList) remove(id uint32) {
    62  	id -= c.offset
    63  	c.Lock()
    64  	if id < uint32(len(c.chans)) {
    65  		c.chans[id] = nil
    66  	}
    67  	c.Unlock()
    68  }
    69  
    70  // dropAll forgets all channels it knows, returning them in a slice.
    71  func (c *chanList) dropAll() []*channel {
    72  	c.Lock()
    73  	defer c.Unlock()
    74  	var r []*channel
    75  
    76  	for _, ch := range c.chans {
    77  		if ch == nil {
    78  			continue
    79  		}
    80  		r = append(r, ch)
    81  	}
    82  	c.chans = nil
    83  	return r
    84  }
    85  
    86  // mux represents the state for the SSH connection protocol, which
    87  // multiplexes many channels onto a single packet transport.
    88  type mux struct {
    89  	conn     packetConn
    90  	chanList chanList
    91  
    92  	incomingChannels chan NewChannel
    93  
    94  	globalSentMu     sync.Mutex
    95  	globalResponses  chan interface{}
    96  	incomingRequests chan *Request
    97  
    98  	errCond *sync.Cond
    99  	err     error
   100  }
   101  
   102  // When debugging, each new chanList instantiation has a different
   103  // offset.
   104  var globalOff uint32
   105  
   106  func (m *mux) Wait() error {
   107  	m.errCond.L.Lock()
   108  	defer m.errCond.L.Unlock()
   109  	for m.err == nil {
   110  		m.errCond.Wait()
   111  	}
   112  	return m.err
   113  }
   114  
   115  // newMux returns a mux that runs over the given connection.
   116  func newMux(p packetConn) *mux {
   117  	m := &mux{
   118  		conn:             p,
   119  		incomingChannels: make(chan NewChannel, 16),
   120  		globalResponses:  make(chan interface{}, 1),
   121  		incomingRequests: make(chan *Request, 16),
   122  		errCond:          newCond(),
   123  	}
   124  	if debugMux {
   125  		m.chanList.offset = atomic.AddUint32(&globalOff, 1)
   126  	}
   127  
   128  	go m.loop()
   129  	return m
   130  }
   131  
   132  func (m *mux) sendMessage(msg interface{}) error {
   133  	p := Marshal(msg)
   134  	return m.conn.writePacket(p)
   135  }
   136  
   137  func (m *mux) SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error) {
   138  	if wantReply {
   139  		m.globalSentMu.Lock()
   140  		defer m.globalSentMu.Unlock()
   141  	}
   142  
   143  	if err := m.sendMessage(globalRequestMsg{
   144  		Type:      name,
   145  		WantReply: wantReply,
   146  		Data:      payload,
   147  	}); err != nil {
   148  		return false, nil, err
   149  	}
   150  
   151  	if !wantReply {
   152  		return false, nil, nil
   153  	}
   154  
   155  	msg, ok := <-m.globalResponses
   156  	if !ok {
   157  		return false, nil, io.EOF
   158  	}
   159  	switch msg := msg.(type) {
   160  	case *globalRequestFailureMsg:
   161  		return false, msg.Data, nil
   162  	case *globalRequestSuccessMsg:
   163  		return true, msg.Data, nil
   164  	default:
   165  		return false, nil, fmt.Errorf("ssh: unexpected response to request: %#v", msg)
   166  	}
   167  }
   168  
   169  // ackRequest must be called after processing a global request that
   170  // has WantReply set.
   171  func (m *mux) ackRequest(ok bool, data []byte) error {
   172  	if ok {
   173  		return m.sendMessage(globalRequestSuccessMsg{Data: data})
   174  	}
   175  	return m.sendMessage(globalRequestFailureMsg{Data: data})
   176  }
   177  
   178  // TODO(hanwen): Disconnect is a transport layer message. We should
   179  // probably send and receive Disconnect somewhere in the transport
   180  // code.
   181  
   182  // Disconnect sends a disconnect message.
   183  func (m *mux) Disconnect(reason uint32, message string) error {
   184  	return m.sendMessage(disconnectMsg{
   185  		Reason:  reason,
   186  		Message: message,
   187  	})
   188  }
   189  
   190  func (m *mux) Close() error {
   191  	return m.conn.Close()
   192  }
   193  
   194  // loop runs the connection machine. It will process packets until an
   195  // error is encountered. To synchronize on loop exit, use mux.Wait.
   196  func (m *mux) loop() {
   197  	var err error
   198  	for err == nil {
   199  		err = m.onePacket()
   200  	}
   201  
   202  	for _, ch := range m.chanList.dropAll() {
   203  		ch.close()
   204  	}
   205  
   206  	close(m.incomingChannels)
   207  	close(m.incomingRequests)
   208  	close(m.globalResponses)
   209  
   210  	m.conn.Close()
   211  
   212  	m.errCond.L.Lock()
   213  	m.err = err
   214  	m.errCond.Broadcast()
   215  	m.errCond.L.Unlock()
   216  
   217  	if debugMux {
   218  		log.Println("loop exit", err)
   219  	}
   220  }
   221  
   222  // onePacket reads and processes one packet.
   223  func (m *mux) onePacket() error {
   224  	packet, err := m.conn.readPacket()
   225  	if err != nil {
   226  		return err
   227  	}
   228  
   229  	if debugMux {
   230  		if packet[0] == msgChannelData || packet[0] == msgChannelExtendedData {
   231  			log.Printf("decoding(%d): data packet - %d bytes", m.chanList.offset, len(packet))
   232  		} else {
   233  			p, _ := decode(packet)
   234  			log.Printf("decoding(%d): %d %#v - %d bytes", m.chanList.offset, packet[0], p, len(packet))
   235  		}
   236  	}
   237  
   238  	switch packet[0] {
   239  	case msgNewKeys:
   240  		// Ignore notification of key change.
   241  		return nil
   242  	case msgDisconnect:
   243  		return m.handleDisconnect(packet)
   244  	case msgChannelOpen:
   245  		return m.handleChannelOpen(packet)
   246  	case msgGlobalRequest, msgRequestSuccess, msgRequestFailure:
   247  		return m.handleGlobalPacket(packet)
   248  	}
   249  
   250  	// assume a channel packet.
   251  	if len(packet) < 5 {
   252  		return parseError(packet[0])
   253  	}
   254  	id := binary.BigEndian.Uint32(packet[1:])
   255  	ch := m.chanList.getChan(id)
   256  	if ch == nil {
   257  		return fmt.Errorf("ssh: invalid channel %d", id)
   258  	}
   259  
   260  	return ch.handlePacket(packet)
   261  }
   262  
   263  func (m *mux) handleDisconnect(packet []byte) error {
   264  	var d disconnectMsg
   265  	if err := Unmarshal(packet, &d); err != nil {
   266  		return err
   267  	}
   268  
   269  	if debugMux {
   270  		log.Printf("caught disconnect: %v", d)
   271  	}
   272  	return &d
   273  }
   274  
   275  func (m *mux) handleGlobalPacket(packet []byte) error {
   276  	msg, err := decode(packet)
   277  	if err != nil {
   278  		return err
   279  	}
   280  
   281  	switch msg := msg.(type) {
   282  	case *globalRequestMsg:
   283  		m.incomingRequests <- &Request{
   284  			Type:      msg.Type,
   285  			WantReply: msg.WantReply,
   286  			Payload:   msg.Data,
   287  			mux:       m,
   288  		}
   289  	case *globalRequestSuccessMsg, *globalRequestFailureMsg:
   290  		m.globalResponses <- msg
   291  	default:
   292  		panic(fmt.Sprintf("not a global message %#v", msg))
   293  	}
   294  
   295  	return nil
   296  }
   297  
   298  // handleChannelOpen schedules a channel to be Accept()ed.
   299  func (m *mux) handleChannelOpen(packet []byte) error {
   300  	var msg channelOpenMsg
   301  	if err := Unmarshal(packet, &msg); err != nil {
   302  		return err
   303  	}
   304  
   305  	if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 {
   306  		failMsg := channelOpenFailureMsg{
   307  			PeersId:  msg.PeersId,
   308  			Reason:   ConnectionFailed,
   309  			Message:  "invalid request",
   310  			Language: "en_US.UTF-8",
   311  		}
   312  		return m.sendMessage(failMsg)
   313  	}
   314  
   315  	c := m.newChannel(msg.ChanType, channelInbound, msg.TypeSpecificData)
   316  	c.remoteId = msg.PeersId
   317  	c.maxRemotePayload = msg.MaxPacketSize
   318  	c.remoteWin.add(msg.PeersWindow)
   319  	m.incomingChannels <- c
   320  	return nil
   321  }
   322  
   323  func (m *mux) OpenChannel(chanType string, extra []byte) (Channel, <-chan *Request, error) {
   324  	ch, err := m.openChannel(chanType, extra)
   325  	if err != nil {
   326  		return nil, nil, err
   327  	}
   328  
   329  	return ch, ch.incomingRequests, nil
   330  }
   331  
   332  func (m *mux) openChannel(chanType string, extra []byte) (*channel, error) {
   333  	ch := m.newChannel(chanType, channelOutbound, extra)
   334  
   335  	ch.maxIncomingPayload = channelMaxPacket
   336  
   337  	open := channelOpenMsg{
   338  		ChanType:         chanType,
   339  		PeersWindow:      ch.myWindow,
   340  		MaxPacketSize:    ch.maxIncomingPayload,
   341  		TypeSpecificData: extra,
   342  		PeersId:          ch.localId,
   343  	}
   344  	if err := m.sendMessage(open); err != nil {
   345  		return nil, err
   346  	}
   347  
   348  	switch msg := (<-ch.msg).(type) {
   349  	case *channelOpenConfirmMsg:
   350  		return ch, nil
   351  	case *channelOpenFailureMsg:
   352  		return nil, &OpenChannelError{msg.Reason, msg.Message}
   353  	default:
   354  		return nil, fmt.Errorf("ssh: unexpected packet in response to channel open: %T", msg)
   355  	}
   356  }