github.com/cmd-stream/base-go@v0.0.0-20230813145615-dd6ac24c16f5/client/client.go (about)

     1  package client
     2  
     3  import (
     4  	"context"
     5  	"io"
     6  	"net"
     7  	"sync"
     8  	"sync/atomic"
     9  	"time"
    10  
    11  	"github.com/cmd-stream/base-go"
    12  )
    13  
    14  const (
    15  	inProgress int = iota
    16  	closed
    17  )
    18  
    19  // UnexpectedResultHandler is a handler which is used to handle unexpected
    20  // results.
    21  //
    22  // That is, when the sequence number of the result does not match the sequence
    23  // number of any command (waiting for the result) that was send by Client.
    24  type UnexpectedResultHandler func(seq base.Seq, result base.Result)
    25  
    26  // New creates a new Client.
    27  //
    28  // The handler parameter may be nil.
    29  func New[T any](delegate base.ClientDelegate[T],
    30  	handler UnexpectedResultHandler) *Client[T] {
    31  	var (
    32  		ctx, cancel        = context.WithCancel(context.Background())
    33  		flagFl      uint32 = 0
    34  		client             = Client[T]{
    35  			cancel:   cancel,
    36  			delegate: delegate,
    37  			waiting:  make(map[base.Seq]chan<- base.AsyncResult),
    38  			handler:  handler,
    39  			done:     make(chan struct{}),
    40  			flagFl:   &flagFl,
    41  			chFl:     make(chan error, 1),
    42  		}
    43  	)
    44  	go receive[T](ctx, &client)
    45  	return &client
    46  }
    47  
    48  // Client is an asynchronious cmd-stream client. It is thread-safe, so you can
    49  // use it from different goroutines. Here is some of its features:
    50  //   - Once created, Client is in the connected state, where it can send
    51  //     commands and receive results.
    52  //   - It uses only one connection.
    53  //   - For each command Client generates a unique sequence number, thanks to
    54  //     which it maps results to commands.
    55  //   - If a command timeout elapsed, you can call the Client.Forget() method, so
    56  //     that Client stops waiting for the results of this command.
    57  //   - You can set a handler to handle unexpected results received from the
    58  //     server.
    59  //   - To terminate the connection to the server and stop Client, call the
    60  //     Clinet.Close() method.
    61  //   - You can find out when Client is done with the Client.Done() method.
    62  //   - If a connection fails with an error Client fails too.
    63  //   - If Client was failed, you can get a connection error using the
    64  //     Client.Err() method.
    65  //   - If Client is closed, all commands waiting for the results will
    66  //     receive an error (AsyncResult.Error != nil).
    67  type Client[T any] struct {
    68  	cancel   context.CancelFunc
    69  	state    int
    70  	delegate base.ClientDelegate[T]
    71  	seq      base.Seq
    72  	waiting  map[base.Seq]chan<- base.AsyncResult
    73  	handler  UnexpectedResultHandler
    74  	err      error
    75  	done     chan struct{}
    76  	flagFl   *uint32
    77  	chFl     chan error
    78  	muSn     sync.Mutex
    79  	muWt     sync.Mutex
    80  	muEr     sync.Mutex
    81  	muSt     sync.Mutex
    82  }
    83  
    84  // Send sends a command.
    85  //
    86  // Adds the command results received from the server to the results channel. If
    87  // the last one is not large enough, getting results for all commands may hang.
    88  //
    89  // For each command, generates a unique sequence number, starting with 1.
    90  // Thus, a command with seq == 1 is sent first, with seq == 2 is sent second,
    91  // and so on. 0 is reserved for the Ping-Pong game, which keeps a connection
    92  // alive.
    93  //
    94  // Returns the sequence number and an error != nil if the command was not send.
    95  func (c *Client[T]) Send(cmd base.Cmd[T], results chan<- base.AsyncResult) (
    96  	seq base.Seq, err error) {
    97  	var chFl chan error
    98  	c.muSn.Lock()
    99  	chFl = c.chFl
   100  	c.seq++
   101  	seq = c.seq
   102  	c.memorize(seq, results)
   103  	err = c.delegate.Send(seq, cmd)
   104  	if err != nil {
   105  		c.muSn.Unlock()
   106  		c.Forget(seq)
   107  		return
   108  	}
   109  	c.muSn.Unlock()
   110  	return seq, c.flush(seq, chFl)
   111  }
   112  
   113  // SendWithDeadline sends a command with a deadline.
   114  //
   115  // Use this method if you want to send a command and specify the send deadline.
   116  // In all other it performs like the Send method.
   117  func (c *Client[T]) SendWithDeadline(deadline time.Time, cmd base.Cmd[T],
   118  	results chan<- base.AsyncResult) (seq base.Seq, err error) {
   119  	var chFl chan error
   120  	c.muSn.Lock()
   121  	chFl = c.chFl
   122  	c.seq++
   123  	seq = c.seq
   124  	c.memorize(seq, results)
   125  	err = c.delegate.SetSendDeadline(deadline)
   126  	if err != nil {
   127  		c.muSn.Unlock()
   128  		c.Forget(seq)
   129  		return
   130  	}
   131  	err = c.delegate.Send(seq, cmd)
   132  	if err != nil {
   133  		c.muSn.Unlock()
   134  		c.Forget(seq)
   135  		return
   136  	}
   137  	c.muSn.Unlock()
   138  	return seq, c.flush(seq, chFl)
   139  }
   140  
   141  // Has checks if the command with the specified sequence number has been sent
   142  // by Client and still waiting for the result.
   143  func (c *Client[T]) Has(seq base.Seq) bool {
   144  	_, pst := c.load(seq)
   145  	return pst
   146  }
   147  
   148  // Forget makes Client to forget about the command which still waiting for the
   149  // result.
   150  //
   151  // After calling Forget, all the results of the corresponding command will be
   152  // handled with UnexpectedResultHandler.
   153  func (c *Client[T]) Forget(seq base.Seq) {
   154  	c.unmemorize(seq)
   155  }
   156  
   157  // Done returns a channel that is closed when Client terminates.
   158  func (c *Client[T]) Done() <-chan struct{} {
   159  	return c.done
   160  }
   161  
   162  // Err returns a connection error, when Client is failed.
   163  func (c *Client[T]) Err() (err error) {
   164  	c.muEr.Lock()
   165  	err = c.err
   166  	c.muEr.Unlock()
   167  	return
   168  }
   169  
   170  // Close closes Client.
   171  func (c *Client[T]) Close() (err error) {
   172  	c.muSt.Lock()
   173  	defer c.muSt.Unlock()
   174  	c.state = closed
   175  	if err = c.delegate.Close(); err != nil {
   176  		c.state = inProgress
   177  		return
   178  	}
   179  	c.cancel()
   180  	return
   181  }
   182  
   183  func (c *Client[T]) receive(ctx context.Context) (err error) {
   184  	defer c.unmemorizeAll(err)
   185  	var (
   186  		seq     base.Seq
   187  		result  base.Result
   188  		results chan<- base.AsyncResult
   189  		pst     bool
   190  	)
   191  	for {
   192  		seq, result, err = c.delegate.Receive()
   193  		if err != nil {
   194  			return
   195  		}
   196  		if result.LastOne() {
   197  			results, pst = c.loadAndUnmemorize(seq)
   198  		} else {
   199  			results, pst = c.load(seq)
   200  		}
   201  		if !pst && c.handler != nil {
   202  			c.handler(seq, result)
   203  			continue
   204  		}
   205  		select {
   206  		case <-ctx.Done():
   207  			return context.Canceled
   208  		case results <- base.AsyncResult{Seq: seq, Result: result}:
   209  			continue
   210  		}
   211  	}
   212  }
   213  
   214  func (c *Client[T]) memorize(seq base.Seq, results chan<- base.AsyncResult) {
   215  	c.muWt.Lock()
   216  	c.waiting[seq] = results
   217  	c.muWt.Unlock()
   218  }
   219  
   220  func (c *Client[T]) unmemorize(seq base.Seq) {
   221  	c.muWt.Lock()
   222  	delete(c.waiting, seq)
   223  	c.muWt.Unlock()
   224  }
   225  
   226  func (c *Client[T]) loadAndUnmemorize(seq base.Seq) (
   227  	results chan<- base.AsyncResult, pst bool) {
   228  	c.muWt.Lock()
   229  	results, pst = c.waiting[seq]
   230  	if pst {
   231  		delete(c.waiting, seq)
   232  	}
   233  	c.muWt.Unlock()
   234  	return
   235  }
   236  
   237  func (c *Client[T]) load(seq base.Seq) (results chan<- base.AsyncResult,
   238  	pst bool) {
   239  	c.muWt.Lock()
   240  	results, pst = c.waiting[seq]
   241  	c.muWt.Unlock()
   242  	return
   243  }
   244  
   245  func (c *Client[T]) flush(seq base.Seq, chFl chan error) (err error) {
   246  	if swapped := atomic.CompareAndSwapUint32(c.flagFl, 0, 1); !swapped {
   247  		err = <-chFl
   248  		if err != nil {
   249  			chFl <- err
   250  			c.Forget(seq)
   251  		}
   252  		return
   253  	}
   254  	c.muSn.Lock()
   255  	err = c.delegate.Flush()
   256  	if err != nil {
   257  		c.chFl <- err
   258  		c.changeChFl()
   259  		c.muSn.Unlock()
   260  		c.Forget(seq)
   261  		return
   262  	}
   263  	close(c.chFl)
   264  	c.changeChFl()
   265  	c.muSn.Unlock()
   266  	return
   267  }
   268  
   269  func (c *Client[T]) changeChFl() {
   270  	c.chFl = make(chan error, 1)
   271  	atomic.CompareAndSwapUint32(c.flagFl, 1, 0)
   272  }
   273  
   274  func (c *Client[T]) rangeAndUnmemorize(
   275  	fn func(seq base.Seq, results chan<- base.AsyncResult)) {
   276  	c.muWt.Lock()
   277  	for seq, results := range c.waiting {
   278  		fn(seq, results)
   279  		delete(c.waiting, seq)
   280  	}
   281  	c.muWt.Unlock()
   282  }
   283  
   284  func (c *Client[T]) exit(cause error) (err error) {
   285  	c.muSt.Lock()
   286  	if c.state != closed {
   287  		if err = c.delegate.Close(); err != nil {
   288  			c.muSt.Unlock()
   289  			return
   290  		}
   291  	}
   292  	c.muSt.Unlock()
   293  
   294  	c.muEr.Lock()
   295  	c.err = cause
   296  	c.muEr.Unlock()
   297  	close(c.done)
   298  	return
   299  }
   300  
   301  func (c *Client[T]) unmemorizeAll(cause error) {
   302  	c.rangeAndUnmemorize(func(seq base.Seq, results chan<- base.AsyncResult) {
   303  		queueErrResult(seq, cause, results)
   304  	})
   305  }
   306  
   307  func (c *Client[T]) correctErr(err error) error {
   308  	c.muSt.Lock()
   309  	defer c.muSt.Unlock()
   310  	switch c.state {
   311  	case inProgress:
   312  		return err
   313  	case closed:
   314  		return ErrClosed
   315  	default:
   316  		panic("unexpected state")
   317  	}
   318  }
   319  
   320  func receive[T any](ctx context.Context, client *Client[T]) {
   321  Start:
   322  	err := client.receive(ctx)
   323  	if err != nil {
   324  		err = client.correctErr(err)
   325  		if netError(err) || err == io.EOF { // TODO Test EOF.
   326  			if rdelegate, ok := client.delegate.(base.ClientReconnectDelegate[T]); ok {
   327  				if err = rdelegate.Reconnect(); err == nil {
   328  					goto Start
   329  				}
   330  			}
   331  		}
   332  	}
   333  	if err = client.exit(err); err != nil {
   334  		panic(err)
   335  	}
   336  }
   337  
   338  func queueErrResult(seq base.Seq, err error, results chan<- base.AsyncResult) {
   339  	select {
   340  	case results <- base.AsyncResult{Seq: seq, Error: err}:
   341  	default:
   342  	}
   343  }
   344  
   345  func netError(err error) bool {
   346  	_, ok := err.(net.Error)
   347  	return ok
   348  }