github.com/ronaksoft/rony@v0.16.26-0.20230807065236-1743dbfe6959/edgec/ws.go (about)

     1  package edgec
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"strings"
     7  	"sync"
     8  	"sync/atomic"
     9  	"time"
    10  
    11  	"go.opentelemetry.io/otel/attribute"
    12  
    13  	"go.opentelemetry.io/otel/codes"
    14  
    15  	"github.com/ronaksoft/rony/registry"
    16  
    17  	"go.opentelemetry.io/otel/propagation"
    18  
    19  	"github.com/ronaksoft/rony"
    20  	"github.com/ronaksoft/rony/errors"
    21  	"github.com/ronaksoft/rony/log"
    22  	"github.com/ronaksoft/rony/pools"
    23  	"github.com/ronaksoft/rony/tools"
    24  	"go.opentelemetry.io/otel/trace"
    25  	"go.uber.org/zap"
    26  )
    27  
    28  /*
    29     Creation Time: 2020 - Jul - 17
    30     Created by:  (ehsan)
    31     Maintainers:
    32        1.  Ehsan N. Moosa (E2)
    33     Auditor: Ehsan N. Moosa (E2)
    34     Copyright Ronak Software Group 2020
    35  */
    36  
    37  var (
    38  	_ Client = &Websocket{}
    39  )
    40  
    41  const (
    42  	requestTimeout = 3 * time.Second
    43  	requestRetry   = 5
    44  	dialTimeout    = 3 * time.Second
    45  	idleTimeout    = time.Minute
    46  )
    47  
    48  type MessageHandler func(m *rony.MessageEnvelope)
    49  
    50  type ConnectHandler func(c *Websocket)
    51  
    52  // WebsocketConfig holds the configs for the Websocket client
    53  type WebsocketConfig struct {
    54  	Name         string
    55  	SeedHostPort string
    56  	IdleTimeout  time.Duration
    57  	DialTimeout  time.Duration
    58  	// Handler must not block in function because other incoming messages might get blocked.
    59  	// This handler must return quickly and pass a deep copy of the MessageEnvelope to other
    60  	// routines.
    61  	Handler    MessageHandler
    62  	HeaderFunc func() map[string]string
    63  	Secure     bool
    64  	// RequestMaxRetry is the maximum number client sends a request if any network layer error occurs
    65  	RequestMaxRetry int
    66  	// RequestTimeout is the timeout for each individual request on each try.
    67  	RequestTimeout time.Duration
    68  	// ContextTimeout is the amount that Send function will wait until times out. This includes all the retries.
    69  	ContextTimeout time.Duration
    70  	// Router is an optional parameter which give more control over selecting the target host based on each request.
    71  	Router Router
    72  	// OnConnect will be called everytime the websocket connection is established.
    73  	OnConnect ConnectHandler
    74  	Tracer    trace.Tracer
    75  }
    76  
    77  // Websocket client which could handle multiple connections
    78  type Websocket struct {
    79  	cfg            WebsocketConfig
    80  	sessionReplica uint64
    81  	nextReqID      uint64
    82  	logger         log.Logger
    83  	tracer         trace.Tracer
    84  	propagator     propagation.TraceContext
    85  
    86  	// Connection Pool
    87  	connsMtx       sync.RWMutex
    88  	connsByReplica map[uint64]map[string]*wsConn
    89  	connsByID      map[string]*wsConn
    90  
    91  	// FLusher
    92  	flusherPool *tools.FlusherPool
    93  
    94  	// Flying Requests
    95  	pendingMtx tools.SpinLock
    96  	pending    map[uint64]chan *rony.MessageEnvelope
    97  }
    98  
    99  func NewWebsocket(config WebsocketConfig) *Websocket {
   100  	c := &Websocket{
   101  		nextReqID:      tools.RandomUint64(0),
   102  		cfg:            config,
   103  		connsByReplica: make(map[uint64]map[string]*wsConn, 64),
   104  		connsByID:      make(map[string]*wsConn, 64),
   105  		pending:        make(map[uint64]chan *rony.MessageEnvelope, 1024),
   106  		logger:         log.With("EdgeC(Websocket)"),
   107  		tracer:         config.Tracer,
   108  	}
   109  
   110  	c.flusherPool = tools.NewFlusherPool(1, 100, c.sendFunc)
   111  
   112  	// Prepare default config values
   113  	if c.cfg.DialTimeout == 0 {
   114  		c.cfg.DialTimeout = dialTimeout
   115  	}
   116  	if c.cfg.IdleTimeout == 0 {
   117  		c.cfg.IdleTimeout = idleTimeout
   118  	}
   119  	if c.cfg.RequestMaxRetry == 0 {
   120  		c.cfg.RequestMaxRetry = requestRetry
   121  	}
   122  	if c.cfg.RequestTimeout == 0 {
   123  		c.cfg.RequestTimeout = requestTimeout
   124  	}
   125  	if c.cfg.ContextTimeout == 0 {
   126  		c.cfg.ContextTimeout = c.cfg.RequestTimeout * time.Duration(c.cfg.RequestMaxRetry)
   127  	}
   128  	if c.cfg.Router == nil {
   129  		c.cfg.Router = &wsRouter{
   130  			c: c,
   131  		}
   132  	}
   133  	if c.cfg.Handler == nil {
   134  		c.cfg.Handler = func(m *rony.MessageEnvelope) {}
   135  	}
   136  
   137  	return c
   138  }
   139  
   140  func (ws *Websocket) Tracer() trace.Tracer {
   141  	return ws.cfg.Tracer
   142  }
   143  
   144  func (ws *Websocket) GetRequestID() uint64 {
   145  	return atomic.AddUint64(&ws.nextReqID, 1)
   146  }
   147  
   148  func (ws *Websocket) Start() error {
   149  	initConn := ws.newConn("", 0, ws.cfg.SeedHostPort)
   150  	ws.addConn("", 0, initConn)
   151  
   152  	req := rony.PoolMessageEnvelope.Get()
   153  	defer rony.PoolMessageEnvelope.Put(req)
   154  	res := rony.PoolMessageEnvelope.Get()
   155  	defer rony.PoolMessageEnvelope.Put(res)
   156  	req.Fill(ws.GetRequestID(), rony.C_GetNodes, &rony.GetNodes{})
   157  
   158  	if err := ws.Send(context.Background(), req, res); err != nil {
   159  		return err
   160  	}
   161  	switch res.Constructor {
   162  	case rony.C_Edges:
   163  		x := &rony.Edges{}
   164  		_ = x.Unmarshal(res.Message)
   165  		_ = initConn.close()
   166  		ws.removeConn("", 0)
   167  		for _, n := range x.Nodes {
   168  			if ce := ws.logger.Check(log.DebugLevel, "NodeInfo"); ce != nil {
   169  				ce.Write(
   170  					zap.String("ServerID", n.ServerID),
   171  					zap.Uint64("RS", n.ReplicaSet),
   172  					zap.Strings("HostPorts", n.HostPorts),
   173  				)
   174  			}
   175  			wsc := ws.newConn(n.ServerID, n.ReplicaSet, n.HostPorts...)
   176  
   177  			ws.addConn(n.ServerID, n.ReplicaSet, wsc)
   178  			ws.sessionReplica = n.ReplicaSet
   179  		}
   180  
   181  	default:
   182  		return ErrUnknownResponse
   183  	}
   184  
   185  	return nil
   186  }
   187  
   188  func (ws *Websocket) addConn(serverID string, replicaSet uint64, wsc *wsConn) {
   189  	ws.logger.Debug("Pool connection added",
   190  		zap.String("ServerID", serverID),
   191  		zap.Uint64("RS", replicaSet),
   192  	)
   193  	ws.connsMtx.Lock()
   194  	defer ws.connsMtx.Unlock()
   195  
   196  	if ws.connsByReplica[replicaSet] == nil {
   197  		ws.connsByReplica[replicaSet] = make(map[string]*wsConn, 16)
   198  	}
   199  	ws.connsByID[serverID] = wsc
   200  	ws.connsByReplica[replicaSet][serverID] = wsc
   201  }
   202  
   203  func (ws *Websocket) removeConn(serverID string, replicaSet uint64) {
   204  	ws.connsMtx.Lock()
   205  	defer ws.connsMtx.Unlock()
   206  
   207  	if ws.connsByReplica[replicaSet] != nil {
   208  		delete(ws.connsByReplica[replicaSet], serverID)
   209  	}
   210  	delete(ws.connsByID, serverID)
   211  }
   212  
   213  func (ws *Websocket) getConnByReplica(replicaSet uint64) *wsConn {
   214  	ws.connsMtx.RLock()
   215  	defer ws.connsMtx.RUnlock()
   216  
   217  	m := ws.connsByReplica[replicaSet]
   218  	for _, c := range m {
   219  		return c
   220  	}
   221  
   222  	return nil
   223  }
   224  
   225  func (ws *Websocket) getConnByID(serverID string) *wsConn {
   226  	ws.connsMtx.RLock()
   227  	defer ws.connsMtx.RUnlock()
   228  
   229  	wsc := ws.connsByID[serverID]
   230  
   231  	return wsc
   232  }
   233  
   234  func (ws *Websocket) newConn(id string, replicaSet uint64, hostPorts ...string) *wsConn {
   235  	wsc := &wsConn{
   236  		serverID:   id,
   237  		ws:         ws,
   238  		replicaSet: replicaSet,
   239  		hostPorts:  hostPorts,
   240  		secure:     ws.cfg.Secure,
   241  	}
   242  
   243  	return wsc
   244  }
   245  
   246  func (ws *Websocket) sendFunc(serverID string, entries []tools.FlushEntry) {
   247  	wsc := ws.getConnByID(serverID)
   248  	if wsc == nil {
   249  		// TODO:: for each entry we must return
   250  		return
   251  	}
   252  
   253  	// Check if we have active connection
   254  	if !wsc.connected {
   255  		wsc.connect()
   256  	}
   257  
   258  	me := rony.PoolMessageEnvelope.Get()
   259  	defer rony.PoolMessageEnvelope.Put(me)
   260  
   261  	switch len(entries) {
   262  	case 0:
   263  		// There is nothing to do, Probably a bug if we are here
   264  		return
   265  	case 1:
   266  		ev := entries[0].Value().(*wsRequest)
   267  		ws.pendingMtx.Lock()
   268  		ws.pending[ev.req.GetRequestID()] = ev.resChan
   269  		ws.pendingMtx.Unlock()
   270  		ev.req.DeepCopy(me)
   271  
   272  	default:
   273  		mc := rony.PoolMessageContainer.Get()
   274  		for _, e := range entries {
   275  			ev := e.Value().(*wsRequest)
   276  			ws.pendingMtx.Lock()
   277  			ws.pending[ev.req.GetRequestID()] = ev.resChan
   278  			ws.pendingMtx.Unlock()
   279  			mc.Envelopes = append(mc.Envelopes, ev.req.Clone())
   280  			mc.Length += 1
   281  		}
   282  		me.Fill(0, rony.C_MessageContainer, mc)
   283  		rony.PoolMessageContainer.Put(mc)
   284  	}
   285  
   286  	if err := wsc.send(me); err != nil {
   287  		ws.logger.Warn("got error on sending request", zap.Error(err))
   288  	}
   289  }
   290  
   291  func (ws *Websocket) Send(ctx context.Context, req, res *rony.MessageEnvelope) (err error) {
   292  	err = ws.SendWithDetails(ctx, req, res, ws.cfg.RequestMaxRetry, ws.cfg.RequestTimeout, "")
   293  
   294  	return
   295  }
   296  
   297  func (ws *Websocket) SendTo(ctx context.Context, req, res *rony.MessageEnvelope, serverID string) error {
   298  	return ws.SendWithDetails(ctx, req, res, ws.cfg.RequestMaxRetry, ws.cfg.RequestTimeout, serverID)
   299  }
   300  
   301  func (ws *Websocket) SendWithDetails(
   302  	ctx context.Context,
   303  	req, res *rony.MessageEnvelope, retry int, timeout time.Duration, serverID string,
   304  ) error {
   305  	var (
   306  		wsc *wsConn
   307  		rs  uint64
   308  	)
   309  
   310  	if serverID != "" {
   311  		wsc = ws.getConnByID(serverID)
   312  	} else {
   313  		rs = ws.cfg.Router.GetRoute(req)
   314  		wsc = ws.getConnByReplica(rs)
   315  	}
   316  
   317  	if ws.tracer != nil {
   318  		var span trace.Span
   319  		ctx, span = ws.tracer.
   320  			Start(
   321  				ctx,
   322  				fmt.Sprintf("%s.%s", ws.cfg.Name, registry.C(req.Constructor)),
   323  				trace.WithSpanKind(trace.SpanKindClient),
   324  			)
   325  		defer span.End()
   326  
   327  		ws.propagator.Inject(ctx, req.Carrier())
   328  	}
   329  
   330  Loop:
   331  	if wsc == nil {
   332  		// TODO:: try to gather information about the target
   333  		err := ErrNoConnection
   334  		trace.SpanFromContext(ctx).SetStatus(codes.Error, err.Error())
   335  
   336  		return err
   337  	}
   338  
   339  	wsReq := &wsRequest{
   340  		req:     req,
   341  		resChan: make(chan *rony.MessageEnvelope, 1),
   342  	}
   343  	ws.flusherPool.Enter(wsc.serverID, tools.NewEntry(wsReq))
   344  
   345  	t := pools.AcquireTimer(timeout)
   346  	defer pools.ReleaseTimer(t)
   347  	select {
   348  	case x := <-wsReq.resChan:
   349  		switch x.GetConstructor() {
   350  		case rony.C_Redirect:
   351  			xx := &rony.Redirect{}
   352  			_ = xx.Unmarshal(x.GetMessage())
   353  			if retry--; retry < 0 {
   354  				return errors.ErrRetriesExceeded(fmt.Errorf("redirect"))
   355  			}
   356  			rs = ws.redirect(xx)
   357  			wsc = ws.getConnByReplica(rs)
   358  
   359  			trace.SpanFromContext(ctx).
   360  				AddEvent(
   361  					"Redirect",
   362  					trace.WithAttributes(
   363  						attribute.Int64("rony.replicaset", int64(rs)),
   364  					),
   365  				)
   366  
   367  			goto Loop
   368  		default:
   369  			x.DeepCopy(res)
   370  		}
   371  	case <-t.C:
   372  		ws.pendingMtx.Lock()
   373  		delete(ws.pending, req.GetRequestID())
   374  		ws.pendingMtx.Unlock()
   375  
   376  		err := ErrTimeout
   377  		trace.SpanFromContext(ctx).SetStatus(codes.Error, err.Error())
   378  
   379  		return err
   380  	}
   381  
   382  	trace.SpanFromContext(ctx).SetStatus(codes.Ok, "")
   383  
   384  	return nil
   385  }
   386  
   387  func (ws *Websocket) redirect(x *rony.Redirect) (replicaSet uint64) {
   388  	if ce := ws.logger.Check(log.DebugLevel, "received Redirect"); ce != nil {
   389  		ce.Write(
   390  			zap.Any("Edges", x.Edges),
   391  			zap.Any("Wait", x.WaitInSec),
   392  		)
   393  	}
   394  
   395  	if len(x.Edges) == 0 {
   396  		return
   397  	}
   398  
   399  	for _, n := range x.Edges {
   400  		ws.addConn(
   401  			n.ServerID, n.ReplicaSet,
   402  			ws.newConn(n.ServerID, n.ReplicaSet, n.HostPorts...),
   403  		)
   404  	}
   405  
   406  	switch x.Reason {
   407  	case rony.RedirectReason_ReplicaSetSession:
   408  		ws.sessionReplica = x.Edges[0].ReplicaSet
   409  	case rony.RedirectReason_ReplicaSetRequest:
   410  	default:
   411  	}
   412  
   413  	replicaSet = x.Edges[0].ReplicaSet
   414  
   415  	return
   416  }
   417  
   418  func (ws *Websocket) Close() error {
   419  	ws.connsMtx.RLock()
   420  	defer ws.connsMtx.RUnlock()
   421  
   422  	for _, conns := range ws.connsByReplica {
   423  		for _, c := range conns {
   424  			_ = c.close()
   425  		}
   426  	}
   427  
   428  	return nil
   429  }
   430  
   431  func (ws *Websocket) ConnInfo() string {
   432  	sb := strings.Builder{}
   433  	sb.WriteString("\n-----\n")
   434  	ws.connsMtx.Lock()
   435  	for id, wsc := range ws.connsByID {
   436  		sb.WriteString(
   437  			fmt.Sprintf(
   438  				"%s: [RS=%d] [HostPorts=%v] [Connected: %t]\n",
   439  				id, wsc.replicaSet, wsc.hostPorts, wsc.connected,
   440  			),
   441  		)
   442  	}
   443  	ws.connsMtx.Unlock()
   444  	sb.WriteString("-----\n")
   445  
   446  	return sb.String()
   447  }
   448  
   449  func (ws *Websocket) ClusterInfo(replicaSets ...uint64) (*rony.Edges, error) {
   450  	req := rony.PoolMessageEnvelope.Get()
   451  	defer rony.PoolMessageEnvelope.Put(req)
   452  	res := rony.PoolMessageEnvelope.Get()
   453  	defer rony.PoolMessageEnvelope.Put(res)
   454  	req.Fill(ws.GetRequestID(), rony.C_GetNodes, &rony.GetNodes{ReplicaSet: replicaSets})
   455  	if err := ws.Send(context.Background(), req, res); err != nil {
   456  		return nil, err
   457  	}
   458  	switch res.GetConstructor() {
   459  	case rony.C_Edges:
   460  		x := &rony.Edges{}
   461  		_ = x.Unmarshal(res.GetMessage())
   462  
   463  		return x, nil
   464  	case rony.C_Error:
   465  		x := &rony.Error{}
   466  		_ = x.Unmarshal(res.GetMessage())
   467  
   468  		return nil, x
   469  	}
   470  
   471  	return nil, ErrUnknownResponse
   472  }
   473  
   474  type wsRouter struct {
   475  	c *Websocket
   476  }
   477  
   478  func (d *wsRouter) UpdateRoute(_ *rony.MessageEnvelope, replicaSet uint64) {
   479  	// TODO:: implement cache maybe
   480  }
   481  
   482  func (d *wsRouter) GetRoute(_ *rony.MessageEnvelope) (replicaSet uint64) {
   483  	return d.c.sessionReplica
   484  }
   485  
   486  type wsRequest struct {
   487  	req     *rony.MessageEnvelope
   488  	resChan chan *rony.MessageEnvelope
   489  }