github.com/zignig/go-ipfs@v0.0.0-20141111235910-c9e5fdf55a52/net/mux/mux.go (about)

     1  // package mux implements a protocol muxer.
     2  package mux
     3  
     4  import (
     5  	"errors"
     6  	"sync"
     7  
     8  	conn "github.com/jbenet/go-ipfs/net/conn"
     9  	msg "github.com/jbenet/go-ipfs/net/message"
    10  	pb "github.com/jbenet/go-ipfs/net/mux/internal/pb"
    11  	u "github.com/jbenet/go-ipfs/util"
    12  	ctxc "github.com/jbenet/go-ipfs/util/ctxcloser"
    13  
    14  	context "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/go.net/context"
    15  	proto "github.com/jbenet/go-ipfs/Godeps/_workspace/src/code.google.com/p/goprotobuf/proto"
    16  )
    17  
    18  var log = u.Logger("muxer")
    19  
    20  // ProtocolIDs used to identify each protocol.
    21  // These should probably be defined elsewhere.
    22  var (
    23  	ProtocolID_Routing    = pb.ProtocolID_Routing
    24  	ProtocolID_Exchange   = pb.ProtocolID_Exchange
    25  	ProtocolID_Diagnostic = pb.ProtocolID_Diagnostic
    26  )
    27  
    28  // Protocol objects produce + consume raw data. They are added to the Muxer
    29  // with a ProtocolID, which is added to outgoing payloads. Muxer properly
    30  // encapsulates and decapsulates when interfacing with its Protocols. The
    31  // Protocols do not encounter their ProtocolID.
    32  type Protocol interface {
    33  	GetPipe() *msg.Pipe
    34  }
    35  
    36  // ProtocolMap maps ProtocolIDs to Protocols.
    37  type ProtocolMap map[pb.ProtocolID]Protocol
    38  
    39  // Muxer is a simple multiplexor that reads + writes to Incoming and Outgoing
    40  // channels. It multiplexes various protocols, wrapping and unwrapping data
    41  // with a ProtocolID.
    42  type Muxer struct {
    43  	// Protocols are the multiplexed services.
    44  	Protocols ProtocolMap
    45  
    46  	bwiLock sync.Mutex
    47  	bwIn    uint64
    48  
    49  	bwoLock sync.Mutex
    50  	bwOut   uint64
    51  
    52  	*msg.Pipe
    53  	ctxc.ContextCloser
    54  }
    55  
    56  // NewMuxer constructs a muxer given a protocol map.
    57  func NewMuxer(ctx context.Context, mp ProtocolMap) *Muxer {
    58  	m := &Muxer{
    59  		Protocols:     mp,
    60  		Pipe:          msg.NewPipe(10),
    61  		ContextCloser: ctxc.NewContextCloser(ctx, nil),
    62  	}
    63  
    64  	m.Children().Add(1)
    65  	go m.handleIncomingMessages()
    66  	for pid, proto := range m.Protocols {
    67  		m.Children().Add(1)
    68  		go m.handleOutgoingMessages(pid, proto)
    69  	}
    70  
    71  	return m
    72  }
    73  
    74  // GetPipe implements the Protocol interface
    75  func (m *Muxer) GetPipe() *msg.Pipe {
    76  	return m.Pipe
    77  }
    78  
    79  // GetBandwidthTotals return the in/out bandwidth measured over this muxer.
    80  func (m *Muxer) GetBandwidthTotals() (in uint64, out uint64) {
    81  	m.bwiLock.Lock()
    82  	in = m.bwIn
    83  	m.bwiLock.Unlock()
    84  
    85  	m.bwoLock.Lock()
    86  	out = m.bwOut
    87  	m.bwoLock.Unlock()
    88  	return
    89  }
    90  
    91  // AddProtocol adds a Protocol with given ProtocolID to the Muxer.
    92  func (m *Muxer) AddProtocol(p Protocol, pid pb.ProtocolID) error {
    93  	if _, found := m.Protocols[pid]; found {
    94  		return errors.New("Another protocol already using this ProtocolID")
    95  	}
    96  
    97  	m.Protocols[pid] = p
    98  	return nil
    99  }
   100  
   101  // handleIncoming consumes the messages on the m.Incoming channel and
   102  // routes them appropriately (to the protocols).
   103  func (m *Muxer) handleIncomingMessages() {
   104  	defer m.Children().Done()
   105  
   106  	for {
   107  		select {
   108  		case <-m.Closing():
   109  			return
   110  
   111  		case msg, more := <-m.Incoming:
   112  			if !more {
   113  				return
   114  			}
   115  			m.Children().Add(1)
   116  			go m.handleIncomingMessage(msg)
   117  		}
   118  	}
   119  }
   120  
   121  // handleIncomingMessage routes message to the appropriate protocol.
   122  func (m *Muxer) handleIncomingMessage(m1 msg.NetMessage) {
   123  	defer m.Children().Done()
   124  
   125  	m.bwiLock.Lock()
   126  	// TODO: compensate for overhead
   127  	m.bwIn += uint64(len(m1.Data()))
   128  	m.bwiLock.Unlock()
   129  
   130  	data, pid, err := unwrapData(m1.Data())
   131  	if err != nil {
   132  		log.Errorf("muxer de-serializing error: %v", err)
   133  		return
   134  	}
   135  	conn.ReleaseBuffer(m1.Data())
   136  
   137  	m2 := msg.New(m1.Peer(), data)
   138  	proto, found := m.Protocols[pid]
   139  	if !found {
   140  		log.Errorf("muxer unknown protocol %v", pid)
   141  		return
   142  	}
   143  
   144  	select {
   145  	case proto.GetPipe().Incoming <- m2:
   146  	case <-m.Closing():
   147  		return
   148  	}
   149  }
   150  
   151  // handleOutgoingMessages consumes the messages on the proto.Outgoing channel,
   152  // wraps them and sends them out.
   153  func (m *Muxer) handleOutgoingMessages(pid pb.ProtocolID, proto Protocol) {
   154  	defer m.Children().Done()
   155  
   156  	for {
   157  		select {
   158  		case msg, more := <-proto.GetPipe().Outgoing:
   159  			if !more {
   160  				return
   161  			}
   162  			m.Children().Add(1)
   163  			go m.handleOutgoingMessage(pid, msg)
   164  
   165  		case <-m.Closing():
   166  			return
   167  		}
   168  	}
   169  }
   170  
   171  // handleOutgoingMessage wraps out a message and sends it out the
   172  func (m *Muxer) handleOutgoingMessage(pid pb.ProtocolID, m1 msg.NetMessage) {
   173  	defer m.Children().Done()
   174  
   175  	data, err := wrapData(m1.Data(), pid)
   176  	if err != nil {
   177  		log.Errorf("muxer serializing error: %v", err)
   178  		return
   179  	}
   180  
   181  	m.bwoLock.Lock()
   182  	// TODO: compensate for overhead
   183  	// TODO(jbenet): switch this to a goroutine to prevent sync waiting.
   184  	m.bwOut += uint64(len(data))
   185  	m.bwoLock.Unlock()
   186  
   187  	m2 := msg.New(m1.Peer(), data)
   188  	select {
   189  	case m.GetPipe().Outgoing <- m2:
   190  	case <-m.Closing():
   191  		return
   192  	}
   193  }
   194  
   195  func wrapData(data []byte, pid pb.ProtocolID) ([]byte, error) {
   196  	// Marshal
   197  	pbm := new(pb.PBProtocolMessage)
   198  	pbm.ProtocolID = &pid
   199  	pbm.Data = data
   200  	b, err := proto.Marshal(pbm)
   201  	if err != nil {
   202  		return nil, err
   203  	}
   204  
   205  	return b, nil
   206  }
   207  
   208  func unwrapData(data []byte) ([]byte, pb.ProtocolID, error) {
   209  	// Unmarshal
   210  	pbm := new(pb.PBProtocolMessage)
   211  	err := proto.Unmarshal(data, pbm)
   212  	if err != nil {
   213  		return nil, 0, err
   214  	}
   215  
   216  	return pbm.GetData(), pbm.GetProtocolID(), nil
   217  }