github.com/qichengzx/mattermost-server@v4.5.1-0.20180604164826-2c75247c97d0+incompatible/plugin/rpcplugin/muxer.go (about)

     1  // Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
     2  // See License.txt for license information.
     3  
     4  package rpcplugin
     5  
     6  import (
     7  	"bufio"
     8  	"bytes"
     9  	"encoding/binary"
    10  	"fmt"
    11  	"io"
    12  	"sync"
    13  	"sync/atomic"
    14  )
    15  
    16  // Muxer allows multiple bidirectional streams to be transmitted over a single connection.
    17  //
    18  // Muxer is safe for use by multiple goroutines.
    19  //
    20  // Streams opened on the muxer must be periodically drained in order to reclaim read buffer memory.
    21  // In other words, readers must consume incoming data as it comes in.
    22  type Muxer struct {
    23  	// writeMutex guards conn writes
    24  	writeMutex sync.Mutex
    25  	conn       io.ReadWriteCloser
    26  
    27  	// didCloseConn is a boolean (0 or 1) used from multiple goroutines via atomic operations
    28  	didCloseConn int32
    29  
    30  	// streamsMutex guards streams and nextId
    31  	streamsMutex sync.Mutex
    32  	nextId       int64
    33  	streams      map[int64]*muxerStream
    34  
    35  	stream0Reader *io.PipeReader
    36  	stream0Writer *io.PipeWriter
    37  	result        chan error
    38  }
    39  
    40  // Creates a new Muxer.
    41  //
    42  // conn must be safe for simultaneous reads by one goroutine and writes by another.
    43  //
    44  // For two muxers communicating with each other via a connection, parity must be true for exactly
    45  // one of them.
    46  func NewMuxer(conn io.ReadWriteCloser, parity bool) *Muxer {
    47  	s0r, s0w := io.Pipe()
    48  	muxer := &Muxer{
    49  		conn:          conn,
    50  		streams:       make(map[int64]*muxerStream),
    51  		result:        make(chan error, 1),
    52  		nextId:        1,
    53  		stream0Reader: s0r,
    54  		stream0Writer: s0w,
    55  	}
    56  	if parity {
    57  		muxer.nextId = 2
    58  	}
    59  	go muxer.run()
    60  	return muxer
    61  }
    62  
    63  // Opens a new stream with a unique id.
    64  //
    65  // Writes made to the stream before the other end calls Connect will be discarded.
    66  func (m *Muxer) Serve() (int64, io.ReadWriteCloser) {
    67  	m.streamsMutex.Lock()
    68  	id := m.nextId
    69  	m.nextId += 2
    70  	m.streamsMutex.Unlock()
    71  	return id, m.Connect(id)
    72  }
    73  
    74  // Opens a remotely opened stream.
    75  func (m *Muxer) Connect(id int64) io.ReadWriteCloser {
    76  	m.streamsMutex.Lock()
    77  	defer m.streamsMutex.Unlock()
    78  	mutex := &sync.Mutex{}
    79  	stream := &muxerStream{
    80  		id:       id,
    81  		muxer:    m,
    82  		mutex:    mutex,
    83  		readWake: sync.NewCond(mutex),
    84  	}
    85  	m.streams[id] = stream
    86  	return stream
    87  }
    88  
    89  // Calling Read on the muxer directly performs a read on a dedicated, always-open channel.
    90  func (m *Muxer) Read(p []byte) (int, error) {
    91  	return m.stream0Reader.Read(p)
    92  }
    93  
    94  // Calling Write on the muxer directly performs a write on a dedicated, always-open channel.
    95  func (m *Muxer) Write(p []byte) (int, error) {
    96  	return m.write(p, 0)
    97  }
    98  
    99  // Closes the muxer.
   100  func (m *Muxer) Close() error {
   101  	if atomic.CompareAndSwapInt32(&m.didCloseConn, 0, 1) {
   102  		m.conn.Close()
   103  	}
   104  	m.stream0Reader.Close()
   105  	m.stream0Writer.Close()
   106  	<-m.result
   107  	return nil
   108  }
   109  
   110  func (m *Muxer) IsClosed() bool {
   111  	return atomic.LoadInt32(&m.didCloseConn) > 0
   112  }
   113  
   114  func (m *Muxer) write(p []byte, sid int64) (int, error) {
   115  	m.writeMutex.Lock()
   116  	defer m.writeMutex.Unlock()
   117  	if m.IsClosed() {
   118  		return 0, fmt.Errorf("muxer closed")
   119  	}
   120  	var buf [10]byte
   121  	n := binary.PutVarint(buf[:], sid)
   122  	if _, err := m.conn.Write(buf[:n]); err != nil {
   123  		m.shutdown(err)
   124  		return 0, err
   125  	}
   126  	n = binary.PutVarint(buf[:], int64(len(p)))
   127  	if _, err := m.conn.Write(buf[:n]); err != nil {
   128  		m.shutdown(err)
   129  		return 0, err
   130  	}
   131  	if len(p) > 0 {
   132  		if _, err := m.conn.Write(p); err != nil {
   133  			m.shutdown(err)
   134  			return 0, err
   135  		}
   136  	}
   137  	return len(p), nil
   138  }
   139  
   140  func (m *Muxer) rm(sid int64) {
   141  	m.streamsMutex.Lock()
   142  	defer m.streamsMutex.Unlock()
   143  	delete(m.streams, sid)
   144  }
   145  
   146  func (m *Muxer) run() {
   147  	m.shutdown(m.loop())
   148  }
   149  
   150  func (m *Muxer) loop() error {
   151  	reader := bufio.NewReader(m.conn)
   152  
   153  	for {
   154  		sid, err := binary.ReadVarint(reader)
   155  		if err != nil {
   156  			return err
   157  		}
   158  		len, err := binary.ReadVarint(reader)
   159  		if err != nil {
   160  			return err
   161  		}
   162  
   163  		if sid == 0 {
   164  			if _, err := io.CopyN(m.stream0Writer, reader, len); err != nil {
   165  				return err
   166  			}
   167  			continue
   168  		}
   169  
   170  		m.streamsMutex.Lock()
   171  		stream, ok := m.streams[sid]
   172  		m.streamsMutex.Unlock()
   173  		if !ok {
   174  			if _, err := reader.Discard(int(len)); err != nil {
   175  				return err
   176  			}
   177  			continue
   178  		}
   179  
   180  		stream.mutex.Lock()
   181  		if stream.isClosed {
   182  			stream.mutex.Unlock()
   183  			if _, err := reader.Discard(int(len)); err != nil {
   184  				return err
   185  			}
   186  			continue
   187  		}
   188  		if len == 0 {
   189  			stream.remoteClosed = true
   190  		} else {
   191  			_, err = io.CopyN(&stream.readBuf, reader, len)
   192  		}
   193  		stream.mutex.Unlock()
   194  		if err != nil {
   195  			return err
   196  		}
   197  		stream.readWake.Signal()
   198  	}
   199  }
   200  
   201  func (m *Muxer) shutdown(err error) {
   202  	if atomic.CompareAndSwapInt32(&m.didCloseConn, 0, 1) {
   203  		m.conn.Close()
   204  	}
   205  	go func() {
   206  		m.streamsMutex.Lock()
   207  		for _, stream := range m.streams {
   208  			stream.mutex.Lock()
   209  			stream.readWake.Signal()
   210  			stream.mutex.Unlock()
   211  		}
   212  		m.streams = make(map[int64]*muxerStream)
   213  		m.streamsMutex.Unlock()
   214  	}()
   215  	m.result <- err
   216  }
   217  
   218  type muxerStream struct {
   219  	id           int64
   220  	muxer        *Muxer
   221  	readBuf      bytes.Buffer
   222  	mutex        *sync.Mutex
   223  	readWake     *sync.Cond
   224  	isClosed     bool
   225  	remoteClosed bool
   226  }
   227  
   228  func (s *muxerStream) Read(p []byte) (int, error) {
   229  	s.mutex.Lock()
   230  	defer s.mutex.Unlock()
   231  	for {
   232  		if s.muxer.IsClosed() {
   233  			return 0, fmt.Errorf("muxer closed")
   234  		} else if s.isClosed {
   235  			return 0, io.EOF
   236  		} else if s.readBuf.Len() > 0 {
   237  			return s.readBuf.Read(p)
   238  		} else if s.remoteClosed {
   239  			return 0, io.EOF
   240  		}
   241  		s.readWake.Wait()
   242  	}
   243  }
   244  
   245  func (s *muxerStream) Write(p []byte) (int, error) {
   246  	s.mutex.Lock()
   247  	defer s.mutex.Unlock()
   248  	if s.isClosed {
   249  		return 0, fmt.Errorf("stream closed")
   250  	}
   251  	return s.muxer.write(p, s.id)
   252  }
   253  
   254  func (s *muxerStream) Close() error {
   255  	s.mutex.Lock()
   256  	defer s.mutex.Unlock()
   257  	if !s.isClosed {
   258  		s.muxer.write(nil, s.id)
   259  		s.isClosed = true
   260  		s.muxer.rm(s.id)
   261  	}
   262  	s.readWake.Signal()
   263  	return nil
   264  }