github.com/kaixiang/packer@v0.5.2-0.20140114230416-1f5786b0d7f1/packer/rpc/muxconn.go (about)

     1  package rpc
     2  
     3  import (
     4  	"encoding/binary"
     5  	"fmt"
     6  	"io"
     7  	"log"
     8  	"sync"
     9  	"time"
    10  )
    11  
    12  // MuxConn is able to multiplex multiple streams on top of any
    13  // io.ReadWriteCloser. These streams act like TCP connections (Dial, Accept,
    14  // Close, full duplex, etc.).
    15  //
    16  // The underlying io.ReadWriteCloser is expected to guarantee delivery
    17  // and ordering, such as TCP. Congestion control and such aren't implemented
    18  // by the streams, so that is also up to the underlying connection.
    19  //
    20  // MuxConn works using a fairly dumb multiplexing technique of simply
    21  // framing every piece of data sent into a prefix + data format. Streams
    22  // are established using a subset of the TCP protocol. Only a subset is
    23  // necessary since we assume ordering on the underlying RWC.
    24  type MuxConn struct {
    25  	curId         uint32
    26  	rwc           io.ReadWriteCloser
    27  	streamsAccept map[uint32]*Stream
    28  	streamsDial   map[uint32]*Stream
    29  	muAccept      sync.RWMutex
    30  	muDial        sync.RWMutex
    31  	wlock         sync.Mutex
    32  	doneCh        chan struct{}
    33  }
    34  
    35  type muxPacketFrom byte
    36  type muxPacketType byte
    37  
    38  const (
    39  	muxPacketFromAccept muxPacketFrom = iota
    40  	muxPacketFromDial
    41  )
    42  
    43  const (
    44  	muxPacketSyn muxPacketType = iota
    45  	muxPacketSynAck
    46  	muxPacketAck
    47  	muxPacketFin
    48  	muxPacketData
    49  )
    50  
    51  func (f muxPacketFrom) String() string {
    52  	switch f {
    53  	case muxPacketFromAccept:
    54  		return "accept"
    55  	case muxPacketFromDial:
    56  		return "dial"
    57  	default:
    58  		panic("unknown from type")
    59  	}
    60  }
    61  
    62  // Create a new MuxConn around any io.ReadWriteCloser.
    63  func NewMuxConn(rwc io.ReadWriteCloser) *MuxConn {
    64  	m := &MuxConn{
    65  		rwc:           rwc,
    66  		streamsAccept: make(map[uint32]*Stream),
    67  		streamsDial:   make(map[uint32]*Stream),
    68  		doneCh:        make(chan struct{}),
    69  	}
    70  
    71  	go m.cleaner()
    72  	go m.loop()
    73  
    74  	return m
    75  }
    76  
    77  // Close closes the underlying io.ReadWriteCloser. This will also close
    78  // all streams that are open.
    79  func (m *MuxConn) Close() error {
    80  	m.muAccept.Lock()
    81  	m.muDial.Lock()
    82  	defer m.muAccept.Unlock()
    83  	defer m.muDial.Unlock()
    84  
    85  	// Close all the streams
    86  	for _, w := range m.streamsAccept {
    87  		w.Close()
    88  	}
    89  	for _, w := range m.streamsDial {
    90  		w.Close()
    91  	}
    92  	m.streamsAccept = make(map[uint32]*Stream)
    93  	m.streamsDial = make(map[uint32]*Stream)
    94  
    95  	// Close the actual connection. This will also force the loop
    96  	// to end since it'll read EOF or closed connection.
    97  	return m.rwc.Close()
    98  }
    99  
   100  // Accept accepts a multiplexed connection with the given ID. This
   101  // will block until a request is made to connect.
   102  func (m *MuxConn) Accept(id uint32) (io.ReadWriteCloser, error) {
   103  	//log.Printf("[TRACE] %p: Accept on stream ID: %d", m, id)
   104  
   105  	// Get the stream. It is okay if it is already in the list of streams
   106  	// because we may have prematurely received a syn for it.
   107  	m.muAccept.Lock()
   108  	stream, ok := m.streamsAccept[id]
   109  	if !ok {
   110  		stream = newStream(muxPacketFromAccept, id, m)
   111  		m.streamsAccept[id] = stream
   112  	}
   113  	m.muAccept.Unlock()
   114  
   115  	stream.mu.Lock()
   116  	defer stream.mu.Unlock()
   117  
   118  	// If the stream isn't closed, then it is already open somehow
   119  	if stream.state != streamStateSynRecv && stream.state != streamStateClosed {
   120  		panic(fmt.Sprintf(
   121  			"Stream %d already open in bad state: %d", id, stream.state))
   122  	}
   123  
   124  	if stream.state == streamStateClosed {
   125  		// Go into the listening state and wait for a syn
   126  		stream.setState(streamStateListen)
   127  		if err := stream.waitState(streamStateSynRecv); err != nil {
   128  			return nil, err
   129  		}
   130  	}
   131  
   132  	if stream.state == streamStateSynRecv {
   133  		// Send a syn-ack
   134  		if _, err := stream.write(muxPacketSynAck, nil); err != nil {
   135  			return nil, err
   136  		}
   137  	}
   138  
   139  	if err := stream.waitState(streamStateEstablished); err != nil {
   140  		return nil, err
   141  	}
   142  
   143  	return stream, nil
   144  }
   145  
   146  // Dial opens a connection to the remote end using the given stream ID.
   147  // An Accept on the remote end will only work with if the IDs match.
   148  func (m *MuxConn) Dial(id uint32) (io.ReadWriteCloser, error) {
   149  	//log.Printf("[TRACE] %p: Dial on stream ID: %d", m, id)
   150  
   151  	m.muDial.Lock()
   152  
   153  	// If we have any streams with this ID, then it is a failure. The
   154  	// reaper should clear out old streams once in awhile.
   155  	if stream, ok := m.streamsDial[id]; ok {
   156  		m.muDial.Unlock()
   157  		panic(fmt.Sprintf(
   158  			"Stream %d already open for dial. State: %d",
   159  			id, stream.state))
   160  	}
   161  
   162  	// Create the new stream and put it in our list. We can then
   163  	// unlock because dialing will no longer be allowed on that ID.
   164  	stream := newStream(muxPacketFromDial, id, m)
   165  	m.streamsDial[id] = stream
   166  
   167  	// Don't let anyone else mess with this stream
   168  	stream.mu.Lock()
   169  	defer stream.mu.Unlock()
   170  
   171  	m.muDial.Unlock()
   172  
   173  	// Open a connection
   174  	if _, err := stream.write(muxPacketSyn, nil); err != nil {
   175  		return nil, err
   176  	}
   177  
   178  	// It is safe to set the state after the write above because
   179  	// we hold the stream lock.
   180  	stream.setState(streamStateSynSent)
   181  
   182  	if err := stream.waitState(streamStateEstablished); err != nil {
   183  		return nil, err
   184  	}
   185  
   186  	stream.write(muxPacketAck, nil)
   187  	return stream, nil
   188  }
   189  
   190  // NextId returns the next available listen stream ID that isn't currently
   191  // taken.
   192  func (m *MuxConn) NextId() uint32 {
   193  	m.muAccept.Lock()
   194  	defer m.muAccept.Unlock()
   195  
   196  	for {
   197  		// We never use stream ID 0 because 0 is the zero value of a uint32
   198  		// and we want to reserve that for "not in use"
   199  		if m.curId == 0 {
   200  			m.curId = 1
   201  		}
   202  
   203  		result := m.curId
   204  		m.curId += 1
   205  		if _, ok := m.streamsAccept[result]; !ok {
   206  			return result
   207  		}
   208  	}
   209  }
   210  
   211  func (m *MuxConn) cleaner() {
   212  	checks := []struct {
   213  		Map  *map[uint32]*Stream
   214  		Lock *sync.RWMutex
   215  	}{
   216  		{&m.streamsAccept, &m.muAccept},
   217  		{&m.streamsDial, &m.muDial},
   218  	}
   219  
   220  	for {
   221  		done := false
   222  		select {
   223  		case <-time.After(500 * time.Millisecond):
   224  		case <-m.doneCh:
   225  			done = true
   226  		}
   227  
   228  		for _, check := range checks {
   229  			check.Lock.Lock()
   230  			for id, s := range *check.Map {
   231  				s.mu.Lock()
   232  
   233  				if done && s.state != streamStateClosed {
   234  					s.closeWriter()
   235  				}
   236  
   237  				if s.state == streamStateClosed {
   238  					// Only clean up the streams that have been closed
   239  					// for a certain amount of time.
   240  					since := time.Now().UTC().Sub(s.stateUpdated)
   241  					if since > 2*time.Second {
   242  						delete(*check.Map, id)
   243  					}
   244  				}
   245  
   246  				s.mu.Unlock()
   247  			}
   248  			check.Lock.Unlock()
   249  		}
   250  
   251  		if done {
   252  			return
   253  		}
   254  	}
   255  }
   256  
   257  func (m *MuxConn) loop() {
   258  	// Force close every stream that we know about when we exit so
   259  	// that they all read EOF and don't block forever.
   260  	defer func() {
   261  		log.Printf("[INFO] Mux connection loop exiting")
   262  		close(m.doneCh)
   263  	}()
   264  
   265  	var from muxPacketFrom
   266  	var id uint32
   267  	var packetType muxPacketType
   268  	var length int32
   269  	for {
   270  		if err := binary.Read(m.rwc, binary.BigEndian, &from); err != nil {
   271  			log.Printf("[ERR] Error reading stream direction: %s", err)
   272  			return
   273  		}
   274  		if err := binary.Read(m.rwc, binary.BigEndian, &id); err != nil {
   275  			log.Printf("[ERR] Error reading stream ID: %s", err)
   276  			return
   277  		}
   278  		if err := binary.Read(m.rwc, binary.BigEndian, &packetType); err != nil {
   279  			log.Printf("[ERR] Error reading packet type: %s", err)
   280  			return
   281  		}
   282  		if err := binary.Read(m.rwc, binary.BigEndian, &length); err != nil {
   283  			log.Printf("[ERR] Error reading length: %s", err)
   284  			return
   285  		}
   286  
   287  		// TODO(mitchellh): probably would be better to re-use a buffer...
   288  		data := make([]byte, length)
   289  		n := 0
   290  		for n < int(length) {
   291  			if n2, err := m.rwc.Read(data[n:]); err != nil {
   292  				log.Printf("[ERR] Error reading data: %s", err)
   293  				return
   294  			} else {
   295  				n += n2
   296  			}
   297  		}
   298  
   299  		// Get the proper stream. Note that the map we look into is
   300  		// opposite the "from" because if the dial side is talking to
   301  		// us, we need to look into the accept map, and so on.
   302  		//
   303  		// Note: we also switch the "from" value so that logging
   304  		// below is correct.
   305  		var stream *Stream
   306  		switch from {
   307  		case muxPacketFromDial:
   308  			m.muAccept.Lock()
   309  			stream = m.streamsAccept[id]
   310  			m.muAccept.Unlock()
   311  
   312  			from = muxPacketFromAccept
   313  		case muxPacketFromAccept:
   314  			m.muDial.Lock()
   315  			stream = m.streamsDial[id]
   316  			m.muDial.Unlock()
   317  
   318  			from = muxPacketFromDial
   319  		default:
   320  			panic(fmt.Sprintf("Unknown stream direction: %d", from))
   321  		}
   322  
   323  		if stream == nil && packetType != muxPacketSyn {
   324  			log.Printf(
   325  				"[WARN] %p: Non-existent stream %d (%s) received packer %d",
   326  				m, id, from, packetType)
   327  			continue
   328  		}
   329  
   330  		//log.Printf("[TRACE] %p: Stream %d (%s) received packet %d", m, id, from, packetType)
   331  		switch packetType {
   332  		case muxPacketSyn:
   333  			// If the stream is nil, this is the only case where we'll
   334  			// automatically create the stream struct.
   335  			if stream == nil {
   336  				var ok bool
   337  
   338  				m.muAccept.Lock()
   339  				stream, ok = m.streamsAccept[id]
   340  				if !ok {
   341  					stream = newStream(muxPacketFromAccept, id, m)
   342  					m.streamsAccept[id] = stream
   343  				}
   344  				m.muAccept.Unlock()
   345  			}
   346  
   347  			stream.mu.Lock()
   348  			switch stream.state {
   349  			case streamStateClosed:
   350  				fallthrough
   351  			case streamStateListen:
   352  				stream.setState(streamStateSynRecv)
   353  			default:
   354  				log.Printf("[ERR] Syn received for stream in state: %d", stream.state)
   355  			}
   356  			stream.mu.Unlock()
   357  		case muxPacketAck:
   358  			stream.mu.Lock()
   359  			switch stream.state {
   360  			case streamStateSynRecv:
   361  				stream.setState(streamStateEstablished)
   362  			case streamStateFinWait1:
   363  				stream.setState(streamStateFinWait2)
   364  			case streamStateLastAck:
   365  				stream.closeWriter()
   366  				fallthrough
   367  			case streamStateClosing:
   368  				stream.setState(streamStateClosed)
   369  			default:
   370  				log.Printf("[ERR] Ack received for stream in state: %d", stream.state)
   371  			}
   372  			stream.mu.Unlock()
   373  		case muxPacketSynAck:
   374  			stream.mu.Lock()
   375  			switch stream.state {
   376  			case streamStateSynSent:
   377  				stream.setState(streamStateEstablished)
   378  			default:
   379  				log.Printf("[ERR] SynAck received for stream in state: %d", stream.state)
   380  			}
   381  			stream.mu.Unlock()
   382  		case muxPacketFin:
   383  			stream.mu.Lock()
   384  			switch stream.state {
   385  			case streamStateEstablished:
   386  				stream.closeWriter()
   387  				stream.setState(streamStateCloseWait)
   388  				stream.write(muxPacketAck, nil)
   389  			case streamStateFinWait2:
   390  				stream.closeWriter()
   391  				stream.setState(streamStateClosed)
   392  				stream.write(muxPacketAck, nil)
   393  			case streamStateFinWait1:
   394  				stream.closeWriter()
   395  				stream.setState(streamStateClosing)
   396  				stream.write(muxPacketAck, nil)
   397  			default:
   398  				log.Printf("[ERR] Fin received for stream %d in state: %d", id, stream.state)
   399  			}
   400  			stream.mu.Unlock()
   401  
   402  		case muxPacketData:
   403  			stream.mu.Lock()
   404  			switch stream.state {
   405  			case streamStateFinWait1:
   406  				fallthrough
   407  			case streamStateFinWait2:
   408  				fallthrough
   409  			case streamStateEstablished:
   410  				if len(data) > 0 {
   411  					select {
   412  					case stream.writeCh <- data:
   413  					default:
   414  						panic(fmt.Sprintf(
   415  							"Failed to write data, buffer full for stream %d", id))
   416  					}
   417  				}
   418  			default:
   419  				log.Printf("[ERR] Data received for stream in state: %d", stream.state)
   420  			}
   421  			stream.mu.Unlock()
   422  		}
   423  	}
   424  }
   425  
   426  func (m *MuxConn) write(from muxPacketFrom, id uint32, dataType muxPacketType, p []byte) (int, error) {
   427  	m.wlock.Lock()
   428  	defer m.wlock.Unlock()
   429  
   430  	if err := binary.Write(m.rwc, binary.BigEndian, from); err != nil {
   431  		return 0, err
   432  	}
   433  	if err := binary.Write(m.rwc, binary.BigEndian, id); err != nil {
   434  		return 0, err
   435  	}
   436  	if err := binary.Write(m.rwc, binary.BigEndian, byte(dataType)); err != nil {
   437  		return 0, err
   438  	}
   439  	if err := binary.Write(m.rwc, binary.BigEndian, int32(len(p))); err != nil {
   440  		return 0, err
   441  	}
   442  
   443  	// Write all the bytes. If we don't write all the bytes, report an error
   444  	var err error = nil
   445  	n := 0
   446  	for n < len(p) {
   447  		var n2 int
   448  		n2, err = m.rwc.Write(p[n:])
   449  		n += n2
   450  		if err != nil {
   451  			log.Printf("[ERR] %p: Stream %d (%s) write error: %s", m, id, from, err)
   452  			break
   453  		}
   454  	}
   455  
   456  	return n, err
   457  }
   458  
   459  // Stream is a single stream of data and implements io.ReadWriteCloser.
   460  // A Stream is full-duplex so you can write data as well as read data.
   461  type Stream struct {
   462  	from         muxPacketFrom
   463  	id           uint32
   464  	mux          *MuxConn
   465  	reader       io.Reader
   466  	state        streamState
   467  	stateChange  map[chan<- streamState]struct{}
   468  	stateUpdated time.Time
   469  	mu           sync.Mutex
   470  	writeCh      chan<- []byte
   471  }
   472  
   473  type streamState byte
   474  
   475  const (
   476  	streamStateClosed streamState = iota
   477  	streamStateListen
   478  	streamStateSynRecv
   479  	streamStateSynSent
   480  	streamStateEstablished
   481  	streamStateFinWait1
   482  	streamStateFinWait2
   483  	streamStateCloseWait
   484  	streamStateClosing
   485  	streamStateLastAck
   486  )
   487  
   488  func newStream(from muxPacketFrom, id uint32, m *MuxConn) *Stream {
   489  	// Create the stream object and channel where data will be sent to
   490  	dataR, dataW := io.Pipe()
   491  	writeCh := make(chan []byte, 4096)
   492  
   493  	// Set the data channel so we can write to it.
   494  	stream := &Stream{
   495  		from:        from,
   496  		id:          id,
   497  		mux:         m,
   498  		reader:      dataR,
   499  		writeCh:     writeCh,
   500  		stateChange: make(map[chan<- streamState]struct{}),
   501  	}
   502  	stream.setState(streamStateClosed)
   503  
   504  	// Start the goroutine that will read from the queue and write
   505  	// data out.
   506  	go func() {
   507  		defer dataW.Close()
   508  
   509  		for {
   510  			data := <-writeCh
   511  			if data == nil {
   512  				// A nil is a tombstone letting us know we're done
   513  				// accepting data.
   514  				return
   515  			}
   516  
   517  			if _, err := dataW.Write(data); err != nil {
   518  				return
   519  			}
   520  		}
   521  	}()
   522  
   523  	return stream
   524  }
   525  
   526  func (s *Stream) Close() error {
   527  	s.mu.Lock()
   528  	defer s.mu.Unlock()
   529  
   530  	if s.state != streamStateEstablished && s.state != streamStateCloseWait {
   531  		return fmt.Errorf("Stream in bad state: %d", s.state)
   532  	}
   533  
   534  	if s.state == streamStateEstablished {
   535  		s.setState(streamStateFinWait1)
   536  	} else {
   537  		s.setState(streamStateLastAck)
   538  	}
   539  
   540  	s.write(muxPacketFin, nil)
   541  	return nil
   542  }
   543  
   544  func (s *Stream) Read(p []byte) (int, error) {
   545  	return s.reader.Read(p)
   546  }
   547  
   548  func (s *Stream) Write(p []byte) (int, error) {
   549  	s.mu.Lock()
   550  	state := s.state
   551  	s.mu.Unlock()
   552  
   553  	if state != streamStateEstablished && state != streamStateCloseWait {
   554  		return 0, fmt.Errorf("Stream %d in bad state to send: %d", s.id, state)
   555  	}
   556  
   557  	return s.write(muxPacketData, p)
   558  }
   559  
   560  func (s *Stream) closeWriter() {
   561  	s.writeCh <- nil
   562  }
   563  
   564  func (s *Stream) setState(state streamState) {
   565  	//log.Printf("[TRACE] %p: Stream %d (%s) went to state %d", s.mux, s.id, s.from, state)
   566  	s.state = state
   567  	s.stateUpdated = time.Now().UTC()
   568  	for ch, _ := range s.stateChange {
   569  		select {
   570  		case ch <- state:
   571  		default:
   572  		}
   573  	}
   574  }
   575  
   576  func (s *Stream) waitState(target streamState) error {
   577  	// Register a state change listener to wait for changes
   578  	stateCh := make(chan streamState, 10)
   579  	s.stateChange[stateCh] = struct{}{}
   580  	s.mu.Unlock()
   581  
   582  	defer func() {
   583  		s.mu.Lock()
   584  		delete(s.stateChange, stateCh)
   585  	}()
   586  
   587  	state := <-stateCh
   588  	if state == target {
   589  		return nil
   590  	} else {
   591  		return fmt.Errorf("Stream %d went to bad state: %d", s.id, state)
   592  	}
   593  }
   594  
   595  func (s *Stream) write(dataType muxPacketType, p []byte) (int, error) {
   596  	return s.mux.write(s.from, s.id, dataType, p)
   597  }