github.com/slspeek/camlistore_namedsearch@v0.0.0-20140519202248-ed6f70f7721a/third_party/code.google.com/p/go.crypto/ssh/channel.go (about)

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package ssh
     6  
     7  import (
     8  	"errors"
     9  	"io"
    10  	"sync"
    11  )
    12  
    13  // A Channel is an ordered, reliable, duplex stream that is multiplexed over an
    14  // SSH connection.
    15  type Channel interface {
    16  	// Accept accepts the channel creation request.
    17  	Accept() error
    18  	// Reject rejects the channel creation request. After calling this, no
    19  	// other methods on the Channel may be called. If they are then the
    20  	// peer is likely to signal a protocol error and drop the connection.
    21  	Reject(reason RejectionReason, message string) error
    22  
    23  	// Read may return a ChannelRequest as an error.
    24  	Read(data []byte) (int, error)
    25  	Write(data []byte) (int, error)
    26  	Close() error
    27  
    28  	// AckRequest either sends an ack or nack to the channel request.
    29  	AckRequest(ok bool) error
    30  
    31  	// ChannelType returns the type of the channel, as supplied by the
    32  	// client.
    33  	ChannelType() string
    34  	// ExtraData returns the arbitary payload for this channel, as supplied
    35  	// by the client. This data is specific to the channel type.
    36  	ExtraData() []byte
    37  }
    38  
    39  // ChannelRequest represents a request sent on a channel, outside of the normal
    40  // stream of bytes. It may result from calling Read on a Channel.
    41  type ChannelRequest struct {
    42  	Request   string
    43  	WantReply bool
    44  	Payload   []byte
    45  }
    46  
    47  func (c ChannelRequest) Error() string {
    48  	return "channel request received"
    49  }
    50  
    51  // RejectionReason is an enumeration used when rejecting channel creation
    52  // requests. See RFC 4254, section 5.1.
    53  type RejectionReason int
    54  
    55  const (
    56  	Prohibited RejectionReason = iota + 1
    57  	ConnectionFailed
    58  	UnknownChannelType
    59  	ResourceShortage
    60  )
    61  
    62  type channel struct {
    63  	// immutable once created
    64  	chanType  string
    65  	extraData []byte
    66  
    67  	theyClosed  bool
    68  	theySentEOF bool
    69  	weClosed    bool
    70  	dead        bool
    71  
    72  	serverConn            *ServerConn
    73  	myId, theirId         uint32
    74  	myWindow, theirWindow uint32
    75  	maxPacketSize         uint32
    76  	err                   error
    77  
    78  	pendingRequests []ChannelRequest
    79  	pendingData     []byte
    80  	head, length    int
    81  
    82  	// This lock is inferior to serverConn.lock
    83  	lock sync.Mutex
    84  	cond *sync.Cond
    85  }
    86  
    87  func (c *channel) Accept() error {
    88  	c.serverConn.lock.Lock()
    89  	defer c.serverConn.lock.Unlock()
    90  
    91  	if c.serverConn.err != nil {
    92  		return c.serverConn.err
    93  	}
    94  
    95  	confirm := channelOpenConfirmMsg{
    96  		PeersId:       c.theirId,
    97  		MyId:          c.myId,
    98  		MyWindow:      c.myWindow,
    99  		MaxPacketSize: c.maxPacketSize,
   100  	}
   101  	return c.serverConn.writePacket(marshal(msgChannelOpenConfirm, confirm))
   102  }
   103  
   104  func (c *channel) Reject(reason RejectionReason, message string) error {
   105  	c.serverConn.lock.Lock()
   106  	defer c.serverConn.lock.Unlock()
   107  
   108  	if c.serverConn.err != nil {
   109  		return c.serverConn.err
   110  	}
   111  
   112  	reject := channelOpenFailureMsg{
   113  		PeersId:  c.theirId,
   114  		Reason:   uint32(reason),
   115  		Message:  message,
   116  		Language: "en",
   117  	}
   118  	return c.serverConn.writePacket(marshal(msgChannelOpenFailure, reject))
   119  }
   120  
   121  func (c *channel) handlePacket(packet interface{}) {
   122  	c.lock.Lock()
   123  	defer c.lock.Unlock()
   124  
   125  	switch packet := packet.(type) {
   126  	case *channelRequestMsg:
   127  		req := ChannelRequest{
   128  			Request:   packet.Request,
   129  			WantReply: packet.WantReply,
   130  			Payload:   packet.RequestSpecificData,
   131  		}
   132  
   133  		c.pendingRequests = append(c.pendingRequests, req)
   134  		c.cond.Signal()
   135  	case *channelCloseMsg:
   136  		c.theyClosed = true
   137  		c.cond.Signal()
   138  	case *channelEOFMsg:
   139  		c.theySentEOF = true
   140  		c.cond.Signal()
   141  	default:
   142  		panic("unknown packet type")
   143  	}
   144  }
   145  
   146  func (c *channel) handleData(data []byte) {
   147  	c.lock.Lock()
   148  	defer c.lock.Unlock()
   149  
   150  	// The other side should never send us more than our window.
   151  	if len(data)+c.length > len(c.pendingData) {
   152  		// TODO(agl): we should tear down the channel with a protocol
   153  		// error.
   154  		return
   155  	}
   156  
   157  	c.myWindow -= uint32(len(data))
   158  	for i := 0; i < 2; i++ {
   159  		tail := c.head + c.length
   160  		if tail > len(c.pendingData) {
   161  			tail -= len(c.pendingData)
   162  		}
   163  		n := copy(c.pendingData[tail:], data)
   164  		data = data[n:]
   165  		c.length += n
   166  	}
   167  
   168  	c.cond.Signal()
   169  }
   170  
   171  func (c *channel) Read(data []byte) (n int, err error) {
   172  	c.lock.Lock()
   173  	defer c.lock.Unlock()
   174  
   175  	if c.err != nil {
   176  		return 0, c.err
   177  	}
   178  
   179  	if c.myWindow <= uint32(len(c.pendingData))/2 {
   180  		packet := marshal(msgChannelWindowAdjust, windowAdjustMsg{
   181  			PeersId:         c.theirId,
   182  			AdditionalBytes: uint32(len(c.pendingData)) - c.myWindow,
   183  		})
   184  		if err := c.serverConn.writePacket(packet); err != nil {
   185  			return 0, err
   186  		}
   187  	}
   188  
   189  	for {
   190  		if c.theySentEOF || c.theyClosed || c.dead {
   191  			return 0, io.EOF
   192  		}
   193  
   194  		if len(c.pendingRequests) > 0 {
   195  			req := c.pendingRequests[0]
   196  			if len(c.pendingRequests) == 1 {
   197  				c.pendingRequests = nil
   198  			} else {
   199  				oldPendingRequests := c.pendingRequests
   200  				c.pendingRequests = make([]ChannelRequest, len(oldPendingRequests)-1)
   201  				copy(c.pendingRequests, oldPendingRequests[1:])
   202  			}
   203  
   204  			return 0, req
   205  		}
   206  
   207  		if c.length > 0 {
   208  			tail := c.head + c.length
   209  			if tail > len(c.pendingData) {
   210  				tail -= len(c.pendingData)
   211  			}
   212  			n = copy(data, c.pendingData[c.head:tail])
   213  			c.head += n
   214  			c.length -= n
   215  			if c.head == len(c.pendingData) {
   216  				c.head = 0
   217  			}
   218  			return
   219  		}
   220  
   221  		c.cond.Wait()
   222  	}
   223  
   224  	panic("unreachable")
   225  }
   226  
   227  func (c *channel) Write(data []byte) (n int, err error) {
   228  	for len(data) > 0 {
   229  		c.lock.Lock()
   230  		if c.dead || c.weClosed {
   231  			return 0, io.EOF
   232  		}
   233  
   234  		if c.theirWindow == 0 {
   235  			c.cond.Wait()
   236  			continue
   237  		}
   238  		c.lock.Unlock()
   239  
   240  		todo := data
   241  		if uint32(len(todo)) > c.theirWindow {
   242  			todo = todo[:c.theirWindow]
   243  		}
   244  
   245  		packet := make([]byte, 1+4+4+len(todo))
   246  		packet[0] = msgChannelData
   247  		packet[1] = byte(c.theirId >> 24)
   248  		packet[2] = byte(c.theirId >> 16)
   249  		packet[3] = byte(c.theirId >> 8)
   250  		packet[4] = byte(c.theirId)
   251  		packet[5] = byte(len(todo) >> 24)
   252  		packet[6] = byte(len(todo) >> 16)
   253  		packet[7] = byte(len(todo) >> 8)
   254  		packet[8] = byte(len(todo))
   255  		copy(packet[9:], todo)
   256  
   257  		c.serverConn.lock.Lock()
   258  		if err = c.serverConn.writePacket(packet); err != nil {
   259  			c.serverConn.lock.Unlock()
   260  			return
   261  		}
   262  		c.serverConn.lock.Unlock()
   263  
   264  		n += len(todo)
   265  		data = data[len(todo):]
   266  	}
   267  
   268  	return
   269  }
   270  
   271  func (c *channel) Close() error {
   272  	c.serverConn.lock.Lock()
   273  	defer c.serverConn.lock.Unlock()
   274  
   275  	if c.serverConn.err != nil {
   276  		return c.serverConn.err
   277  	}
   278  
   279  	if c.weClosed {
   280  		return errors.New("ssh: channel already closed")
   281  	}
   282  	c.weClosed = true
   283  
   284  	closeMsg := channelCloseMsg{
   285  		PeersId: c.theirId,
   286  	}
   287  	return c.serverConn.writePacket(marshal(msgChannelClose, closeMsg))
   288  }
   289  
   290  func (c *channel) AckRequest(ok bool) error {
   291  	c.serverConn.lock.Lock()
   292  	defer c.serverConn.lock.Unlock()
   293  
   294  	if c.serverConn.err != nil {
   295  		return c.serverConn.err
   296  	}
   297  
   298  	if ok {
   299  		ack := channelRequestSuccessMsg{
   300  			PeersId: c.theirId,
   301  		}
   302  		return c.serverConn.writePacket(marshal(msgChannelSuccess, ack))
   303  	} else {
   304  		ack := channelRequestFailureMsg{
   305  			PeersId: c.theirId,
   306  		}
   307  		return c.serverConn.writePacket(marshal(msgChannelFailure, ack))
   308  	}
   309  	panic("unreachable")
   310  }
   311  
   312  func (c *channel) ChannelType() string {
   313  	return c.chanType
   314  }
   315  
   316  func (c *channel) ExtraData() []byte {
   317  	return c.extraData
   318  }