github.com/ronaksoft/rony@v0.16.26-0.20230807065236-1743dbfe6959/edgec/ws_conn.go (about)

     1  package edgec
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"net"
     7  	"strings"
     8  	"time"
     9  
    10  	"github.com/gobwas/ws"
    11  	"github.com/ronaksoft/rony"
    12  	wsutil "github.com/ronaksoft/rony/internal/gateway/tcp/util"
    13  	"github.com/ronaksoft/rony/pools"
    14  	"github.com/ronaksoft/rony/tools"
    15  	"go.uber.org/zap"
    16  	"google.golang.org/protobuf/proto"
    17  )
    18  
    19  /*
    20     Creation Time: 2021 - Jan - 04
    21     Created by:  (ehsan)
    22     Maintainers:
    23        1.  Ehsan N. Moosa (E2)
    24     Auditor: Ehsan N. Moosa (E2)
    25     Copyright Ronak Software Group 2020
    26  */
    27  
    28  type wsConn struct {
    29  	replicaSet uint64
    30  	serverID   string
    31  	ws         *Websocket
    32  	stop       bool
    33  	conn       net.Conn
    34  	dialer     ws.Dialer
    35  	connected  bool
    36  	hostPorts  []string
    37  	secure     bool
    38  }
    39  
    40  func (c *wsConn) createDialer(timeout time.Duration) {
    41  	c.dialer = ws.Dialer{
    42  		ReadBufferSize:  32 * 1024, // 32kB
    43  		WriteBufferSize: 32 * 1024, // 32kB
    44  		Timeout:         timeout,
    45  		NetDial: func(ctx context.Context, network, addr string) (conn net.Conn, err error) {
    46  			host, port, err := net.SplitHostPort(addr)
    47  			if err != nil {
    48  				return nil, err
    49  			}
    50  			ips, err := net.LookupIP(host)
    51  			if err != nil {
    52  				return nil, err
    53  			}
    54  			c.ws.logger.Debug("DNS LookIP", zap.String("Addr", addr), zap.Any("IPs", ips))
    55  			d := net.Dialer{Timeout: timeout}
    56  			for _, ip := range ips {
    57  				if ip.To4() != nil {
    58  					conn, err = d.DialContext(ctx, "tcp4", net.JoinHostPort(ip.String(), port))
    59  					if err != nil {
    60  						continue
    61  					}
    62  
    63  					return
    64  				}
    65  			}
    66  
    67  			return nil, ErrNoConnection
    68  		},
    69  		OnStatusError: nil,
    70  		OnHeader:      nil,
    71  		TLSClient:     nil,
    72  		TLSConfig:     nil,
    73  		WrapConn:      nil,
    74  	}
    75  }
    76  
    77  func (c *wsConn) connect() {
    78  	urlPrefix := "ws://"
    79  	if c.secure {
    80  		urlPrefix = "wss://"
    81  	}
    82  ConnectLoop:
    83  	c.ws.logger.Debug("Connect", zap.Strings("H", c.hostPorts))
    84  	c.createDialer(c.ws.cfg.DialTimeout)
    85  
    86  	sb := strings.Builder{}
    87  	if hf := c.ws.cfg.HeaderFunc; hf != nil {
    88  		for k, v := range hf() {
    89  			sb.WriteString(k)
    90  			sb.WriteString(": ")
    91  			sb.WriteString(v)
    92  			sb.WriteRune('\n')
    93  		}
    94  	}
    95  
    96  	c.dialer.Header = ws.HandshakeHeaderString(sb.String())
    97  	conn, _, _, err := c.dialer.Dial(context.Background(), fmt.Sprintf("%s%s", urlPrefix, c.hostPorts[0]))
    98  	if err != nil {
    99  		c.ws.logger.Debug("Dial failed", zap.Error(err), zap.Strings("Host", c.hostPorts))
   100  		time.Sleep(time.Duration(tools.RandomInt64(2000))*time.Millisecond + time.Second)
   101  
   102  		goto ConnectLoop
   103  	}
   104  	c.conn = conn
   105  	c.connected = true
   106  
   107  	go c.receiver()
   108  
   109  	if c.ws.cfg.OnConnect != nil {
   110  		c.ws.cfg.OnConnect(c.ws)
   111  	}
   112  }
   113  
   114  func (c *wsConn) receiver() {
   115  	var (
   116  		ms []wsutil.Message
   117  	)
   118  	// Receive Loop
   119  	for {
   120  		ms = ms[:0]
   121  		_ = c.conn.SetReadDeadline(time.Now().Add(c.ws.cfg.IdleTimeout))
   122  		ms, err := wsutil.ReadMessage(c.conn, ws.StateClientSide, ms)
   123  		if err != nil {
   124  			_ = c.conn.Close()
   125  			if !c.stop {
   126  				c.connected = false
   127  				c.connect()
   128  			}
   129  
   130  			break
   131  		}
   132  		for idx := range ms {
   133  			switch ms[idx].OpCode {
   134  			case ws.OpBinary, ws.OpText:
   135  				e := rony.PoolMessageEnvelope.Get()
   136  				_ = e.Unmarshal(ms[idx].Payload)
   137  				c.extractor(e)
   138  				rony.PoolMessageEnvelope.Put(e)
   139  			default:
   140  			}
   141  		}
   142  	}
   143  }
   144  
   145  func (c *wsConn) extractor(e *rony.MessageEnvelope) {
   146  	switch e.GetConstructor() {
   147  	case rony.C_MessageContainer:
   148  		x := rony.PoolMessageContainer.Get()
   149  		_ = x.Unmarshal(e.Message)
   150  		for idx := range x.Envelopes {
   151  			c.handler(x.Envelopes[idx])
   152  		}
   153  		rony.PoolMessageContainer.Put(x)
   154  	default:
   155  		c.handler(e)
   156  	}
   157  }
   158  
   159  func (c *wsConn) handler(e *rony.MessageEnvelope) {
   160  	defaultHandler := c.ws.cfg.Handler
   161  	if e.GetRequestID() == 0 {
   162  		if defaultHandler != nil {
   163  			defaultHandler(e)
   164  		}
   165  
   166  		return
   167  	}
   168  
   169  	c.ws.pendingMtx.Lock()
   170  	ch := c.ws.pending[e.GetRequestID()]
   171  	delete(c.ws.pending, e.GetRequestID())
   172  	c.ws.pendingMtx.Unlock()
   173  
   174  	if ch != nil {
   175  		ch <- e.Clone()
   176  	} else {
   177  		defaultHandler(e)
   178  	}
   179  }
   180  
   181  func (c *wsConn) close() error {
   182  	// by setting the stop flag, we are making sure no reconnection will happen
   183  	c.stop = true
   184  
   185  	if c.conn == nil {
   186  		return nil
   187  	}
   188  
   189  	_ = wsutil.WriteMessage(c.conn, ws.StateClientSide, ws.OpClose, nil)
   190  
   191  	// by setting the read deadline we make the receiver() routine stops
   192  	return c.conn.SetReadDeadline(time.Now())
   193  }
   194  
   195  func (c *wsConn) send(req *rony.MessageEnvelope) error {
   196  	if !c.connected {
   197  		c.connect()
   198  	}
   199  	mo := proto.MarshalOptions{UseCachedSize: true}
   200  	buf := pools.Buffer.GetCap(mo.Size(req))
   201  	defer pools.Buffer.Put(buf)
   202  
   203  	b, err := mo.MarshalAppend(*buf.Bytes(), req)
   204  	if err != nil {
   205  		return err
   206  	}
   207  
   208  	return wsutil.WriteMessage(c.conn, ws.StateClientSide, ws.OpBinary, b)
   209  }