github.com/kaixiang/packer@v0.5.2-0.20140114230416-1f5786b0d7f1/packer/rpc/muxconn.go (about) 1 package rpc 2 3 import ( 4 "encoding/binary" 5 "fmt" 6 "io" 7 "log" 8 "sync" 9 "time" 10 ) 11 12 // MuxConn is able to multiplex multiple streams on top of any 13 // io.ReadWriteCloser. These streams act like TCP connections (Dial, Accept, 14 // Close, full duplex, etc.). 15 // 16 // The underlying io.ReadWriteCloser is expected to guarantee delivery 17 // and ordering, such as TCP. Congestion control and such aren't implemented 18 // by the streams, so that is also up to the underlying connection. 19 // 20 // MuxConn works using a fairly dumb multiplexing technique of simply 21 // framing every piece of data sent into a prefix + data format. Streams 22 // are established using a subset of the TCP protocol. Only a subset is 23 // necessary since we assume ordering on the underlying RWC. 24 type MuxConn struct { 25 curId uint32 26 rwc io.ReadWriteCloser 27 streamsAccept map[uint32]*Stream 28 streamsDial map[uint32]*Stream 29 muAccept sync.RWMutex 30 muDial sync.RWMutex 31 wlock sync.Mutex 32 doneCh chan struct{} 33 } 34 35 type muxPacketFrom byte 36 type muxPacketType byte 37 38 const ( 39 muxPacketFromAccept muxPacketFrom = iota 40 muxPacketFromDial 41 ) 42 43 const ( 44 muxPacketSyn muxPacketType = iota 45 muxPacketSynAck 46 muxPacketAck 47 muxPacketFin 48 muxPacketData 49 ) 50 51 func (f muxPacketFrom) String() string { 52 switch f { 53 case muxPacketFromAccept: 54 return "accept" 55 case muxPacketFromDial: 56 return "dial" 57 default: 58 panic("unknown from type") 59 } 60 } 61 62 // Create a new MuxConn around any io.ReadWriteCloser. 63 func NewMuxConn(rwc io.ReadWriteCloser) *MuxConn { 64 m := &MuxConn{ 65 rwc: rwc, 66 streamsAccept: make(map[uint32]*Stream), 67 streamsDial: make(map[uint32]*Stream), 68 doneCh: make(chan struct{}), 69 } 70 71 go m.cleaner() 72 go m.loop() 73 74 return m 75 } 76 77 // Close closes the underlying io.ReadWriteCloser. This will also close 78 // all streams that are open. 79 func (m *MuxConn) Close() error { 80 m.muAccept.Lock() 81 m.muDial.Lock() 82 defer m.muAccept.Unlock() 83 defer m.muDial.Unlock() 84 85 // Close all the streams 86 for _, w := range m.streamsAccept { 87 w.Close() 88 } 89 for _, w := range m.streamsDial { 90 w.Close() 91 } 92 m.streamsAccept = make(map[uint32]*Stream) 93 m.streamsDial = make(map[uint32]*Stream) 94 95 // Close the actual connection. This will also force the loop 96 // to end since it'll read EOF or closed connection. 97 return m.rwc.Close() 98 } 99 100 // Accept accepts a multiplexed connection with the given ID. This 101 // will block until a request is made to connect. 102 func (m *MuxConn) Accept(id uint32) (io.ReadWriteCloser, error) { 103 //log.Printf("[TRACE] %p: Accept on stream ID: %d", m, id) 104 105 // Get the stream. It is okay if it is already in the list of streams 106 // because we may have prematurely received a syn for it. 107 m.muAccept.Lock() 108 stream, ok := m.streamsAccept[id] 109 if !ok { 110 stream = newStream(muxPacketFromAccept, id, m) 111 m.streamsAccept[id] = stream 112 } 113 m.muAccept.Unlock() 114 115 stream.mu.Lock() 116 defer stream.mu.Unlock() 117 118 // If the stream isn't closed, then it is already open somehow 119 if stream.state != streamStateSynRecv && stream.state != streamStateClosed { 120 panic(fmt.Sprintf( 121 "Stream %d already open in bad state: %d", id, stream.state)) 122 } 123 124 if stream.state == streamStateClosed { 125 // Go into the listening state and wait for a syn 126 stream.setState(streamStateListen) 127 if err := stream.waitState(streamStateSynRecv); err != nil { 128 return nil, err 129 } 130 } 131 132 if stream.state == streamStateSynRecv { 133 // Send a syn-ack 134 if _, err := stream.write(muxPacketSynAck, nil); err != nil { 135 return nil, err 136 } 137 } 138 139 if err := stream.waitState(streamStateEstablished); err != nil { 140 return nil, err 141 } 142 143 return stream, nil 144 } 145 146 // Dial opens a connection to the remote end using the given stream ID. 147 // An Accept on the remote end will only work with if the IDs match. 148 func (m *MuxConn) Dial(id uint32) (io.ReadWriteCloser, error) { 149 //log.Printf("[TRACE] %p: Dial on stream ID: %d", m, id) 150 151 m.muDial.Lock() 152 153 // If we have any streams with this ID, then it is a failure. The 154 // reaper should clear out old streams once in awhile. 155 if stream, ok := m.streamsDial[id]; ok { 156 m.muDial.Unlock() 157 panic(fmt.Sprintf( 158 "Stream %d already open for dial. State: %d", 159 id, stream.state)) 160 } 161 162 // Create the new stream and put it in our list. We can then 163 // unlock because dialing will no longer be allowed on that ID. 164 stream := newStream(muxPacketFromDial, id, m) 165 m.streamsDial[id] = stream 166 167 // Don't let anyone else mess with this stream 168 stream.mu.Lock() 169 defer stream.mu.Unlock() 170 171 m.muDial.Unlock() 172 173 // Open a connection 174 if _, err := stream.write(muxPacketSyn, nil); err != nil { 175 return nil, err 176 } 177 178 // It is safe to set the state after the write above because 179 // we hold the stream lock. 180 stream.setState(streamStateSynSent) 181 182 if err := stream.waitState(streamStateEstablished); err != nil { 183 return nil, err 184 } 185 186 stream.write(muxPacketAck, nil) 187 return stream, nil 188 } 189 190 // NextId returns the next available listen stream ID that isn't currently 191 // taken. 192 func (m *MuxConn) NextId() uint32 { 193 m.muAccept.Lock() 194 defer m.muAccept.Unlock() 195 196 for { 197 // We never use stream ID 0 because 0 is the zero value of a uint32 198 // and we want to reserve that for "not in use" 199 if m.curId == 0 { 200 m.curId = 1 201 } 202 203 result := m.curId 204 m.curId += 1 205 if _, ok := m.streamsAccept[result]; !ok { 206 return result 207 } 208 } 209 } 210 211 func (m *MuxConn) cleaner() { 212 checks := []struct { 213 Map *map[uint32]*Stream 214 Lock *sync.RWMutex 215 }{ 216 {&m.streamsAccept, &m.muAccept}, 217 {&m.streamsDial, &m.muDial}, 218 } 219 220 for { 221 done := false 222 select { 223 case <-time.After(500 * time.Millisecond): 224 case <-m.doneCh: 225 done = true 226 } 227 228 for _, check := range checks { 229 check.Lock.Lock() 230 for id, s := range *check.Map { 231 s.mu.Lock() 232 233 if done && s.state != streamStateClosed { 234 s.closeWriter() 235 } 236 237 if s.state == streamStateClosed { 238 // Only clean up the streams that have been closed 239 // for a certain amount of time. 240 since := time.Now().UTC().Sub(s.stateUpdated) 241 if since > 2*time.Second { 242 delete(*check.Map, id) 243 } 244 } 245 246 s.mu.Unlock() 247 } 248 check.Lock.Unlock() 249 } 250 251 if done { 252 return 253 } 254 } 255 } 256 257 func (m *MuxConn) loop() { 258 // Force close every stream that we know about when we exit so 259 // that they all read EOF and don't block forever. 260 defer func() { 261 log.Printf("[INFO] Mux connection loop exiting") 262 close(m.doneCh) 263 }() 264 265 var from muxPacketFrom 266 var id uint32 267 var packetType muxPacketType 268 var length int32 269 for { 270 if err := binary.Read(m.rwc, binary.BigEndian, &from); err != nil { 271 log.Printf("[ERR] Error reading stream direction: %s", err) 272 return 273 } 274 if err := binary.Read(m.rwc, binary.BigEndian, &id); err != nil { 275 log.Printf("[ERR] Error reading stream ID: %s", err) 276 return 277 } 278 if err := binary.Read(m.rwc, binary.BigEndian, &packetType); err != nil { 279 log.Printf("[ERR] Error reading packet type: %s", err) 280 return 281 } 282 if err := binary.Read(m.rwc, binary.BigEndian, &length); err != nil { 283 log.Printf("[ERR] Error reading length: %s", err) 284 return 285 } 286 287 // TODO(mitchellh): probably would be better to re-use a buffer... 288 data := make([]byte, length) 289 n := 0 290 for n < int(length) { 291 if n2, err := m.rwc.Read(data[n:]); err != nil { 292 log.Printf("[ERR] Error reading data: %s", err) 293 return 294 } else { 295 n += n2 296 } 297 } 298 299 // Get the proper stream. Note that the map we look into is 300 // opposite the "from" because if the dial side is talking to 301 // us, we need to look into the accept map, and so on. 302 // 303 // Note: we also switch the "from" value so that logging 304 // below is correct. 305 var stream *Stream 306 switch from { 307 case muxPacketFromDial: 308 m.muAccept.Lock() 309 stream = m.streamsAccept[id] 310 m.muAccept.Unlock() 311 312 from = muxPacketFromAccept 313 case muxPacketFromAccept: 314 m.muDial.Lock() 315 stream = m.streamsDial[id] 316 m.muDial.Unlock() 317 318 from = muxPacketFromDial 319 default: 320 panic(fmt.Sprintf("Unknown stream direction: %d", from)) 321 } 322 323 if stream == nil && packetType != muxPacketSyn { 324 log.Printf( 325 "[WARN] %p: Non-existent stream %d (%s) received packer %d", 326 m, id, from, packetType) 327 continue 328 } 329 330 //log.Printf("[TRACE] %p: Stream %d (%s) received packet %d", m, id, from, packetType) 331 switch packetType { 332 case muxPacketSyn: 333 // If the stream is nil, this is the only case where we'll 334 // automatically create the stream struct. 335 if stream == nil { 336 var ok bool 337 338 m.muAccept.Lock() 339 stream, ok = m.streamsAccept[id] 340 if !ok { 341 stream = newStream(muxPacketFromAccept, id, m) 342 m.streamsAccept[id] = stream 343 } 344 m.muAccept.Unlock() 345 } 346 347 stream.mu.Lock() 348 switch stream.state { 349 case streamStateClosed: 350 fallthrough 351 case streamStateListen: 352 stream.setState(streamStateSynRecv) 353 default: 354 log.Printf("[ERR] Syn received for stream in state: %d", stream.state) 355 } 356 stream.mu.Unlock() 357 case muxPacketAck: 358 stream.mu.Lock() 359 switch stream.state { 360 case streamStateSynRecv: 361 stream.setState(streamStateEstablished) 362 case streamStateFinWait1: 363 stream.setState(streamStateFinWait2) 364 case streamStateLastAck: 365 stream.closeWriter() 366 fallthrough 367 case streamStateClosing: 368 stream.setState(streamStateClosed) 369 default: 370 log.Printf("[ERR] Ack received for stream in state: %d", stream.state) 371 } 372 stream.mu.Unlock() 373 case muxPacketSynAck: 374 stream.mu.Lock() 375 switch stream.state { 376 case streamStateSynSent: 377 stream.setState(streamStateEstablished) 378 default: 379 log.Printf("[ERR] SynAck received for stream in state: %d", stream.state) 380 } 381 stream.mu.Unlock() 382 case muxPacketFin: 383 stream.mu.Lock() 384 switch stream.state { 385 case streamStateEstablished: 386 stream.closeWriter() 387 stream.setState(streamStateCloseWait) 388 stream.write(muxPacketAck, nil) 389 case streamStateFinWait2: 390 stream.closeWriter() 391 stream.setState(streamStateClosed) 392 stream.write(muxPacketAck, nil) 393 case streamStateFinWait1: 394 stream.closeWriter() 395 stream.setState(streamStateClosing) 396 stream.write(muxPacketAck, nil) 397 default: 398 log.Printf("[ERR] Fin received for stream %d in state: %d", id, stream.state) 399 } 400 stream.mu.Unlock() 401 402 case muxPacketData: 403 stream.mu.Lock() 404 switch stream.state { 405 case streamStateFinWait1: 406 fallthrough 407 case streamStateFinWait2: 408 fallthrough 409 case streamStateEstablished: 410 if len(data) > 0 { 411 select { 412 case stream.writeCh <- data: 413 default: 414 panic(fmt.Sprintf( 415 "Failed to write data, buffer full for stream %d", id)) 416 } 417 } 418 default: 419 log.Printf("[ERR] Data received for stream in state: %d", stream.state) 420 } 421 stream.mu.Unlock() 422 } 423 } 424 } 425 426 func (m *MuxConn) write(from muxPacketFrom, id uint32, dataType muxPacketType, p []byte) (int, error) { 427 m.wlock.Lock() 428 defer m.wlock.Unlock() 429 430 if err := binary.Write(m.rwc, binary.BigEndian, from); err != nil { 431 return 0, err 432 } 433 if err := binary.Write(m.rwc, binary.BigEndian, id); err != nil { 434 return 0, err 435 } 436 if err := binary.Write(m.rwc, binary.BigEndian, byte(dataType)); err != nil { 437 return 0, err 438 } 439 if err := binary.Write(m.rwc, binary.BigEndian, int32(len(p))); err != nil { 440 return 0, err 441 } 442 443 // Write all the bytes. If we don't write all the bytes, report an error 444 var err error = nil 445 n := 0 446 for n < len(p) { 447 var n2 int 448 n2, err = m.rwc.Write(p[n:]) 449 n += n2 450 if err != nil { 451 log.Printf("[ERR] %p: Stream %d (%s) write error: %s", m, id, from, err) 452 break 453 } 454 } 455 456 return n, err 457 } 458 459 // Stream is a single stream of data and implements io.ReadWriteCloser. 460 // A Stream is full-duplex so you can write data as well as read data. 461 type Stream struct { 462 from muxPacketFrom 463 id uint32 464 mux *MuxConn 465 reader io.Reader 466 state streamState 467 stateChange map[chan<- streamState]struct{} 468 stateUpdated time.Time 469 mu sync.Mutex 470 writeCh chan<- []byte 471 } 472 473 type streamState byte 474 475 const ( 476 streamStateClosed streamState = iota 477 streamStateListen 478 streamStateSynRecv 479 streamStateSynSent 480 streamStateEstablished 481 streamStateFinWait1 482 streamStateFinWait2 483 streamStateCloseWait 484 streamStateClosing 485 streamStateLastAck 486 ) 487 488 func newStream(from muxPacketFrom, id uint32, m *MuxConn) *Stream { 489 // Create the stream object and channel where data will be sent to 490 dataR, dataW := io.Pipe() 491 writeCh := make(chan []byte, 4096) 492 493 // Set the data channel so we can write to it. 494 stream := &Stream{ 495 from: from, 496 id: id, 497 mux: m, 498 reader: dataR, 499 writeCh: writeCh, 500 stateChange: make(map[chan<- streamState]struct{}), 501 } 502 stream.setState(streamStateClosed) 503 504 // Start the goroutine that will read from the queue and write 505 // data out. 506 go func() { 507 defer dataW.Close() 508 509 for { 510 data := <-writeCh 511 if data == nil { 512 // A nil is a tombstone letting us know we're done 513 // accepting data. 514 return 515 } 516 517 if _, err := dataW.Write(data); err != nil { 518 return 519 } 520 } 521 }() 522 523 return stream 524 } 525 526 func (s *Stream) Close() error { 527 s.mu.Lock() 528 defer s.mu.Unlock() 529 530 if s.state != streamStateEstablished && s.state != streamStateCloseWait { 531 return fmt.Errorf("Stream in bad state: %d", s.state) 532 } 533 534 if s.state == streamStateEstablished { 535 s.setState(streamStateFinWait1) 536 } else { 537 s.setState(streamStateLastAck) 538 } 539 540 s.write(muxPacketFin, nil) 541 return nil 542 } 543 544 func (s *Stream) Read(p []byte) (int, error) { 545 return s.reader.Read(p) 546 } 547 548 func (s *Stream) Write(p []byte) (int, error) { 549 s.mu.Lock() 550 state := s.state 551 s.mu.Unlock() 552 553 if state != streamStateEstablished && state != streamStateCloseWait { 554 return 0, fmt.Errorf("Stream %d in bad state to send: %d", s.id, state) 555 } 556 557 return s.write(muxPacketData, p) 558 } 559 560 func (s *Stream) closeWriter() { 561 s.writeCh <- nil 562 } 563 564 func (s *Stream) setState(state streamState) { 565 //log.Printf("[TRACE] %p: Stream %d (%s) went to state %d", s.mux, s.id, s.from, state) 566 s.state = state 567 s.stateUpdated = time.Now().UTC() 568 for ch, _ := range s.stateChange { 569 select { 570 case ch <- state: 571 default: 572 } 573 } 574 } 575 576 func (s *Stream) waitState(target streamState) error { 577 // Register a state change listener to wait for changes 578 stateCh := make(chan streamState, 10) 579 s.stateChange[stateCh] = struct{}{} 580 s.mu.Unlock() 581 582 defer func() { 583 s.mu.Lock() 584 delete(s.stateChange, stateCh) 585 }() 586 587 state := <-stateCh 588 if state == target { 589 return nil 590 } else { 591 return fmt.Errorf("Stream %d went to bad state: %d", s.id, state) 592 } 593 } 594 595 func (s *Stream) write(dataType muxPacketType, p []byte) (int, error) { 596 return s.mux.write(s.from, s.id, dataType, p) 597 }