github.com/timstclair/heapster@v0.20.0-alpha1/Godeps/_workspace/src/golang.org/x/crypto/ssh/channel.go (about)

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package ssh
     6  
     7  import (
     8  	"encoding/binary"
     9  	"errors"
    10  	"fmt"
    11  	"io"
    12  	"log"
    13  	"sync"
    14  )
    15  
    16  const (
    17  	minPacketLength = 9
    18  	// channelMaxPacket contains the maximum number of bytes that will be
    19  	// sent in a single packet. As per RFC 4253, section 6.1, 32k is also
    20  	// the minimum.
    21  	channelMaxPacket = 1 << 15
    22  	// We follow OpenSSH here.
    23  	channelWindowSize = 64 * channelMaxPacket
    24  )
    25  
    26  // NewChannel represents an incoming request to a channel. It must either be
    27  // accepted for use by calling Accept, or rejected by calling Reject.
    28  type NewChannel interface {
    29  	// Accept accepts the channel creation request. It returns the Channel
    30  	// and a Go channel containing SSH requests. The Go channel must be
    31  	// serviced otherwise the Channel will hang.
    32  	Accept() (Channel, <-chan *Request, error)
    33  
    34  	// Reject rejects the channel creation request. After calling
    35  	// this, no other methods on the Channel may be called.
    36  	Reject(reason RejectionReason, message string) error
    37  
    38  	// ChannelType returns the type of the channel, as supplied by the
    39  	// client.
    40  	ChannelType() string
    41  
    42  	// ExtraData returns the arbitrary payload for this channel, as supplied
    43  	// by the client. This data is specific to the channel type.
    44  	ExtraData() []byte
    45  }
    46  
    47  // A Channel is an ordered, reliable, flow-controlled, duplex stream
    48  // that is multiplexed over an SSH connection.
    49  type Channel interface {
    50  	// Read reads up to len(data) bytes from the channel.
    51  	Read(data []byte) (int, error)
    52  
    53  	// Write writes len(data) bytes to the channel.
    54  	Write(data []byte) (int, error)
    55  
    56  	// Close signals end of channel use. No data may be sent after this
    57  	// call.
    58  	Close() error
    59  
    60  	// CloseWrite signals the end of sending in-band
    61  	// data. Requests may still be sent, and the other side may
    62  	// still send data
    63  	CloseWrite() error
    64  
    65  	// SendRequest sends a channel request.  If wantReply is true,
    66  	// it will wait for a reply and return the result as a
    67  	// boolean, otherwise the return value will be false. Channel
    68  	// requests are out-of-band messages so they may be sent even
    69  	// if the data stream is closed or blocked by flow control.
    70  	SendRequest(name string, wantReply bool, payload []byte) (bool, error)
    71  
    72  	// Stderr returns an io.ReadWriter that writes to this channel
    73  	// with the extended data type set to stderr. Stderr may
    74  	// safely be read and written from a different goroutine than
    75  	// Read and Write respectively.
    76  	Stderr() io.ReadWriter
    77  }
    78  
    79  // Request is a request sent outside of the normal stream of
    80  // data. Requests can either be specific to an SSH channel, or they
    81  // can be global.
    82  type Request struct {
    83  	Type      string
    84  	WantReply bool
    85  	Payload   []byte
    86  
    87  	ch  *channel
    88  	mux *mux
    89  }
    90  
    91  // Reply sends a response to a request. It must be called for all requests
    92  // where WantReply is true and is a no-op otherwise. The payload argument is
    93  // ignored for replies to channel-specific requests.
    94  func (r *Request) Reply(ok bool, payload []byte) error {
    95  	if !r.WantReply {
    96  		return nil
    97  	}
    98  
    99  	if r.ch == nil {
   100  		return r.mux.ackRequest(ok, payload)
   101  	}
   102  
   103  	return r.ch.ackRequest(ok)
   104  }
   105  
   106  // RejectionReason is an enumeration used when rejecting channel creation
   107  // requests. See RFC 4254, section 5.1.
   108  type RejectionReason uint32
   109  
   110  const (
   111  	Prohibited RejectionReason = iota + 1
   112  	ConnectionFailed
   113  	UnknownChannelType
   114  	ResourceShortage
   115  )
   116  
   117  // String converts the rejection reason to human readable form.
   118  func (r RejectionReason) String() string {
   119  	switch r {
   120  	case Prohibited:
   121  		return "administratively prohibited"
   122  	case ConnectionFailed:
   123  		return "connect failed"
   124  	case UnknownChannelType:
   125  		return "unknown channel type"
   126  	case ResourceShortage:
   127  		return "resource shortage"
   128  	}
   129  	return fmt.Sprintf("unknown reason %d", int(r))
   130  }
   131  
   132  func min(a uint32, b int) uint32 {
   133  	if a < uint32(b) {
   134  		return a
   135  	}
   136  	return uint32(b)
   137  }
   138  
   139  type channelDirection uint8
   140  
   141  const (
   142  	channelInbound channelDirection = iota
   143  	channelOutbound
   144  )
   145  
   146  // channel is an implementation of the Channel interface that works
   147  // with the mux class.
   148  type channel struct {
   149  	// R/O after creation
   150  	chanType          string
   151  	extraData         []byte
   152  	localId, remoteId uint32
   153  
   154  	// maxIncomingPayload and maxRemotePayload are the maximum
   155  	// payload sizes of normal and extended data packets for
   156  	// receiving and sending, respectively. The wire packet will
   157  	// be 9 or 13 bytes larger (excluding encryption overhead).
   158  	maxIncomingPayload uint32
   159  	maxRemotePayload   uint32
   160  
   161  	mux *mux
   162  
   163  	// decided is set to true if an accept or reject message has been sent
   164  	// (for outbound channels) or received (for inbound channels).
   165  	decided bool
   166  
   167  	// direction contains either channelOutbound, for channels created
   168  	// locally, or channelInbound, for channels created by the peer.
   169  	direction channelDirection
   170  
   171  	// Pending internal channel messages.
   172  	msg chan interface{}
   173  
   174  	// Since requests have no ID, there can be only one request
   175  	// with WantReply=true outstanding.  This lock is held by a
   176  	// goroutine that has such an outgoing request pending.
   177  	sentRequestMu sync.Mutex
   178  
   179  	incomingRequests chan *Request
   180  
   181  	sentEOF bool
   182  
   183  	// thread-safe data
   184  	remoteWin  window
   185  	pending    *buffer
   186  	extPending *buffer
   187  
   188  	// windowMu protects myWindow, the flow-control window.
   189  	windowMu sync.Mutex
   190  	myWindow uint32
   191  
   192  	// writeMu serializes calls to mux.conn.writePacket() and
   193  	// protects sentClose and packetPool. This mutex must be
   194  	// different from windowMu, as writePacket can block if there
   195  	// is a key exchange pending.
   196  	writeMu   sync.Mutex
   197  	sentClose bool
   198  
   199  	// packetPool has a buffer for each extended channel ID to
   200  	// save allocations during writes.
   201  	packetPool map[uint32][]byte
   202  }
   203  
   204  // writePacket sends a packet. If the packet is a channel close, it updates
   205  // sentClose. This method takes the lock c.writeMu.
   206  func (c *channel) writePacket(packet []byte) error {
   207  	c.writeMu.Lock()
   208  	if c.sentClose {
   209  		c.writeMu.Unlock()
   210  		return io.EOF
   211  	}
   212  	c.sentClose = (packet[0] == msgChannelClose)
   213  	err := c.mux.conn.writePacket(packet)
   214  	c.writeMu.Unlock()
   215  	return err
   216  }
   217  
   218  func (c *channel) sendMessage(msg interface{}) error {
   219  	if debugMux {
   220  		log.Printf("send %d: %#v", c.mux.chanList.offset, msg)
   221  	}
   222  
   223  	p := Marshal(msg)
   224  	binary.BigEndian.PutUint32(p[1:], c.remoteId)
   225  	return c.writePacket(p)
   226  }
   227  
   228  // WriteExtended writes data to a specific extended stream. These streams are
   229  // used, for example, for stderr.
   230  func (c *channel) WriteExtended(data []byte, extendedCode uint32) (n int, err error) {
   231  	if c.sentEOF {
   232  		return 0, io.EOF
   233  	}
   234  	// 1 byte message type, 4 bytes remoteId, 4 bytes data length
   235  	opCode := byte(msgChannelData)
   236  	headerLength := uint32(9)
   237  	if extendedCode > 0 {
   238  		headerLength += 4
   239  		opCode = msgChannelExtendedData
   240  	}
   241  
   242  	c.writeMu.Lock()
   243  	packet := c.packetPool[extendedCode]
   244  	// We don't remove the buffer from packetPool, so
   245  	// WriteExtended calls from different goroutines will be
   246  	// flagged as errors by the race detector.
   247  	c.writeMu.Unlock()
   248  
   249  	for len(data) > 0 {
   250  		space := min(c.maxRemotePayload, len(data))
   251  		if space, err = c.remoteWin.reserve(space); err != nil {
   252  			return n, err
   253  		}
   254  		if want := headerLength + space; uint32(cap(packet)) < want {
   255  			packet = make([]byte, want)
   256  		} else {
   257  			packet = packet[:want]
   258  		}
   259  
   260  		todo := data[:space]
   261  
   262  		packet[0] = opCode
   263  		binary.BigEndian.PutUint32(packet[1:], c.remoteId)
   264  		if extendedCode > 0 {
   265  			binary.BigEndian.PutUint32(packet[5:], uint32(extendedCode))
   266  		}
   267  		binary.BigEndian.PutUint32(packet[headerLength-4:], uint32(len(todo)))
   268  		copy(packet[headerLength:], todo)
   269  		if err = c.writePacket(packet); err != nil {
   270  			return n, err
   271  		}
   272  
   273  		n += len(todo)
   274  		data = data[len(todo):]
   275  	}
   276  
   277  	c.writeMu.Lock()
   278  	c.packetPool[extendedCode] = packet
   279  	c.writeMu.Unlock()
   280  
   281  	return n, err
   282  }
   283  
   284  func (c *channel) handleData(packet []byte) error {
   285  	headerLen := 9
   286  	isExtendedData := packet[0] == msgChannelExtendedData
   287  	if isExtendedData {
   288  		headerLen = 13
   289  	}
   290  	if len(packet) < headerLen {
   291  		// malformed data packet
   292  		return parseError(packet[0])
   293  	}
   294  
   295  	var extended uint32
   296  	if isExtendedData {
   297  		extended = binary.BigEndian.Uint32(packet[5:])
   298  	}
   299  
   300  	length := binary.BigEndian.Uint32(packet[headerLen-4 : headerLen])
   301  	if length == 0 {
   302  		return nil
   303  	}
   304  	if length > c.maxIncomingPayload {
   305  		// TODO(hanwen): should send Disconnect?
   306  		return errors.New("ssh: incoming packet exceeds maximum payload size")
   307  	}
   308  
   309  	data := packet[headerLen:]
   310  	if length != uint32(len(data)) {
   311  		return errors.New("ssh: wrong packet length")
   312  	}
   313  
   314  	c.windowMu.Lock()
   315  	if c.myWindow < length {
   316  		c.windowMu.Unlock()
   317  		// TODO(hanwen): should send Disconnect with reason?
   318  		return errors.New("ssh: remote side wrote too much")
   319  	}
   320  	c.myWindow -= length
   321  	c.windowMu.Unlock()
   322  
   323  	if extended == 1 {
   324  		c.extPending.write(data)
   325  	} else if extended > 0 {
   326  		// discard other extended data.
   327  	} else {
   328  		c.pending.write(data)
   329  	}
   330  	return nil
   331  }
   332  
   333  func (c *channel) adjustWindow(n uint32) error {
   334  	c.windowMu.Lock()
   335  	// Since myWindow is managed on our side, and can never exceed
   336  	// the initial window setting, we don't worry about overflow.
   337  	c.myWindow += uint32(n)
   338  	c.windowMu.Unlock()
   339  	return c.sendMessage(windowAdjustMsg{
   340  		AdditionalBytes: uint32(n),
   341  	})
   342  }
   343  
   344  func (c *channel) ReadExtended(data []byte, extended uint32) (n int, err error) {
   345  	switch extended {
   346  	case 1:
   347  		n, err = c.extPending.Read(data)
   348  	case 0:
   349  		n, err = c.pending.Read(data)
   350  	default:
   351  		return 0, fmt.Errorf("ssh: extended code %d unimplemented", extended)
   352  	}
   353  
   354  	if n > 0 {
   355  		err = c.adjustWindow(uint32(n))
   356  		// sendWindowAdjust can return io.EOF if the remote
   357  		// peer has closed the connection, however we want to
   358  		// defer forwarding io.EOF to the caller of Read until
   359  		// the buffer has been drained.
   360  		if n > 0 && err == io.EOF {
   361  			err = nil
   362  		}
   363  	}
   364  
   365  	return n, err
   366  }
   367  
   368  func (c *channel) close() {
   369  	c.pending.eof()
   370  	c.extPending.eof()
   371  	close(c.msg)
   372  	close(c.incomingRequests)
   373  	c.writeMu.Lock()
   374  	// This is not necesary for a normal channel teardown, but if
   375  	// there was another error, it is.
   376  	c.sentClose = true
   377  	c.writeMu.Unlock()
   378  	// Unblock writers.
   379  	c.remoteWin.close()
   380  }
   381  
   382  // responseMessageReceived is called when a success or failure message is
   383  // received on a channel to check that such a message is reasonable for the
   384  // given channel.
   385  func (c *channel) responseMessageReceived() error {
   386  	if c.direction == channelInbound {
   387  		return errors.New("ssh: channel response message received on inbound channel")
   388  	}
   389  	if c.decided {
   390  		return errors.New("ssh: duplicate response received for channel")
   391  	}
   392  	c.decided = true
   393  	return nil
   394  }
   395  
   396  func (c *channel) handlePacket(packet []byte) error {
   397  	switch packet[0] {
   398  	case msgChannelData, msgChannelExtendedData:
   399  		return c.handleData(packet)
   400  	case msgChannelClose:
   401  		c.sendMessage(channelCloseMsg{PeersId: c.remoteId})
   402  		c.mux.chanList.remove(c.localId)
   403  		c.close()
   404  		return nil
   405  	case msgChannelEOF:
   406  		// RFC 4254 is mute on how EOF affects dataExt messages but
   407  		// it is logical to signal EOF at the same time.
   408  		c.extPending.eof()
   409  		c.pending.eof()
   410  		return nil
   411  	}
   412  
   413  	decoded, err := decode(packet)
   414  	if err != nil {
   415  		return err
   416  	}
   417  
   418  	switch msg := decoded.(type) {
   419  	case *channelOpenFailureMsg:
   420  		if err := c.responseMessageReceived(); err != nil {
   421  			return err
   422  		}
   423  		c.mux.chanList.remove(msg.PeersId)
   424  		c.msg <- msg
   425  	case *channelOpenConfirmMsg:
   426  		if err := c.responseMessageReceived(); err != nil {
   427  			return err
   428  		}
   429  		if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 {
   430  			return fmt.Errorf("ssh: invalid MaxPacketSize %d from peer", msg.MaxPacketSize)
   431  		}
   432  		c.remoteId = msg.MyId
   433  		c.maxRemotePayload = msg.MaxPacketSize
   434  		c.remoteWin.add(msg.MyWindow)
   435  		c.msg <- msg
   436  	case *windowAdjustMsg:
   437  		if !c.remoteWin.add(msg.AdditionalBytes) {
   438  			return fmt.Errorf("ssh: invalid window update for %d bytes", msg.AdditionalBytes)
   439  		}
   440  	case *channelRequestMsg:
   441  		req := Request{
   442  			Type:      msg.Request,
   443  			WantReply: msg.WantReply,
   444  			Payload:   msg.RequestSpecificData,
   445  			ch:        c,
   446  		}
   447  
   448  		c.incomingRequests <- &req
   449  	default:
   450  		c.msg <- msg
   451  	}
   452  	return nil
   453  }
   454  
   455  func (m *mux) newChannel(chanType string, direction channelDirection, extraData []byte) *channel {
   456  	ch := &channel{
   457  		remoteWin:        window{Cond: newCond()},
   458  		myWindow:         channelWindowSize,
   459  		pending:          newBuffer(),
   460  		extPending:       newBuffer(),
   461  		direction:        direction,
   462  		incomingRequests: make(chan *Request, 16),
   463  		msg:              make(chan interface{}, 16),
   464  		chanType:         chanType,
   465  		extraData:        extraData,
   466  		mux:              m,
   467  		packetPool:       make(map[uint32][]byte),
   468  	}
   469  	ch.localId = m.chanList.add(ch)
   470  	return ch
   471  }
   472  
   473  var errUndecided = errors.New("ssh: must Accept or Reject channel")
   474  var errDecidedAlready = errors.New("ssh: can call Accept or Reject only once")
   475  
   476  type extChannel struct {
   477  	code uint32
   478  	ch   *channel
   479  }
   480  
   481  func (e *extChannel) Write(data []byte) (n int, err error) {
   482  	return e.ch.WriteExtended(data, e.code)
   483  }
   484  
   485  func (e *extChannel) Read(data []byte) (n int, err error) {
   486  	return e.ch.ReadExtended(data, e.code)
   487  }
   488  
   489  func (c *channel) Accept() (Channel, <-chan *Request, error) {
   490  	if c.decided {
   491  		return nil, nil, errDecidedAlready
   492  	}
   493  	c.maxIncomingPayload = channelMaxPacket
   494  	confirm := channelOpenConfirmMsg{
   495  		PeersId:       c.remoteId,
   496  		MyId:          c.localId,
   497  		MyWindow:      c.myWindow,
   498  		MaxPacketSize: c.maxIncomingPayload,
   499  	}
   500  	c.decided = true
   501  	if err := c.sendMessage(confirm); err != nil {
   502  		return nil, nil, err
   503  	}
   504  
   505  	return c, c.incomingRequests, nil
   506  }
   507  
   508  func (ch *channel) Reject(reason RejectionReason, message string) error {
   509  	if ch.decided {
   510  		return errDecidedAlready
   511  	}
   512  	reject := channelOpenFailureMsg{
   513  		PeersId:  ch.remoteId,
   514  		Reason:   reason,
   515  		Message:  message,
   516  		Language: "en",
   517  	}
   518  	ch.decided = true
   519  	return ch.sendMessage(reject)
   520  }
   521  
   522  func (ch *channel) Read(data []byte) (int, error) {
   523  	if !ch.decided {
   524  		return 0, errUndecided
   525  	}
   526  	return ch.ReadExtended(data, 0)
   527  }
   528  
   529  func (ch *channel) Write(data []byte) (int, error) {
   530  	if !ch.decided {
   531  		return 0, errUndecided
   532  	}
   533  	return ch.WriteExtended(data, 0)
   534  }
   535  
   536  func (ch *channel) CloseWrite() error {
   537  	if !ch.decided {
   538  		return errUndecided
   539  	}
   540  	ch.sentEOF = true
   541  	return ch.sendMessage(channelEOFMsg{
   542  		PeersId: ch.remoteId})
   543  }
   544  
   545  func (ch *channel) Close() error {
   546  	if !ch.decided {
   547  		return errUndecided
   548  	}
   549  
   550  	return ch.sendMessage(channelCloseMsg{
   551  		PeersId: ch.remoteId})
   552  }
   553  
   554  // Extended returns an io.ReadWriter that sends and receives data on the given,
   555  // SSH extended stream. Such streams are used, for example, for stderr.
   556  func (ch *channel) Extended(code uint32) io.ReadWriter {
   557  	if !ch.decided {
   558  		return nil
   559  	}
   560  	return &extChannel{code, ch}
   561  }
   562  
   563  func (ch *channel) Stderr() io.ReadWriter {
   564  	return ch.Extended(1)
   565  }
   566  
   567  func (ch *channel) SendRequest(name string, wantReply bool, payload []byte) (bool, error) {
   568  	if !ch.decided {
   569  		return false, errUndecided
   570  	}
   571  
   572  	if wantReply {
   573  		ch.sentRequestMu.Lock()
   574  		defer ch.sentRequestMu.Unlock()
   575  	}
   576  
   577  	msg := channelRequestMsg{
   578  		PeersId:             ch.remoteId,
   579  		Request:             name,
   580  		WantReply:           wantReply,
   581  		RequestSpecificData: payload,
   582  	}
   583  
   584  	if err := ch.sendMessage(msg); err != nil {
   585  		return false, err
   586  	}
   587  
   588  	if wantReply {
   589  		m, ok := (<-ch.msg)
   590  		if !ok {
   591  			return false, io.EOF
   592  		}
   593  		switch m.(type) {
   594  		case *channelRequestFailureMsg:
   595  			return false, nil
   596  		case *channelRequestSuccessMsg:
   597  			return true, nil
   598  		default:
   599  			return false, fmt.Errorf("ssh: unexpected response to channel request: %#v", m)
   600  		}
   601  	}
   602  
   603  	return false, nil
   604  }
   605  
   606  // ackRequest either sends an ack or nack to the channel request.
   607  func (ch *channel) ackRequest(ok bool) error {
   608  	if !ch.decided {
   609  		return errUndecided
   610  	}
   611  
   612  	var msg interface{}
   613  	if !ok {
   614  		msg = channelRequestFailureMsg{
   615  			PeersId: ch.remoteId,
   616  		}
   617  	} else {
   618  		msg = channelRequestSuccessMsg{
   619  			PeersId: ch.remoteId,
   620  		}
   621  	}
   622  	return ch.sendMessage(msg)
   623  }
   624  
   625  func (ch *channel) ChannelType() string {
   626  	return ch.chanType
   627  }
   628  
   629  func (ch *channel) ExtraData() []byte {
   630  	return ch.extraData
   631  }