nanomsg.org/go/mangos/v2@v2.0.9-0.20200203084354-8a092611e461/protocol/surveyor/surveyor.go (about)

     1  // Copyright 2019 The Mangos Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use file except in compliance with the License.
     5  // You may obtain a copy of the license at
     6  //
     7  //    http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package surveyor implements the SURVEYOR protocol. This sends messages
    16  // out to RESPONDENT partners, and receives their responses.
    17  package surveyor
    18  
    19  import (
    20  	"encoding/binary"
    21  	"sync"
    22  	"sync/atomic"
    23  	"time"
    24  
    25  	"nanomsg.org/go/mangos/v2/protocol"
    26  )
    27  
    28  // Protocol identity information.
    29  const (
    30  	Self     = protocol.ProtoSurveyor
    31  	Peer     = protocol.ProtoRespondent
    32  	SelfName = "surveyor"
    33  	PeerName = "respondent"
    34  )
    35  
    36  const defaultSurveyTime = time.Second
    37  
    38  type pipe struct {
    39  	s      *socket
    40  	p      protocol.Pipe
    41  	closeQ chan struct{}
    42  	sendQ  chan *protocol.Message
    43  }
    44  
    45  type survey struct {
    46  	timer  *time.Timer
    47  	recvQ  chan *protocol.Message
    48  	active bool
    49  	id     uint32
    50  	ctx    *context
    51  	sock   *socket
    52  	err    error
    53  	once   sync.Once
    54  }
    55  
    56  type context struct {
    57  	s          *socket
    58  	closed     bool
    59  	closeQ     chan struct{}
    60  	recvQLen   int
    61  	recvExpire time.Duration
    62  	survExpire time.Duration
    63  	surv       *survey
    64  }
    65  
    66  type socket struct {
    67  	master   *context              // default context
    68  	ctxs     map[*context]struct{} // all contexts
    69  	surveys  map[uint32]*survey    // contexts by survey ID
    70  	pipes    map[uint32]*pipe      // all pipes by pipe ID
    71  	nextID   uint32                // next survey ID
    72  	closed   bool                  // true if closed
    73  	sendQLen int                   // send Q depth
    74  	sync.Mutex
    75  }
    76  
    77  var (
    78  	nilQ <-chan time.Time
    79  )
    80  
    81  const defaultQLen = 128
    82  
    83  func (s *survey) cancel(err error) {
    84  
    85  	s.once.Do(func() {
    86  		sock := s.sock
    87  		ctx := s.ctx
    88  
    89  		s.err = err
    90  		sock.Lock()
    91  		s.timer.Stop()
    92  		if ctx.surv == s {
    93  			ctx.surv = nil
    94  		}
    95  		delete(sock.surveys, s.id)
    96  		sock.Unlock()
    97  
    98  		// Don't close this until after we have removed it from
    99  		// the list of pending surveys, to prevent the receiver
   100  		// from trying to write to a closed channel.
   101  		close(s.recvQ)
   102  		for m := range s.recvQ {
   103  			m.Free()
   104  		}
   105  	})
   106  }
   107  
   108  func (s *survey) start(qLen int, expire time.Duration) {
   109  	// NB: Called with the socket lock held
   110  	s.recvQ = make(chan *protocol.Message, qLen)
   111  	s.sock.surveys[s.id] = s
   112  	s.ctx.surv = s
   113  	s.timer = time.AfterFunc(expire, func() {
   114  		s.cancel(protocol.ErrProtoState)
   115  	})
   116  }
   117  
   118  func (c *context) SendMsg(m *protocol.Message) error {
   119  	s := c.s
   120  
   121  	newsurv := &survey{
   122  		active: true,
   123  		id:     atomic.AddUint32(&s.nextID, 1) | 0x80000000,
   124  		ctx:    c,
   125  		sock:   s,
   126  	}
   127  
   128  	m.MakeUnique()
   129  	m.Header = make([]byte, 4)
   130  	binary.BigEndian.PutUint32(m.Header, newsurv.id)
   131  
   132  	s.Lock()
   133  	if s.closed || c.closed {
   134  		s.Unlock()
   135  		return protocol.ErrClosed
   136  	}
   137  	oldsurv := c.surv
   138  	newsurv.start(c.recvQLen, c.survExpire)
   139  	if oldsurv != nil {
   140  		go oldsurv.cancel(protocol.ErrCanceled)
   141  	}
   142  	pipes := make([]*pipe, 0, len(s.pipes))
   143  	for _, p := range s.pipes {
   144  		pipes = append(pipes, p)
   145  	}
   146  	s.Unlock()
   147  
   148  	// Best-effort broadcast on all pipes
   149  	for _, p := range pipes {
   150  		m.Clone()
   151  		select {
   152  		case p.sendQ <- m:
   153  		default:
   154  			m.Free()
   155  		}
   156  	}
   157  	m.Free()
   158  	return nil
   159  }
   160  
   161  func (c *context) RecvMsg() (*protocol.Message, error) {
   162  	s := c.s
   163  
   164  	s.Lock()
   165  	if s.closed {
   166  		s.Unlock()
   167  		return nil, protocol.ErrClosed
   168  	}
   169  	surv := c.surv
   170  	timeq := nilQ
   171  	if c.recvExpire > 0 {
   172  		timeq = time.After(c.recvExpire)
   173  	}
   174  	s.Unlock()
   175  
   176  	if surv == nil {
   177  		return nil, protocol.ErrProtoState
   178  	}
   179  	select {
   180  	case <-c.closeQ:
   181  		return nil, protocol.ErrClosed
   182  
   183  	case m := <-surv.recvQ:
   184  		if m == nil {
   185  			// Sometimes the recvQ can get closed ahead of
   186  			// the closeQ, but the closeQ takes precedence.
   187  			return nil, surv.err
   188  		}
   189  		return m, nil
   190  
   191  	case <-timeq:
   192  		return nil, protocol.ErrRecvTimeout
   193  	}
   194  }
   195  
   196  func (c *context) close() {
   197  	if !c.closed {
   198  		c.closed = true
   199  		close(c.closeQ)
   200  		if surv := c.surv; surv != nil {
   201  			c.surv = nil
   202  			go surv.cancel(protocol.ErrClosed)
   203  		}
   204  	}
   205  }
   206  
   207  func (c *context) Close() error {
   208  	c.s.Lock()
   209  	defer c.s.Unlock()
   210  	if c.closed {
   211  		return protocol.ErrClosed
   212  	}
   213  	c.close()
   214  	return nil
   215  }
   216  
   217  func (c *context) SetOption(name string, value interface{}) error {
   218  	switch name {
   219  	case protocol.OptionSurveyTime:
   220  		if v, ok := value.(time.Duration); ok {
   221  			c.s.Lock()
   222  			c.survExpire = v
   223  			c.s.Unlock()
   224  			return nil
   225  		}
   226  		return protocol.ErrBadValue
   227  
   228  	case protocol.OptionRecvDeadline:
   229  		if v, ok := value.(time.Duration); ok {
   230  			c.s.Lock()
   231  			c.recvExpire = v
   232  			c.s.Unlock()
   233  			return nil
   234  		}
   235  		return protocol.ErrBadValue
   236  
   237  	case protocol.OptionReadQLen:
   238  		if v, ok := value.(int); ok && v >= 0 {
   239  			// this will only affect new surveys
   240  			c.s.Lock()
   241  			c.recvQLen = v
   242  			c.s.Unlock()
   243  			return nil
   244  		}
   245  		return protocol.ErrBadValue
   246  	}
   247  
   248  	return protocol.ErrBadOption
   249  }
   250  
   251  func (c *context) GetOption(option string) (interface{}, error) {
   252  	switch option {
   253  	case protocol.OptionSurveyTime:
   254  		c.s.Lock()
   255  		v := c.survExpire
   256  		c.s.Unlock()
   257  		return v, nil
   258  	case protocol.OptionRecvDeadline:
   259  		c.s.Lock()
   260  		v := c.recvExpire
   261  		c.s.Unlock()
   262  		return v, nil
   263  	case protocol.OptionReadQLen:
   264  		c.s.Lock()
   265  		v := c.recvQLen
   266  		c.s.Unlock()
   267  		return v, nil
   268  	}
   269  
   270  	return nil, protocol.ErrBadOption
   271  }
   272  
   273  func (p *pipe) close() {
   274  	_ = p.p.Close()
   275  }
   276  
   277  func (p *pipe) sender() {
   278  outer:
   279  	for {
   280  		var m *protocol.Message
   281  		select {
   282  		case <-p.closeQ:
   283  			break outer
   284  		case m = <-p.sendQ:
   285  		}
   286  
   287  		if err := p.p.SendMsg(m); err != nil {
   288  			m.Free()
   289  			break
   290  		}
   291  	}
   292  	p.close()
   293  }
   294  
   295  func (p *pipe) receiver() {
   296  	s := p.s
   297  	for {
   298  		m := p.p.RecvMsg()
   299  		if m == nil {
   300  			break
   301  		}
   302  		if len(m.Body) < 4 {
   303  			m.Free()
   304  			continue
   305  		}
   306  		m.Header = append(m.Header, m.Body[:4]...)
   307  		m.Body = m.Body[4:]
   308  
   309  		id := binary.BigEndian.Uint32(m.Header)
   310  
   311  		s.Lock()
   312  		if surv, ok := s.surveys[id]; ok {
   313  			select {
   314  			case surv.recvQ <- m:
   315  				m = nil
   316  			default:
   317  			}
   318  		}
   319  		s.Unlock()
   320  
   321  		if m != nil {
   322  			m.Free()
   323  		}
   324  	}
   325  }
   326  
   327  func (s *socket) OpenContext() (protocol.Context, error) {
   328  	s.Lock()
   329  	defer s.Unlock()
   330  	if s.closed {
   331  		return nil, protocol.ErrClosed
   332  	}
   333  	c := &context{
   334  		s:          s,
   335  		closeQ:     make(chan struct{}),
   336  		survExpire: s.master.survExpire,
   337  		recvExpire: s.master.recvExpire,
   338  		recvQLen:   s.master.recvQLen,
   339  	}
   340  	s.ctxs[c] = struct{}{}
   341  	return c, nil
   342  }
   343  
   344  func (s *socket) SendMsg(m *protocol.Message) error {
   345  	return s.master.SendMsg(m)
   346  }
   347  
   348  func (s *socket) RecvMsg() (*protocol.Message, error) {
   349  	return s.master.RecvMsg()
   350  }
   351  
   352  func (s *socket) AddPipe(pp protocol.Pipe) error {
   353  	p := &pipe{
   354  		p:      pp,
   355  		s:      s,
   356  		sendQ:  make(chan *protocol.Message, s.sendQLen),
   357  		closeQ: make(chan struct{}),
   358  	}
   359  	pp.SetPrivate(p)
   360  	s.Lock()
   361  	defer s.Unlock()
   362  	if s.closed {
   363  		return protocol.ErrClosed
   364  	}
   365  	s.pipes[p.p.ID()] = p
   366  	go p.receiver()
   367  	go p.sender()
   368  	return nil
   369  }
   370  
   371  func (s *socket) RemovePipe(pp protocol.Pipe) {
   372  	p := pp.GetPrivate().(*pipe)
   373  	close(p.closeQ)
   374  	s.Lock()
   375  	delete(s.pipes, pp.ID())
   376  	s.Unlock()
   377  }
   378  
   379  func (s *socket) Close() error {
   380  	s.Lock()
   381  	if s.closed {
   382  		s.Unlock()
   383  		return protocol.ErrClosed
   384  	}
   385  	s.closed = true
   386  	for c := range s.ctxs {
   387  		c.close()
   388  	}
   389  	s.Unlock()
   390  	return nil
   391  }
   392  
   393  func (s *socket) GetOption(option string) (interface{}, error) {
   394  	switch option {
   395  	case protocol.OptionRaw:
   396  		return false, nil
   397  	case protocol.OptionWriteQLen:
   398  		s.Lock()
   399  		v := s.sendQLen
   400  		s.Unlock()
   401  		return v, nil
   402  
   403  	default:
   404  		return s.master.GetOption(option)
   405  	}
   406  }
   407  
   408  func (s *socket) SetOption(option string, value interface{}) error {
   409  	switch option {
   410  	case protocol.OptionWriteQLen:
   411  		if v, ok := value.(int); ok && v >= 0 {
   412  			s.Lock()
   413  			s.sendQLen = v
   414  			s.Unlock()
   415  			return nil
   416  		}
   417  		return protocol.ErrBadValue
   418  	}
   419  	return s.master.SetOption(option, value)
   420  }
   421  
   422  func (*socket) Info() protocol.Info {
   423  	return protocol.Info{
   424  		Self:     Self,
   425  		Peer:     Peer,
   426  		SelfName: SelfName,
   427  		PeerName: PeerName,
   428  	}
   429  }
   430  
   431  // NewProtocol returns a new protocol implementation.
   432  func NewProtocol() protocol.Protocol {
   433  	s := &socket{
   434  		pipes:    make(map[uint32]*pipe),
   435  		surveys:  make(map[uint32]*survey),
   436  		ctxs:     make(map[*context]struct{}),
   437  		sendQLen: defaultQLen,
   438  		nextID:   uint32(time.Now().UnixNano()), // quasi-random
   439  	}
   440  	s.master = &context{
   441  		s:          s,
   442  		closeQ:     make(chan struct{}),
   443  		recvQLen:   defaultQLen,
   444  		survExpire: defaultSurveyTime,
   445  	}
   446  	s.ctxs[s.master] = struct{}{}
   447  	return s
   448  }
   449  
   450  // NewSocket allocates a new Socket using the RESPONDENT protocol.
   451  func NewSocket() (protocol.Socket, error) {
   452  	return protocol.MakeSocket(NewProtocol()), nil
   453  }