github.com/homburg/packer@v0.6.1-0.20140528012651-1dcaf1716848/packer/rpc/muxconn.go (about)

     1  package rpc
     2  
     3  import (
     4  	"encoding/binary"
     5  	"fmt"
     6  	"io"
     7  	"log"
     8  	"sync"
     9  	"time"
    10  )
    11  
    12  // MuxConn is able to multiplex multiple streams on top of any
    13  // io.ReadWriteCloser. These streams act like TCP connections (Dial, Accept,
    14  // Close, full duplex, etc.).
    15  //
    16  // The underlying io.ReadWriteCloser is expected to guarantee delivery
    17  // and ordering, such as TCP. Congestion control and such aren't implemented
    18  // by the streams, so that is also up to the underlying connection.
    19  //
    20  // MuxConn works using a fairly dumb multiplexing technique of simply
    21  // framing every piece of data sent into a prefix + data format. Streams
    22  // are established using a subset of the TCP protocol. Only a subset is
    23  // necessary since we assume ordering on the underlying RWC.
    24  type MuxConn struct {
    25  	curId         uint32
    26  	rwc           io.ReadWriteCloser
    27  	streamsAccept map[uint32]*Stream
    28  	streamsDial   map[uint32]*Stream
    29  	muAccept      sync.RWMutex
    30  	muDial        sync.RWMutex
    31  	wlock         sync.Mutex
    32  	doneCh        chan struct{}
    33  }
    34  
    35  type muxPacketFrom byte
    36  type muxPacketType byte
    37  
    38  const (
    39  	muxPacketFromAccept muxPacketFrom = iota
    40  	muxPacketFromDial
    41  )
    42  
    43  const (
    44  	muxPacketSyn muxPacketType = iota
    45  	muxPacketSynAck
    46  	muxPacketAck
    47  	muxPacketFin
    48  	muxPacketData
    49  )
    50  
    51  func (f muxPacketFrom) String() string {
    52  	switch f {
    53  	case muxPacketFromAccept:
    54  		return "accept"
    55  	case muxPacketFromDial:
    56  		return "dial"
    57  	default:
    58  		panic("unknown from type")
    59  	}
    60  }
    61  
    62  // Create a new MuxConn around any io.ReadWriteCloser.
    63  func NewMuxConn(rwc io.ReadWriteCloser) *MuxConn {
    64  	m := &MuxConn{
    65  		rwc:           rwc,
    66  		streamsAccept: make(map[uint32]*Stream),
    67  		streamsDial:   make(map[uint32]*Stream),
    68  		doneCh:        make(chan struct{}),
    69  	}
    70  
    71  	go m.cleaner()
    72  	go m.loop()
    73  
    74  	return m
    75  }
    76  
    77  // Close closes the underlying io.ReadWriteCloser. This will also close
    78  // all streams that are open.
    79  func (m *MuxConn) Close() error {
    80  	m.muAccept.Lock()
    81  	m.muDial.Lock()
    82  	defer m.muAccept.Unlock()
    83  	defer m.muDial.Unlock()
    84  
    85  	// Close all the streams
    86  	for _, w := range m.streamsAccept {
    87  		w.Close()
    88  	}
    89  	for _, w := range m.streamsDial {
    90  		w.Close()
    91  	}
    92  	m.streamsAccept = make(map[uint32]*Stream)
    93  	m.streamsDial = make(map[uint32]*Stream)
    94  
    95  	// Close the actual connection. This will also force the loop
    96  	// to end since it'll read EOF or closed connection.
    97  	return m.rwc.Close()
    98  }
    99  
   100  // Accept accepts a multiplexed connection with the given ID. This
   101  // will block until a request is made to connect.
   102  func (m *MuxConn) Accept(id uint32) (io.ReadWriteCloser, error) {
   103  	//log.Printf("[TRACE] %p: Accept on stream ID: %d", m, id)
   104  
   105  	// Get the stream. It is okay if it is already in the list of streams
   106  	// because we may have prematurely received a syn for it.
   107  	m.muAccept.Lock()
   108  	stream, ok := m.streamsAccept[id]
   109  	if !ok {
   110  		stream = newStream(muxPacketFromAccept, id, m)
   111  		m.streamsAccept[id] = stream
   112  	}
   113  	m.muAccept.Unlock()
   114  
   115  	stream.mu.Lock()
   116  	defer stream.mu.Unlock()
   117  
   118  	// If the stream isn't closed, then it is already open somehow
   119  	if stream.state != streamStateSynRecv && stream.state != streamStateClosed {
   120  		panic(fmt.Sprintf(
   121  			"Stream %d already open in bad state: %d", id, stream.state))
   122  	}
   123  
   124  	if stream.state == streamStateClosed {
   125  		// Go into the listening state and wait for a syn
   126  		stream.setState(streamStateListen)
   127  		if err := stream.waitState(streamStateSynRecv); err != nil {
   128  			return nil, err
   129  		}
   130  	}
   131  
   132  	if stream.state == streamStateSynRecv {
   133  		// Send a syn-ack
   134  		if _, err := stream.write(muxPacketSynAck, nil); err != nil {
   135  			return nil, err
   136  		}
   137  	}
   138  
   139  	if err := stream.waitState(streamStateEstablished); err != nil {
   140  		return nil, err
   141  	}
   142  
   143  	return stream, nil
   144  }
   145  
   146  // Dial opens a connection to the remote end using the given stream ID.
   147  // An Accept on the remote end will only work with if the IDs match.
   148  func (m *MuxConn) Dial(id uint32) (io.ReadWriteCloser, error) {
   149  	//log.Printf("[TRACE] %p: Dial on stream ID: %d", m, id)
   150  
   151  	m.muDial.Lock()
   152  
   153  	// If we have any streams with this ID, then it is a failure. The
   154  	// reaper should clear out old streams once in awhile.
   155  	if stream, ok := m.streamsDial[id]; ok {
   156  		m.muDial.Unlock()
   157  		panic(fmt.Sprintf(
   158  			"Stream %d already open for dial. State: %d",
   159  			id, stream.state))
   160  	}
   161  
   162  	// Create the new stream and put it in our list. We can then
   163  	// unlock because dialing will no longer be allowed on that ID.
   164  	stream := newStream(muxPacketFromDial, id, m)
   165  	m.streamsDial[id] = stream
   166  
   167  	// Don't let anyone else mess with this stream
   168  	stream.mu.Lock()
   169  	defer stream.mu.Unlock()
   170  
   171  	m.muDial.Unlock()
   172  
   173  	// Open a connection
   174  	if _, err := stream.write(muxPacketSyn, nil); err != nil {
   175  		return nil, err
   176  	}
   177  
   178  	// It is safe to set the state after the write above because
   179  	// we hold the stream lock.
   180  	stream.setState(streamStateSynSent)
   181  
   182  	if err := stream.waitState(streamStateEstablished); err != nil {
   183  		return nil, err
   184  	}
   185  
   186  	stream.write(muxPacketAck, nil)
   187  	return stream, nil
   188  }
   189  
   190  // NextId returns the next available listen stream ID that isn't currently
   191  // taken.
   192  func (m *MuxConn) NextId() uint32 {
   193  	m.muAccept.Lock()
   194  	defer m.muAccept.Unlock()
   195  
   196  	for {
   197  		// We never use stream ID 0 because 0 is the zero value of a uint32
   198  		// and we want to reserve that for "not in use"
   199  		if m.curId == 0 {
   200  			m.curId = 1
   201  		}
   202  
   203  		result := m.curId
   204  		m.curId += 1
   205  		if _, ok := m.streamsAccept[result]; !ok {
   206  			return result
   207  		}
   208  	}
   209  }
   210  
   211  func (m *MuxConn) cleaner() {
   212  	checks := []struct {
   213  		Map  *map[uint32]*Stream
   214  		Lock *sync.RWMutex
   215  	}{
   216  		{&m.streamsAccept, &m.muAccept},
   217  		{&m.streamsDial, &m.muDial},
   218  	}
   219  
   220  	for {
   221  		done := false
   222  		select {
   223  		case <-time.After(500 * time.Millisecond):
   224  		case <-m.doneCh:
   225  			done = true
   226  		}
   227  
   228  		for _, check := range checks {
   229  			check.Lock.Lock()
   230  			for id, s := range *check.Map {
   231  				s.mu.Lock()
   232  
   233  				if done && s.state != streamStateClosed {
   234  					s.closeWriter()
   235  				}
   236  
   237  				if s.state == streamStateClosed {
   238  					// Only clean up the streams that have been closed
   239  					// for a certain amount of time.
   240  					since := time.Now().UTC().Sub(s.stateUpdated)
   241  					if since > 2*time.Second {
   242  						delete(*check.Map, id)
   243  					}
   244  				}
   245  
   246  				s.mu.Unlock()
   247  			}
   248  			check.Lock.Unlock()
   249  		}
   250  
   251  		if done {
   252  			return
   253  		}
   254  	}
   255  }
   256  
   257  func (m *MuxConn) loop() {
   258  	// Force close every stream that we know about when we exit so
   259  	// that they all read EOF and don't block forever.
   260  	defer func() {
   261  		log.Printf("[INFO] Mux connection loop exiting")
   262  		close(m.doneCh)
   263  	}()
   264  
   265  	var from muxPacketFrom
   266  	var id uint32
   267  	var packetType muxPacketType
   268  	var length int32
   269  	for {
   270  		if err := binary.Read(m.rwc, binary.BigEndian, &from); err != nil {
   271  			log.Printf("[ERR] Error reading stream direction: %s", err)
   272  			return
   273  		}
   274  		if err := binary.Read(m.rwc, binary.BigEndian, &id); err != nil {
   275  			log.Printf("[ERR] Error reading stream ID: %s", err)
   276  			return
   277  		}
   278  		if err := binary.Read(m.rwc, binary.BigEndian, &packetType); err != nil {
   279  			log.Printf("[ERR] Error reading packet type: %s", err)
   280  			return
   281  		}
   282  		if err := binary.Read(m.rwc, binary.BigEndian, &length); err != nil {
   283  			log.Printf("[ERR] Error reading length: %s", err)
   284  			return
   285  		}
   286  
   287  		// TODO(mitchellh): probably would be better to re-use a buffer...
   288  		data := make([]byte, length)
   289  		n := 0
   290  		for n < int(length) {
   291  			if n2, err := m.rwc.Read(data[n:]); err != nil {
   292  				log.Printf("[ERR] Error reading data: %s", err)
   293  				return
   294  			} else {
   295  				n += n2
   296  			}
   297  		}
   298  
   299  		// Get the proper stream. Note that the map we look into is
   300  		// opposite the "from" because if the dial side is talking to
   301  		// us, we need to look into the accept map, and so on.
   302  		//
   303  		// Note: we also switch the "from" value so that logging
   304  		// below is correct.
   305  		var stream *Stream
   306  		switch from {
   307  		case muxPacketFromDial:
   308  			m.muAccept.Lock()
   309  			stream = m.streamsAccept[id]
   310  			m.muAccept.Unlock()
   311  
   312  			from = muxPacketFromAccept
   313  		case muxPacketFromAccept:
   314  			m.muDial.Lock()
   315  			stream = m.streamsDial[id]
   316  			m.muDial.Unlock()
   317  
   318  			from = muxPacketFromDial
   319  		default:
   320  			panic(fmt.Sprintf("Unknown stream direction: %d", from))
   321  		}
   322  
   323  		if stream == nil && packetType != muxPacketSyn {
   324  			log.Printf(
   325  				"[WARN] %p: Non-existent stream %d (%s) received packer %d",
   326  				m, id, from, packetType)
   327  			continue
   328  		}
   329  
   330  		//log.Printf("[TRACE] %p: Stream %d (%s) received packet %d", m, id, from, packetType)
   331  		switch packetType {
   332  		case muxPacketSyn:
   333  			// If the stream is nil, this is the only case where we'll
   334  			// automatically create the stream struct.
   335  			if stream == nil {
   336  				var ok bool
   337  
   338  				m.muAccept.Lock()
   339  				stream, ok = m.streamsAccept[id]
   340  				if !ok {
   341  					stream = newStream(muxPacketFromAccept, id, m)
   342  					m.streamsAccept[id] = stream
   343  				}
   344  				m.muAccept.Unlock()
   345  			}
   346  
   347  			stream.mu.Lock()
   348  			switch stream.state {
   349  			case streamStateClosed:
   350  				fallthrough
   351  			case streamStateListen:
   352  				stream.setState(streamStateSynRecv)
   353  			default:
   354  				log.Printf("[ERR] Syn received for stream in state: %d", stream.state)
   355  			}
   356  			stream.mu.Unlock()
   357  		case muxPacketAck:
   358  			stream.mu.Lock()
   359  			switch stream.state {
   360  			case streamStateSynRecv:
   361  				stream.setState(streamStateEstablished)
   362  			case streamStateFinWait1:
   363  				stream.setState(streamStateFinWait2)
   364  			case streamStateLastAck:
   365  				stream.closeWriter()
   366  				fallthrough
   367  			case streamStateClosing:
   368  				stream.setState(streamStateClosed)
   369  			default:
   370  				log.Printf("[ERR] Ack received for stream in state: %d", stream.state)
   371  			}
   372  			stream.mu.Unlock()
   373  		case muxPacketSynAck:
   374  			stream.mu.Lock()
   375  			switch stream.state {
   376  			case streamStateSynSent:
   377  				stream.setState(streamStateEstablished)
   378  			default:
   379  				log.Printf("[ERR] SynAck received for stream in state: %d", stream.state)
   380  			}
   381  			stream.mu.Unlock()
   382  		case muxPacketFin:
   383  			stream.mu.Lock()
   384  			switch stream.state {
   385  			case streamStateEstablished:
   386  				stream.closeWriter()
   387  				stream.setState(streamStateCloseWait)
   388  				stream.write(muxPacketAck, nil)
   389  			case streamStateFinWait2:
   390  				stream.closeWriter()
   391  				stream.setState(streamStateClosed)
   392  				stream.write(muxPacketAck, nil)
   393  			case streamStateFinWait1:
   394  				stream.closeWriter()
   395  				stream.setState(streamStateClosing)
   396  				stream.write(muxPacketAck, nil)
   397  			default:
   398  				log.Printf("[ERR] Fin received for stream %d in state: %d", id, stream.state)
   399  			}
   400  			stream.mu.Unlock()
   401  
   402  		case muxPacketData:
   403  			stream.mu.Lock()
   404  			switch stream.state {
   405  			case streamStateFinWait1:
   406  				fallthrough
   407  			case streamStateFinWait2:
   408  				fallthrough
   409  			case streamStateEstablished:
   410  				if len(data) > 0 && stream.writeCh != nil {
   411  					//log.Printf("[TRACE] %p: Stream %d (%s) WRITE-START", m, id, from)
   412  					stream.writeCh <- data
   413  					//log.Printf("[TRACE] %p: Stream %d (%s) WRITE-END", m, id, from)
   414  				}
   415  			default:
   416  				log.Printf("[ERR] Data received for stream in state: %d", stream.state)
   417  			}
   418  			stream.mu.Unlock()
   419  		}
   420  	}
   421  }
   422  
   423  func (m *MuxConn) write(from muxPacketFrom, id uint32, dataType muxPacketType, p []byte) (int, error) {
   424  	m.wlock.Lock()
   425  	defer m.wlock.Unlock()
   426  
   427  	if err := binary.Write(m.rwc, binary.BigEndian, from); err != nil {
   428  		return 0, err
   429  	}
   430  	if err := binary.Write(m.rwc, binary.BigEndian, id); err != nil {
   431  		return 0, err
   432  	}
   433  	if err := binary.Write(m.rwc, binary.BigEndian, byte(dataType)); err != nil {
   434  		return 0, err
   435  	}
   436  	if err := binary.Write(m.rwc, binary.BigEndian, int32(len(p))); err != nil {
   437  		return 0, err
   438  	}
   439  
   440  	// Write all the bytes. If we don't write all the bytes, report an error
   441  	var err error = nil
   442  	n := 0
   443  	for n < len(p) {
   444  		var n2 int
   445  		n2, err = m.rwc.Write(p[n:])
   446  		n += n2
   447  		if err != nil {
   448  			log.Printf("[ERR] %p: Stream %d (%s) write error: %s", m, id, from, err)
   449  			break
   450  		}
   451  	}
   452  
   453  	return n, err
   454  }
   455  
   456  // Stream is a single stream of data and implements io.ReadWriteCloser.
   457  // A Stream is full-duplex so you can write data as well as read data.
   458  type Stream struct {
   459  	from         muxPacketFrom
   460  	id           uint32
   461  	mux          *MuxConn
   462  	reader       io.Reader
   463  	state        streamState
   464  	stateChange  map[chan<- streamState]struct{}
   465  	stateUpdated time.Time
   466  	mu           sync.Mutex
   467  	writeCh      chan<- []byte
   468  }
   469  
   470  type streamState byte
   471  
   472  const (
   473  	streamStateClosed streamState = iota
   474  	streamStateListen
   475  	streamStateSynRecv
   476  	streamStateSynSent
   477  	streamStateEstablished
   478  	streamStateFinWait1
   479  	streamStateFinWait2
   480  	streamStateCloseWait
   481  	streamStateClosing
   482  	streamStateLastAck
   483  )
   484  
   485  func newStream(from muxPacketFrom, id uint32, m *MuxConn) *Stream {
   486  	// Create the stream object and channel where data will be sent to
   487  	dataR, dataW := io.Pipe()
   488  	writeCh := make(chan []byte, 4096)
   489  
   490  	// Set the data channel so we can write to it.
   491  	stream := &Stream{
   492  		from:        from,
   493  		id:          id,
   494  		mux:         m,
   495  		reader:      dataR,
   496  		writeCh:     writeCh,
   497  		stateChange: make(map[chan<- streamState]struct{}),
   498  	}
   499  	stream.setState(streamStateClosed)
   500  
   501  	// Start the goroutine that will read from the queue and write
   502  	// data out.
   503  	go func() {
   504  		defer dataW.Close()
   505  
   506  		drain := false
   507  		for {
   508  			data := <-writeCh
   509  			if data == nil {
   510  				// A nil is a tombstone letting us know we're done
   511  				// accepting data.
   512  				return
   513  			}
   514  
   515  			if drain {
   516  				// We're draining, meaning we're just waiting for the
   517  				// write channel to close.
   518  				continue
   519  			}
   520  
   521  			if _, err := dataW.Write(data); err != nil {
   522  				drain = true
   523  			}
   524  		}
   525  	}()
   526  
   527  	return stream
   528  }
   529  
   530  func (s *Stream) Close() error {
   531  	s.mu.Lock()
   532  	defer s.mu.Unlock()
   533  
   534  	if s.state != streamStateEstablished && s.state != streamStateCloseWait {
   535  		return fmt.Errorf("Stream in bad state: %d", s.state)
   536  	}
   537  
   538  	if s.state == streamStateEstablished {
   539  		s.setState(streamStateFinWait1)
   540  	} else {
   541  		s.setState(streamStateLastAck)
   542  	}
   543  
   544  	s.write(muxPacketFin, nil)
   545  	return nil
   546  }
   547  
   548  func (s *Stream) Read(p []byte) (int, error) {
   549  	return s.reader.Read(p)
   550  }
   551  
   552  func (s *Stream) Write(p []byte) (int, error) {
   553  	s.mu.Lock()
   554  	state := s.state
   555  	s.mu.Unlock()
   556  
   557  	if state != streamStateEstablished && state != streamStateCloseWait {
   558  		return 0, fmt.Errorf("Stream %d in bad state to send: %d", s.id, state)
   559  	}
   560  
   561  	return s.write(muxPacketData, p)
   562  }
   563  
   564  func (s *Stream) closeWriter() {
   565  	if s.writeCh != nil {
   566  		s.writeCh <- nil
   567  		s.writeCh = nil
   568  	}
   569  }
   570  
   571  func (s *Stream) setState(state streamState) {
   572  	//log.Printf("[TRACE] %p: Stream %d (%s) went to state %d", s.mux, s.id, s.from, state)
   573  	s.state = state
   574  	s.stateUpdated = time.Now().UTC()
   575  	for ch, _ := range s.stateChange {
   576  		select {
   577  		case ch <- state:
   578  		default:
   579  		}
   580  	}
   581  }
   582  
   583  func (s *Stream) waitState(target streamState) error {
   584  	// Register a state change listener to wait for changes
   585  	stateCh := make(chan streamState, 10)
   586  	s.stateChange[stateCh] = struct{}{}
   587  	s.mu.Unlock()
   588  
   589  	defer func() {
   590  		s.mu.Lock()
   591  		delete(s.stateChange, stateCh)
   592  	}()
   593  
   594  	//log.Printf("[TRACE] %p: Stream %d (%s) waiting for state: %d", s.mux, s.id, s.from, target)
   595  	state := <-stateCh
   596  	if state == target {
   597  		return nil
   598  	} else {
   599  		return fmt.Errorf("Stream %d went to bad state: %d", s.id, state)
   600  	}
   601  }
   602  
   603  func (s *Stream) write(dataType muxPacketType, p []byte) (int, error) {
   604  	return s.mux.write(s.from, s.id, dataType, p)
   605  }