github.com/technosophos/deis@v1.7.1-0.20150915173815-f9005256004b/Godeps/_workspace/src/golang.org/x/crypto/ssh/mux.go (about) 1 // Copyright 2013 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package ssh 6 7 import ( 8 "encoding/binary" 9 "fmt" 10 "io" 11 "log" 12 "sync" 13 "sync/atomic" 14 ) 15 16 // debugMux, if set, causes messages in the connection protocol to be 17 // logged. 18 const debugMux = false 19 20 // chanList is a thread safe channel list. 21 type chanList struct { 22 // protects concurrent access to chans 23 sync.Mutex 24 25 // chans are indexed by the local id of the channel, which the 26 // other side should send in the PeersId field. 27 chans []*channel 28 29 // This is a debugging aid: it offsets all IDs by this 30 // amount. This helps distinguish otherwise identical 31 // server/client muxes 32 offset uint32 33 } 34 35 // Assigns a channel ID to the given channel. 36 func (c *chanList) add(ch *channel) uint32 { 37 c.Lock() 38 defer c.Unlock() 39 for i := range c.chans { 40 if c.chans[i] == nil { 41 c.chans[i] = ch 42 return uint32(i) + c.offset 43 } 44 } 45 c.chans = append(c.chans, ch) 46 return uint32(len(c.chans)-1) + c.offset 47 } 48 49 // getChan returns the channel for the given ID. 50 func (c *chanList) getChan(id uint32) *channel { 51 id -= c.offset 52 53 c.Lock() 54 defer c.Unlock() 55 if id < uint32(len(c.chans)) { 56 return c.chans[id] 57 } 58 return nil 59 } 60 61 func (c *chanList) remove(id uint32) { 62 id -= c.offset 63 c.Lock() 64 if id < uint32(len(c.chans)) { 65 c.chans[id] = nil 66 } 67 c.Unlock() 68 } 69 70 // dropAll forgets all channels it knows, returning them in a slice. 71 func (c *chanList) dropAll() []*channel { 72 c.Lock() 73 defer c.Unlock() 74 var r []*channel 75 76 for _, ch := range c.chans { 77 if ch == nil { 78 continue 79 } 80 r = append(r, ch) 81 } 82 c.chans = nil 83 return r 84 } 85 86 // mux represents the state for the SSH connection protocol, which 87 // multiplexes many channels onto a single packet transport. 88 type mux struct { 89 conn packetConn 90 chanList chanList 91 92 incomingChannels chan NewChannel 93 94 globalSentMu sync.Mutex 95 globalResponses chan interface{} 96 incomingRequests chan *Request 97 98 errCond *sync.Cond 99 err error 100 } 101 102 // When debugging, each new chanList instantiation has a different 103 // offset. 104 var globalOff uint32 105 106 func (m *mux) Wait() error { 107 m.errCond.L.Lock() 108 defer m.errCond.L.Unlock() 109 for m.err == nil { 110 m.errCond.Wait() 111 } 112 return m.err 113 } 114 115 // newMux returns a mux that runs over the given connection. 116 func newMux(p packetConn) *mux { 117 m := &mux{ 118 conn: p, 119 incomingChannels: make(chan NewChannel, 16), 120 globalResponses: make(chan interface{}, 1), 121 incomingRequests: make(chan *Request, 16), 122 errCond: newCond(), 123 } 124 if debugMux { 125 m.chanList.offset = atomic.AddUint32(&globalOff, 1) 126 } 127 128 go m.loop() 129 return m 130 } 131 132 func (m *mux) sendMessage(msg interface{}) error { 133 p := Marshal(msg) 134 return m.conn.writePacket(p) 135 } 136 137 func (m *mux) SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error) { 138 if wantReply { 139 m.globalSentMu.Lock() 140 defer m.globalSentMu.Unlock() 141 } 142 143 if err := m.sendMessage(globalRequestMsg{ 144 Type: name, 145 WantReply: wantReply, 146 Data: payload, 147 }); err != nil { 148 return false, nil, err 149 } 150 151 if !wantReply { 152 return false, nil, nil 153 } 154 155 msg, ok := <-m.globalResponses 156 if !ok { 157 return false, nil, io.EOF 158 } 159 switch msg := msg.(type) { 160 case *globalRequestFailureMsg: 161 return false, msg.Data, nil 162 case *globalRequestSuccessMsg: 163 return true, msg.Data, nil 164 default: 165 return false, nil, fmt.Errorf("ssh: unexpected response to request: %#v", msg) 166 } 167 } 168 169 // ackRequest must be called after processing a global request that 170 // has WantReply set. 171 func (m *mux) ackRequest(ok bool, data []byte) error { 172 if ok { 173 return m.sendMessage(globalRequestSuccessMsg{Data: data}) 174 } 175 return m.sendMessage(globalRequestFailureMsg{Data: data}) 176 } 177 178 // TODO(hanwen): Disconnect is a transport layer message. We should 179 // probably send and receive Disconnect somewhere in the transport 180 // code. 181 182 // Disconnect sends a disconnect message. 183 func (m *mux) Disconnect(reason uint32, message string) error { 184 return m.sendMessage(disconnectMsg{ 185 Reason: reason, 186 Message: message, 187 }) 188 } 189 190 func (m *mux) Close() error { 191 return m.conn.Close() 192 } 193 194 // loop runs the connection machine. It will process packets until an 195 // error is encountered. To synchronize on loop exit, use mux.Wait. 196 func (m *mux) loop() { 197 var err error 198 for err == nil { 199 err = m.onePacket() 200 } 201 202 for _, ch := range m.chanList.dropAll() { 203 ch.close() 204 } 205 206 close(m.incomingChannels) 207 close(m.incomingRequests) 208 close(m.globalResponses) 209 210 m.conn.Close() 211 212 m.errCond.L.Lock() 213 m.err = err 214 m.errCond.Broadcast() 215 m.errCond.L.Unlock() 216 217 if debugMux { 218 log.Println("loop exit", err) 219 } 220 } 221 222 // onePacket reads and processes one packet. 223 func (m *mux) onePacket() error { 224 packet, err := m.conn.readPacket() 225 if err != nil { 226 return err 227 } 228 229 if debugMux { 230 if packet[0] == msgChannelData || packet[0] == msgChannelExtendedData { 231 log.Printf("decoding(%d): data packet - %d bytes", m.chanList.offset, len(packet)) 232 } else { 233 p, _ := decode(packet) 234 log.Printf("decoding(%d): %d %#v - %d bytes", m.chanList.offset, packet[0], p, len(packet)) 235 } 236 } 237 238 switch packet[0] { 239 case msgNewKeys: 240 // Ignore notification of key change. 241 return nil 242 case msgDisconnect: 243 return m.handleDisconnect(packet) 244 case msgChannelOpen: 245 return m.handleChannelOpen(packet) 246 case msgGlobalRequest, msgRequestSuccess, msgRequestFailure: 247 return m.handleGlobalPacket(packet) 248 } 249 250 // assume a channel packet. 251 if len(packet) < 5 { 252 return parseError(packet[0]) 253 } 254 id := binary.BigEndian.Uint32(packet[1:]) 255 ch := m.chanList.getChan(id) 256 if ch == nil { 257 return fmt.Errorf("ssh: invalid channel %d", id) 258 } 259 260 return ch.handlePacket(packet) 261 } 262 263 func (m *mux) handleDisconnect(packet []byte) error { 264 var d disconnectMsg 265 if err := Unmarshal(packet, &d); err != nil { 266 return err 267 } 268 269 if debugMux { 270 log.Printf("caught disconnect: %v", d) 271 } 272 return &d 273 } 274 275 func (m *mux) handleGlobalPacket(packet []byte) error { 276 msg, err := decode(packet) 277 if err != nil { 278 return err 279 } 280 281 switch msg := msg.(type) { 282 case *globalRequestMsg: 283 m.incomingRequests <- &Request{ 284 Type: msg.Type, 285 WantReply: msg.WantReply, 286 Payload: msg.Data, 287 mux: m, 288 } 289 case *globalRequestSuccessMsg, *globalRequestFailureMsg: 290 m.globalResponses <- msg 291 default: 292 panic(fmt.Sprintf("not a global message %#v", msg)) 293 } 294 295 return nil 296 } 297 298 // handleChannelOpen schedules a channel to be Accept()ed. 299 func (m *mux) handleChannelOpen(packet []byte) error { 300 var msg channelOpenMsg 301 if err := Unmarshal(packet, &msg); err != nil { 302 return err 303 } 304 305 if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 { 306 failMsg := channelOpenFailureMsg{ 307 PeersId: msg.PeersId, 308 Reason: ConnectionFailed, 309 Message: "invalid request", 310 Language: "en_US.UTF-8", 311 } 312 return m.sendMessage(failMsg) 313 } 314 315 c := m.newChannel(msg.ChanType, channelInbound, msg.TypeSpecificData) 316 c.remoteId = msg.PeersId 317 c.maxRemotePayload = msg.MaxPacketSize 318 c.remoteWin.add(msg.PeersWindow) 319 m.incomingChannels <- c 320 return nil 321 } 322 323 func (m *mux) OpenChannel(chanType string, extra []byte) (Channel, <-chan *Request, error) { 324 ch, err := m.openChannel(chanType, extra) 325 if err != nil { 326 return nil, nil, err 327 } 328 329 return ch, ch.incomingRequests, nil 330 } 331 332 func (m *mux) openChannel(chanType string, extra []byte) (*channel, error) { 333 ch := m.newChannel(chanType, channelOutbound, extra) 334 335 ch.maxIncomingPayload = channelMaxPacket 336 337 open := channelOpenMsg{ 338 ChanType: chanType, 339 PeersWindow: ch.myWindow, 340 MaxPacketSize: ch.maxIncomingPayload, 341 TypeSpecificData: extra, 342 PeersId: ch.localId, 343 } 344 if err := m.sendMessage(open); err != nil { 345 return nil, err 346 } 347 348 switch msg := (<-ch.msg).(type) { 349 case *channelOpenConfirmMsg: 350 return ch, nil 351 case *channelOpenFailureMsg: 352 return nil, &OpenChannelError{msg.Reason, msg.Message} 353 default: 354 return nil, fmt.Errorf("ssh: unexpected packet in response to channel open: %T", msg) 355 } 356 }