github.com/gdamore/mangos@v1.4.0/core.go (about)

     1  // Copyright 2017 The Mangos Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use file except in compliance with the License.
     5  // You may obtain a copy of the license at
     6  //
     7  //    http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package mangos
    16  
    17  import (
    18  	"fmt"
    19  	"strings"
    20  	"sync"
    21  	"time"
    22  )
    23  
    24  // defaultQLen is the default length of the upper read/write queues.
    25  const defaultQLen = 128
    26  
    27  // defaultMaxRxSize is the default maximum Rx size
    28  const defaultMaxRxSize = 1024 * 1024
    29  
    30  // socket is the meaty part of the core information.
    31  type socket struct {
    32  	proto Protocol
    33  
    34  	sync.Mutex
    35  
    36  	uwq      chan *Message // upper write queue
    37  	uwqLen   int           // upper write queue buffer length
    38  	urq      chan *Message // upper read queue
    39  	urqLen   int           // upper read queue buffer length
    40  	closeq   chan struct{} // closed when user requests close
    41  	recverrq chan struct{} // signaled when an error is pending
    42  
    43  	closing    bool  // true if Socket was closed at API level
    44  	active     bool  // true if either Dial or Listen has been successfully called
    45  	bestEffort bool  // true if OptionBestEffort is set
    46  	recverr    error // error to return on attempts to Recv()
    47  	senderr    error // error to return on attempts to Send()
    48  
    49  	rdeadline  time.Duration
    50  	wdeadline  time.Duration
    51  	reconntime time.Duration // reconnect time after error or disconnect
    52  	reconnmax  time.Duration // max reconnect interval
    53  	linger     time.Duration
    54  	maxRxSize  int // max recv size
    55  
    56  	pipes []*pipe
    57  
    58  	listeners []*listener
    59  
    60  	transports map[string]Transport
    61  
    62  	// These are conditional "type aliases" for our self
    63  	sendhook ProtocolSendHook
    64  	recvhook ProtocolRecvHook
    65  
    66  	// Port hook -- called when a port is added or removed
    67  	porthook PortHook
    68  }
    69  
    70  func (sock *socket) addPipe(tranpipe Pipe, d *dialer, l *listener) *pipe {
    71  	p := newPipe(tranpipe)
    72  	p.d = d
    73  	p.l = l
    74  
    75  	// Either listener or dialer is non-nil -- this could be an assert
    76  	if l == nil && d == nil {
    77  		p.Close()
    78  		return nil
    79  	}
    80  
    81  	sock.Lock()
    82  	if fn := sock.porthook; fn != nil {
    83  		sock.Unlock()
    84  		if !fn(PortActionAdd, p) {
    85  			p.Close()
    86  			return nil
    87  		}
    88  		sock.Lock()
    89  	}
    90  	p.sock = sock
    91  	p.index = len(sock.pipes)
    92  	sock.pipes = append(sock.pipes, p)
    93  	sock.Unlock()
    94  	sock.proto.AddEndpoint(p)
    95  	return p
    96  }
    97  
    98  func (sock *socket) remPipe(p *pipe) {
    99  
   100  	sock.proto.RemoveEndpoint(p)
   101  
   102  	sock.Lock()
   103  	if p.index >= 0 {
   104  		sock.pipes[p.index] = sock.pipes[len(sock.pipes)-1]
   105  		sock.pipes[p.index].index = p.index
   106  		sock.pipes = sock.pipes[:len(sock.pipes)-1]
   107  		p.index = -1
   108  	}
   109  	sock.Unlock()
   110  }
   111  
   112  func newSocket(proto Protocol) *socket {
   113  	sock := new(socket)
   114  	sock.uwqLen = defaultQLen
   115  	sock.urqLen = defaultQLen
   116  	sock.uwq = make(chan *Message, sock.uwqLen)
   117  	sock.urq = make(chan *Message, sock.urqLen)
   118  	sock.closeq = make(chan struct{})
   119  	sock.recverrq = make(chan struct{})
   120  	sock.reconntime = time.Millisecond * 100
   121  	sock.reconnmax = time.Duration(0)
   122  	sock.proto = proto
   123  	sock.transports = make(map[string]Transport)
   124  	sock.linger = time.Second
   125  	sock.maxRxSize = defaultMaxRxSize
   126  
   127  	// Add some conditionals now -- saves checks later
   128  	if i, ok := interface{}(proto).(ProtocolRecvHook); ok {
   129  		sock.recvhook = i.(ProtocolRecvHook)
   130  	}
   131  	if i, ok := interface{}(proto).(ProtocolSendHook); ok {
   132  		sock.sendhook = i.(ProtocolSendHook)
   133  	}
   134  
   135  	proto.Init(sock)
   136  
   137  	return sock
   138  }
   139  
   140  // MakeSocket is intended for use by Protocol implementations.  The intention
   141  // is that they can wrap this to provide a "proto.NewSocket()" implementation.
   142  func MakeSocket(proto Protocol) Socket {
   143  	return newSocket(proto)
   144  }
   145  
   146  // Implementation of ProtocolSocket bits on socket.  This is the middle
   147  // API presented to Protocol implementations.
   148  
   149  func (sock *socket) SendChannel() <-chan *Message {
   150  	sock.Lock()
   151  	defer sock.Unlock()
   152  	return sock.uwq
   153  }
   154  
   155  func (sock *socket) RecvChannel() chan<- *Message {
   156  	sock.Lock()
   157  	defer sock.Unlock()
   158  	return sock.urq
   159  }
   160  
   161  func (sock *socket) CloseChannel() <-chan struct{} {
   162  	return sock.closeq
   163  }
   164  
   165  func (sock *socket) SetSendError(err error) {
   166  	sock.Lock()
   167  	sock.senderr = err
   168  	sock.Unlock()
   169  }
   170  
   171  func (sock *socket) SetRecvError(err error) {
   172  	sock.Lock()
   173  	sock.recverr = err
   174  	select {
   175  	case sock.recverrq <- struct{}{}:
   176  	default:
   177  	}
   178  	sock.Unlock()
   179  }
   180  
   181  //
   182  // Implementation of Socket bits on socket.  This is the upper API
   183  // presented to applications.
   184  //
   185  
   186  func (sock *socket) Close() error {
   187  
   188  	fin := time.Now().Add(sock.linger)
   189  
   190  	DrainChannel(sock.uwq, fin)
   191  
   192  	sock.Lock()
   193  	if sock.closing {
   194  		sock.Unlock()
   195  		return ErrClosed
   196  	}
   197  	sock.closing = true
   198  	close(sock.closeq)
   199  
   200  	for _, l := range sock.listeners {
   201  		l.l.Close()
   202  	}
   203  	pipes := append([]*pipe{}, sock.pipes...)
   204  	sock.Unlock()
   205  
   206  	// A second drain, just to be sure.  (We could have had device or
   207  	// forwarded messages arrive since the last one.)
   208  	DrainChannel(sock.uwq, fin)
   209  
   210  	// And tell the protocol to shutdown and drain its pipes too.
   211  	sock.proto.Shutdown(fin)
   212  
   213  	for _, p := range pipes {
   214  		p.Close()
   215  	}
   216  
   217  	return nil
   218  }
   219  
   220  func (sock *socket) SendMsg(msg *Message) error {
   221  
   222  	sock.Lock()
   223  	e := sock.senderr
   224  	if e != nil {
   225  		sock.Unlock()
   226  		return e
   227  	}
   228  	sock.Unlock()
   229  	if sock.sendhook != nil {
   230  		if ok := sock.sendhook.SendHook(msg); !ok {
   231  			// just drop it silently
   232  			msg.Free()
   233  			return nil
   234  		}
   235  	}
   236  	sock.Lock()
   237  	useBestEffort := sock.bestEffort
   238  	wdeadline := sock.wdeadline
   239  
   240  	if wdeadline != 0 {
   241  		msg.expire = time.Now().Add(wdeadline)
   242  	} else {
   243  		msg.expire = time.Time{}
   244  	}
   245  	sock.Unlock()
   246  
   247  	if !useBestEffort {
   248  		timeout := mkTimer(wdeadline)
   249  		select {
   250  		case <-timeout:
   251  			return ErrSendTimeout
   252  		case <-sock.closeq:
   253  			return ErrClosed
   254  		case sock.uwq <- msg:
   255  			return nil
   256  		}
   257  	} else {
   258  		select {
   259  		case <-sock.closeq:
   260  			return ErrClosed
   261  		case sock.uwq <- msg:
   262  			return nil
   263  		default:
   264  			msg.Free()
   265  			return nil
   266  		}
   267  	}
   268  }
   269  
   270  func (sock *socket) Send(b []byte) error {
   271  	msg := NewMessage(len(b))
   272  	msg.Body = append(msg.Body, b...)
   273  	return sock.SendMsg(msg)
   274  }
   275  
   276  // String just emits a very high level debug.  This avoids
   277  // triggering race conditions from trying to print %v without
   278  // holding locks on structure members.
   279  func (sock *socket) String() string {
   280  	return fmt.Sprintf("SOCKET[%s](%p)", sock.proto.Name(), sock)
   281  }
   282  
   283  func (sock *socket) RecvMsg() (*Message, error) {
   284  	sock.Lock()
   285  	timeout := mkTimer(sock.rdeadline)
   286  	sock.Unlock()
   287  
   288  	for {
   289  		sock.Lock()
   290  		if e := sock.recverr; e != nil {
   291  			sock.Unlock()
   292  			return nil, e
   293  		}
   294  		sock.Unlock()
   295  		select {
   296  		case <-timeout:
   297  			return nil, ErrRecvTimeout
   298  		case msg := <-sock.urq:
   299  			if sock.recvhook != nil {
   300  				if ok := sock.recvhook.RecvHook(msg); ok {
   301  					return msg, nil
   302  				} // else loop
   303  				msg.Free()
   304  			} else {
   305  				return msg, nil
   306  			}
   307  		case <-sock.closeq:
   308  			return nil, ErrClosed
   309  		case <-sock.recverrq:
   310  		}
   311  	}
   312  }
   313  
   314  func (sock *socket) Recv() ([]byte, error) {
   315  	msg, err := sock.RecvMsg()
   316  	if err != nil {
   317  		return nil, err
   318  	}
   319  	b := make([]byte, 0, len(msg.Body))
   320  	b = append(b, msg.Body...)
   321  	msg.Free()
   322  	return b, nil
   323  }
   324  
   325  func (sock *socket) getTransport(addr string) Transport {
   326  	var i int
   327  
   328  	if i = strings.Index(addr, "://"); i < 0 {
   329  		return nil
   330  	}
   331  	scheme := addr[:i]
   332  
   333  	sock.Lock()
   334  	defer sock.Unlock()
   335  
   336  	t, ok := sock.transports[scheme]
   337  	if t != nil && ok {
   338  		return t
   339  	}
   340  	return nil
   341  }
   342  
   343  func (sock *socket) AddTransport(t Transport) {
   344  	sock.Lock()
   345  	sock.transports[t.Scheme()] = t
   346  	sock.Unlock()
   347  }
   348  
   349  func (sock *socket) DialOptions(addr string, opts map[string]interface{}) error {
   350  
   351  	d, err := sock.NewDialer(addr, opts)
   352  	if err != nil {
   353  		return err
   354  	}
   355  	return d.Dial()
   356  }
   357  
   358  func (sock *socket) Dial(addr string) error {
   359  	return sock.DialOptions(addr, nil)
   360  }
   361  
   362  func (sock *socket) NewDialer(addr string, options map[string]interface{}) (Dialer, error) {
   363  	var err error
   364  	d := &dialer{sock: sock, addr: addr, closeq: make(chan struct{})}
   365  	t := sock.getTransport(addr)
   366  	if t == nil {
   367  		return nil, ErrBadTran
   368  	}
   369  	if d.d, err = t.NewDialer(addr, sock); err != nil {
   370  		return nil, err
   371  	}
   372  	for n, v := range options {
   373  		if err = d.d.SetOption(n, v); err != nil {
   374  			return nil, err
   375  		}
   376  	}
   377  	return d, nil
   378  }
   379  
   380  func (sock *socket) ListenOptions(addr string, options map[string]interface{}) error {
   381  	l, err := sock.NewListener(addr, options)
   382  	if err != nil {
   383  		return err
   384  	}
   385  	if err = l.Listen(); err != nil {
   386  		return err
   387  	}
   388  	return nil
   389  }
   390  
   391  func (sock *socket) Listen(addr string) error {
   392  	return sock.ListenOptions(addr, nil)
   393  }
   394  
   395  func (sock *socket) NewListener(addr string, options map[string]interface{}) (Listener, error) {
   396  	// This function sets up a goroutine to accept inbound connections.
   397  	// The accepted connection will be added to a list of accepted
   398  	// connections.  The Listener just needs to listen continuously,
   399  	// as we assume that we want to continue to receive inbound
   400  	// connections without limit.
   401  	t := sock.getTransport(addr)
   402  	if t == nil {
   403  		return nil, ErrBadTran
   404  	}
   405  	var err error
   406  	l := &listener{sock: sock, addr: addr}
   407  	l.l, err = t.NewListener(addr, sock)
   408  	if err != nil {
   409  		return nil, err
   410  	}
   411  	for n, v := range options {
   412  		if err = l.l.SetOption(n, v); err != nil {
   413  			l.l.Close()
   414  			return nil, err
   415  		}
   416  	}
   417  	return l, nil
   418  }
   419  
   420  func (sock *socket) SetOption(name string, value interface{}) error {
   421  	matched := false
   422  	err := sock.proto.SetOption(name, value)
   423  	if err == nil {
   424  		matched = true
   425  	} else if err != ErrBadOption {
   426  		return err
   427  	}
   428  	switch name {
   429  	case OptionRecvDeadline:
   430  		sock.Lock()
   431  		sock.rdeadline = value.(time.Duration)
   432  		sock.Unlock()
   433  		return nil
   434  	case OptionSendDeadline:
   435  		sock.Lock()
   436  		sock.wdeadline = value.(time.Duration)
   437  		sock.Unlock()
   438  		return nil
   439  	case OptionLinger:
   440  		sock.Lock()
   441  		sock.linger = value.(time.Duration)
   442  		sock.Unlock()
   443  		return nil
   444  	case OptionWriteQLen:
   445  		sock.Lock()
   446  		defer sock.Unlock()
   447  		if sock.active {
   448  			return ErrBadOption
   449  		}
   450  		length := value.(int)
   451  		if length < 0 {
   452  			return ErrBadValue
   453  		}
   454  		owq := sock.uwq
   455  		sock.uwqLen = length
   456  		sock.uwq = make(chan *Message, sock.uwqLen)
   457  		close(owq)
   458  		return nil
   459  	case OptionReadQLen:
   460  		sock.Lock()
   461  		defer sock.Unlock()
   462  		if sock.active {
   463  			return ErrBadOption
   464  		}
   465  		length := value.(int)
   466  		if length < 0 {
   467  			return ErrBadValue
   468  		}
   469  		sock.urqLen = length
   470  		sock.urq = make(chan *Message, sock.urqLen)
   471  		return nil
   472  	case OptionMaxRecvSize:
   473  		sock.Lock()
   474  		defer sock.Unlock()
   475  		switch value := value.(type) {
   476  		case int:
   477  			if value < 0 {
   478  				return ErrBadValue
   479  			}
   480  			sock.maxRxSize = value
   481  			return nil
   482  		default:
   483  			return ErrBadValue
   484  		}
   485  	case OptionReconnectTime:
   486  		sock.Lock()
   487  		sock.reconntime = value.(time.Duration)
   488  		sock.Unlock()
   489  		return nil
   490  	case OptionMaxReconnectTime:
   491  		sock.Lock()
   492  		sock.reconnmax = value.(time.Duration)
   493  		sock.Unlock()
   494  		return nil
   495  	case OptionBestEffort:
   496  		sock.Lock()
   497  		sock.bestEffort = value.(bool)
   498  		sock.Unlock()
   499  		return nil
   500  	}
   501  	if matched {
   502  		return nil
   503  	}
   504  	return ErrBadOption
   505  }
   506  
   507  func (sock *socket) GetOption(name string) (interface{}, error) {
   508  	val, err := sock.proto.GetOption(name)
   509  	if err == nil {
   510  		return val, nil
   511  	}
   512  	if err != ErrBadOption {
   513  		return nil, err
   514  	}
   515  
   516  	switch name {
   517  	case OptionRecvDeadline:
   518  		sock.Lock()
   519  		defer sock.Unlock()
   520  		return sock.rdeadline, nil
   521  	case OptionSendDeadline:
   522  		sock.Lock()
   523  		defer sock.Unlock()
   524  		return sock.wdeadline, nil
   525  	case OptionLinger:
   526  		sock.Lock()
   527  		defer sock.Unlock()
   528  		return sock.linger, nil
   529  	case OptionWriteQLen:
   530  		sock.Lock()
   531  		defer sock.Unlock()
   532  		return sock.uwqLen, nil
   533  	case OptionReadQLen:
   534  		sock.Lock()
   535  		defer sock.Unlock()
   536  		return sock.urqLen, nil
   537  	case OptionMaxRecvSize:
   538  		sock.Lock()
   539  		defer sock.Unlock()
   540  		return sock.maxRxSize, nil
   541  	case OptionReconnectTime:
   542  		sock.Lock()
   543  		defer sock.Unlock()
   544  		return sock.reconntime, nil
   545  	case OptionMaxReconnectTime:
   546  		sock.Lock()
   547  		defer sock.Unlock()
   548  		return sock.reconnmax, nil
   549  	}
   550  	return nil, ErrBadOption
   551  }
   552  
   553  func (sock *socket) GetProtocol() Protocol {
   554  	return sock.proto
   555  }
   556  
   557  func (sock *socket) SetPortHook(newhook PortHook) PortHook {
   558  	sock.Lock()
   559  	oldhook := sock.porthook
   560  	sock.porthook = newhook
   561  	sock.Unlock()
   562  	return oldhook
   563  }
   564  
   565  type dialer struct {
   566  	d      PipeDialer
   567  	sock   *socket
   568  	addr   string
   569  	closed bool
   570  	active bool
   571  	closeq chan struct{}
   572  }
   573  
   574  func (d *dialer) Dial() error {
   575  	d.sock.Lock()
   576  	if d.active {
   577  		d.sock.Unlock()
   578  		return ErrAddrInUse
   579  	}
   580  	d.closeq = make(chan struct{})
   581  	d.sock.active = true
   582  	d.active = true
   583  	d.sock.Unlock()
   584  	go d.dialer()
   585  	return nil
   586  }
   587  
   588  func (d *dialer) Close() error {
   589  	d.sock.Lock()
   590  	if d.closed {
   591  		d.sock.Unlock()
   592  		return ErrClosed
   593  	}
   594  	d.closed = true
   595  	close(d.closeq)
   596  	d.sock.Unlock()
   597  	return nil
   598  }
   599  
   600  func (d *dialer) GetOption(n string) (interface{}, error) {
   601  	return d.d.GetOption(n)
   602  }
   603  
   604  func (d *dialer) SetOption(n string, v interface{}) error {
   605  	return d.d.SetOption(n, v)
   606  }
   607  
   608  func (d *dialer) Address() string {
   609  	return d.addr
   610  }
   611  
   612  // dialer is used to dial or redial from a goroutine.
   613  func (d *dialer) dialer() {
   614  	rtime := d.sock.reconntime
   615  	rtmax := d.sock.reconnmax
   616  	for {
   617  		p, err := d.d.Dial()
   618  		if err == nil {
   619  			// reset retry time
   620  			rtime = d.sock.reconntime
   621  			d.sock.Lock()
   622  			if d.closed {
   623  				d.sock.Unlock()
   624  				p.Close()
   625  				return
   626  			}
   627  			d.sock.Unlock()
   628  			if cp := d.sock.addPipe(p, d, nil); cp != nil {
   629  				select {
   630  				case <-d.sock.closeq: // parent socket closed
   631  				case <-cp.closeq: // disconnect event
   632  				case <-d.closeq: // dialer closed
   633  				}
   634  			}
   635  		}
   636  
   637  		// we're redialing here
   638  		select {
   639  		case <-d.closeq: // dialer closed
   640  			if p != nil {
   641  				p.Close()
   642  			}
   643  			return
   644  		case <-d.sock.closeq: // exit if parent socket closed
   645  			if p != nil {
   646  				p.Close()
   647  			}
   648  			return
   649  		case <-time.After(rtime):
   650  			if rtmax > 0 {
   651  				rtime *= 2
   652  				if rtime > rtmax {
   653  					rtime = rtmax
   654  				}
   655  			}
   656  			continue
   657  		}
   658  	}
   659  }
   660  
   661  type listener struct {
   662  	l    PipeListener
   663  	sock *socket
   664  	addr string
   665  }
   666  
   667  func (l *listener) GetOption(n string) (interface{}, error) {
   668  	return l.l.GetOption(n)
   669  }
   670  
   671  func (l *listener) SetOption(n string, v interface{}) error {
   672  	return l.l.SetOption(n, v)
   673  }
   674  
   675  // serve spins in a loop, calling the accepter's Accept routine.
   676  func (l *listener) serve() {
   677  	for {
   678  		select {
   679  		case <-l.sock.closeq:
   680  			return
   681  		default:
   682  		}
   683  
   684  		// If the underlying PipeListener is closed, or not
   685  		// listening, we expect to return back with an error.
   686  		if pipe, err := l.l.Accept(); err == nil {
   687  			l.sock.addPipe(pipe, nil, l)
   688  		} else if err == ErrClosed {
   689  			return
   690  		}
   691  	}
   692  }
   693  
   694  func (l *listener) Listen() error {
   695  	// This function sets up a goroutine to accept inbound connections.
   696  	// The accepted connection will be added to a list of accepted
   697  	// connections.  The Listener just needs to listen continuously,
   698  	// as we assume that we want to continue to receive inbound
   699  	// connections without limit.
   700  
   701  	if err := l.l.Listen(); err != nil {
   702  		return err
   703  	}
   704  	l.sock.Lock()
   705  	l.sock.listeners = append(l.sock.listeners, l)
   706  	l.sock.active = true
   707  	l.sock.Unlock()
   708  	go l.serve()
   709  	return nil
   710  }
   711  
   712  func (l *listener) Address() string {
   713  	return l.l.Address()
   714  }
   715  
   716  func (l *listener) Close() error {
   717  	return l.l.Close()
   718  }