github.com/glycerine/xcryptossh@v7.0.4+incompatible/mux.go (about) 1 // Copyright 2013 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package ssh 6 7 import ( 8 "context" 9 "encoding/binary" 10 "fmt" 11 "io" 12 "log" 13 "net" 14 "sync" 15 "sync/atomic" 16 ) 17 18 // debugMux, if set, causes messages in the connection protocol to be 19 // logged. 20 const debugMux = false 21 22 // chanList is a thread safe channel list. 23 type chanList struct { 24 // protects concurrent access to chans 25 sync.Mutex 26 27 // chans are indexed by the local id of the channel, which the 28 // other side should send in the PeersId field. 29 chans []*channel 30 31 // This is a debugging aid: it offsets all IDs by this 32 // amount. This helps distinguish otherwise identical 33 // server/client muxes 34 offset uint32 35 } 36 37 // Assigns a channel ID to the given channel. 38 func (c *chanList) add(ch *channel) uint32 { 39 c.Lock() 40 defer c.Unlock() 41 for i := range c.chans { 42 if c.chans[i] == nil { 43 c.chans[i] = ch 44 return uint32(i) + c.offset 45 } 46 } 47 c.chans = append(c.chans, ch) 48 return uint32(len(c.chans)-1) + c.offset 49 } 50 51 // getChan returns the channel for the given ID. 52 func (c *chanList) getChan(id uint32) *channel { 53 id -= c.offset 54 55 c.Lock() 56 defer c.Unlock() 57 if id < uint32(len(c.chans)) { 58 return c.chans[id] 59 } 60 return nil 61 } 62 63 func (c *chanList) remove(id uint32) { 64 id -= c.offset 65 c.Lock() 66 if id < uint32(len(c.chans)) { 67 c.chans[id] = nil 68 } 69 c.Unlock() 70 } 71 72 // dropAll forgets all channels it knows, returning them in a slice. 73 func (c *chanList) dropAll() []*channel { 74 c.Lock() 75 defer c.Unlock() 76 var r []*channel 77 78 for _, ch := range c.chans { 79 if ch == nil { 80 continue 81 } 82 r = append(r, ch) 83 } 84 c.chans = nil 85 return r 86 } 87 88 // mux represents the state for the SSH connection protocol, which 89 // multiplexes many channels onto a single packet transport. 90 type mux struct { 91 conn packetConn 92 chanList chanList 93 94 incomingChannels chan NewChannel 95 96 globalSentMu sync.Mutex 97 globalResponses chan interface{} 98 incomingRequests chan *Request 99 100 errCond *sync.Cond 101 err error 102 103 halt *Halter 104 } 105 106 // When debugging, each new chanList instantiation has a different 107 // offset. 108 var globalOff uint32 109 110 func (m *mux) Wait() error { 111 m.errCond.L.Lock() 112 defer m.errCond.L.Unlock() 113 for m.err == nil { 114 m.errCond.Wait() 115 } 116 return m.err 117 } 118 119 // newMux returns a mux that runs over the given connection. 120 func newMux(ctx context.Context, p packetConn, halt *Halter) *mux { 121 // idle is nil on server 122 m := &mux{ 123 conn: p, 124 incomingChannels: make(chan NewChannel, chanSize), 125 globalResponses: make(chan interface{}, 1), 126 incomingRequests: make(chan *Request, chanSize), 127 errCond: newCond(), 128 halt: halt, 129 } 130 131 if debugMux { 132 m.chanList.offset = atomic.AddUint32(&globalOff, 1) 133 } 134 135 go m.loop(ctx) 136 return m 137 } 138 139 func (m *mux) sendMessage(msg interface{}) error { 140 p := Marshal(msg) 141 if debugMux { 142 log.Printf("send global(%d): %#v", m.chanList.offset, msg) 143 } 144 return m.conn.writePacket(p) 145 } 146 147 // SendRequest sends a global request, and returns the 148 // reply. This is the ssh.Conn implimentation, described 149 // in connection.go. If wantReply is true, it returns the 150 // response status and payload. See also RFC4254, section 4. 151 func (m *mux) SendRequest(ctx context.Context, name string, wantReply bool, payload []byte) (bool, []byte, error) { 152 if wantReply { 153 m.globalSentMu.Lock() 154 defer m.globalSentMu.Unlock() 155 } 156 157 if err := m.sendMessage(globalRequestMsg{ 158 Type: name, 159 WantReply: wantReply, 160 Data: payload, 161 }); err != nil { 162 return false, nil, err 163 } 164 165 if !wantReply { 166 return false, nil, nil 167 } 168 169 select { 170 case msg, ok := <-m.globalResponses: 171 if !ok { 172 return false, nil, io.EOF 173 } 174 switch msg := msg.(type) { 175 case *globalRequestFailureMsg: 176 return false, msg.Data, nil 177 case *globalRequestSuccessMsg: 178 return true, msg.Data, nil 179 default: 180 return false, nil, fmt.Errorf("ssh: unexpected response to request: %#v", msg) 181 } 182 183 case <-m.halt.ReqStopChan(): 184 return false, nil, io.EOF 185 case <-ctx.Done(): 186 return false, nil, io.EOF 187 } 188 } 189 190 // ackRequest must be called after processing a global request that 191 // has WantReply set. 192 func (m *mux) ackRequest(ok bool, data []byte) error { 193 if ok { 194 return m.sendMessage(globalRequestSuccessMsg{Data: data}) 195 } 196 return m.sendMessage(globalRequestFailureMsg{Data: data}) 197 } 198 199 func (m *mux) Close() error { 200 return m.conn.Close() 201 } 202 203 // loop runs the connection machine. It will process packets until an 204 // error is encountered. To synchronize on loop exit, use mux.Wait. 205 func (m *mux) loop(ctx context.Context) { 206 var err error 207 for err == nil { 208 err = m.onePacket(ctx) 209 210 // We can't have timeout errors here cause us to 211 // leave the loop and close down, because we need to be able to 212 // resume from a timeout where we left off. 213 if err != nil { 214 nerr, ok := err.(net.Error) 215 if ok && nerr.Timeout() { 216 err = nil 217 } 218 } 219 } 220 for _, ch := range m.chanList.dropAll() { 221 ch.close() 222 } 223 224 close(m.incomingChannels) 225 close(m.incomingRequests) 226 close(m.globalResponses) 227 228 m.conn.Close() 229 230 m.errCond.L.Lock() 231 m.err = err 232 m.errCond.Broadcast() 233 m.errCond.L.Unlock() 234 235 if debugMux { 236 log.Println("loop exit", err) 237 } 238 } 239 240 // onePacket reads and processes one packet. 241 func (m *mux) onePacket(ctx context.Context) error { 242 packet, err := m.conn.readPacket(ctx) 243 if err != nil { 244 return err 245 } 246 247 if debugMux { 248 if packet[0] == msgChannelData || packet[0] == msgChannelExtendedData { 249 log.Printf("decoding(%d): data packet - %d bytes", m.chanList.offset, len(packet)) 250 } else { 251 p, _ := decode(packet) 252 log.Printf("decoding(%d): %d %#v - %d bytes", m.chanList.offset, packet[0], p, len(packet)) 253 } 254 } 255 256 switch packet[0] { 257 case msgChannelOpen: 258 return m.handleChannelOpen(ctx, packet) 259 case msgGlobalRequest, msgRequestSuccess, msgRequestFailure: 260 return m.handleGlobalPacket(ctx, packet) 261 } 262 263 // assume a channel packet. 264 if len(packet) < 5 { 265 return parseError(packet[0]) 266 } 267 id := binary.BigEndian.Uint32(packet[1:]) 268 ch := m.chanList.getChan(id) 269 if ch == nil { 270 return fmt.Errorf("ssh: invalid channel %d", id) 271 } 272 273 return ch.handlePacket(packet) 274 } 275 276 func (m *mux) handleGlobalPacket(ctx context.Context, packet []byte) error { 277 msg, err := decode(packet) 278 if err != nil { 279 return err 280 } 281 282 switch msg := msg.(type) { 283 case *globalRequestMsg: 284 select { 285 case m.incomingRequests <- &Request{ 286 Type: msg.Type, 287 WantReply: msg.WantReply, 288 Payload: msg.Data, 289 mux: m, 290 }: 291 // just the send 292 case <-m.halt.ReqStopChan(): 293 return io.EOF 294 case <-ctx.Done(): 295 return io.EOF 296 } 297 case *globalRequestSuccessMsg, *globalRequestFailureMsg: 298 select { 299 case m.globalResponses <- msg: 300 case <-m.halt.ReqStopChan(): 301 return io.EOF 302 case <-ctx.Done(): 303 return io.EOF 304 } 305 default: 306 panic(fmt.Sprintf("not a global message %#v", msg)) 307 } 308 309 return nil 310 } 311 312 // handleChannelOpen schedules a channel to be Accept()ed. 313 func (m *mux) handleChannelOpen(ctx context.Context, packet []byte) error { 314 var msg channelOpenMsg 315 if err := Unmarshal(packet, &msg); err != nil { 316 return err 317 } 318 319 if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 { 320 failMsg := channelOpenFailureMsg{ 321 PeersId: msg.PeersId, 322 Reason: ConnectionFailed, 323 Message: "invalid request", 324 Language: "en_US.UTF-8", 325 } 326 return m.sendMessage(failMsg) 327 } 328 329 c := m.newChannel(msg.ChanType, channelInbound, msg.TypeSpecificData) 330 c.remoteId = msg.PeersId 331 c.maxRemotePayload = msg.MaxPacketSize 332 c.remoteWin.add(msg.PeersWindow) 333 select { 334 case m.incomingChannels <- c: 335 case <-m.halt.ReqStopChan(): 336 return io.EOF 337 case <-ctx.Done(): 338 return io.EOF 339 } 340 return nil 341 } 342 343 func (m *mux) OpenChannel(ctx context.Context, chanType string, extra []byte, parentHalt *Halter) (Channel, <-chan *Request, error) { 344 ch, err := m.openChannel(ctx, chanType, extra, parentHalt) 345 if err != nil { 346 return nil, nil, err 347 } 348 349 return ch, ch.incomingRequests, nil 350 } 351 352 func (m *mux) openChannel(ctx context.Context, chanType string, extra []byte, parentHalt *Halter) (*channel, error) { 353 ch := m.newChannel(chanType, channelOutbound, extra) 354 355 ch.maxIncomingPayload = channelMaxPacket 356 357 open := channelOpenMsg{ 358 ChanType: chanType, 359 PeersWindow: ch.myWindow, 360 MaxPacketSize: ch.maxIncomingPayload, 361 TypeSpecificData: extra, 362 PeersId: ch.localId, 363 } 364 if err := m.sendMessage(open); err != nil { 365 ch.idleR.Halt.RequestStop() 366 ch.idleW.Halt.RequestStop() 367 return nil, err 368 } 369 370 var done chan struct{} 371 if m.halt != nil { 372 done = m.halt.ReqStopChan() 373 } 374 375 select { 376 case msg := <-ch.msg: 377 switch msgt := msg.(type) { 378 case *channelOpenConfirmMsg: 379 if parentHalt != nil { 380 parentHalt.AddDownstream(ch.halt) 381 } 382 return ch, nil 383 case *channelOpenFailureMsg: 384 ch.idleR.Halt.RequestStop() 385 ch.idleW.Halt.RequestStop() 386 return nil, &OpenChannelError{msgt.Reason, msgt.Message} 387 default: 388 ch.idleR.Halt.RequestStop() 389 ch.idleW.Halt.RequestStop() 390 return nil, fmt.Errorf("ssh: unexpected packet in response to channel open: %T", msgt) 391 } 392 case <-done: 393 return nil, io.EOF 394 case <-ctx.Done(): 395 return nil, io.EOF 396 } 397 }