
     1  // Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved.
     2  // See the file LICENSE for licensing terms.
     4  package pubsub
     6  import (
     7  	"encoding/json"
     8  	"errors"
     9  	"fmt"
    10  	"sync/atomic"
    11  	"time"
    13  	""
    14  	""
    16  	""
    17  )
    19  var (
    20  	ErrFilterNotInitialized        = errors.New("filter not initialized")
    21  	ErrAddressLimit                = errors.New("address limit exceeded")
    22  	ErrInvalidFilterParam          = errors.New("invalid bloom filter params")
    23  	ErrInvalidCommand              = errors.New("invalid command")
    24  	_                       Filter = (*connection)(nil)
    25  )
    27  type Filter interface {
    28  	Check(addr []byte) bool
    29  }
    31  // connection is a representation of the websocket connection.
    32  type connection struct {
    33  	s *Server
    35  	// The websocket connection.
    36  	conn *websocket.Conn
    38  	// Buffered channel of outbound messages.
    39  	send chan interface{}
    41  	fp *FilterParam
    43  	active uint32
    44  }
    46  func (c *connection) Check(addr []byte) bool {
    47  	return c.fp.Check(addr)
    48  }
    50  func (c *connection) isActive() bool {
    51  	active := atomic.LoadUint32(&
    52  	return active != 0
    53  }
    55  func (c *connection) deactivate() {
    56  	atomic.StoreUint32(&, 0)
    57  }
    59  func (c *connection) Send(msg interface{}) bool {
    60  	if !c.isActive() {
    61  		return false
    62  	}
    63  	select {
    64  	case c.send <- msg:
    65  		return true
    66  	default:
    67  	}
    68  	return false
    69  }
    71  // readPump pumps messages from the websocket connection to the hub.
    72  //
    73  // The application runs readPump in a per-connection goroutine. The application
    74  // ensures that there is at most one reader on a connection by executing all
    75  // reads from this goroutine.
    76  func (c *connection) readPump() {
    77  	defer func() {
    78  		c.deactivate()
    79  		c.s.removeConnection(c)
    81  		// close is called by both the writePump and the readPump so one of them
    82  		// will always error
    83  		_ = c.conn.Close()
    84  	}()
    86  	c.conn.SetReadLimit(maxMessageSize)
    87  	// SetReadDeadline returns an error if the connection is corrupted
    88  	if err := c.conn.SetReadDeadline(time.Now().Add(pongWait)); err != nil {
    89  		return
    90  	}
    91  	c.conn.SetPongHandler(func(string) error {
    92  		return c.conn.SetReadDeadline(time.Now().Add(pongWait))
    93  	})
    95  	for {
    96  		err := c.readMessage()
    97  		if err != nil {
    98  			if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
    99  				c.s.log.Debug("unexpected close in websockets",
   100  					zap.Error(err),
   101  				)
   102  			}
   103  			break
   104  		}
   105  	}
   106  }
   108  // writePump pumps messages from the hub to the websocket connection.
   109  //
   110  // A goroutine running writePump is started for each connection. The
   111  // application ensures that there is at most one writer to a connection by
   112  // executing all writes from this goroutine.
   113  func (c *connection) writePump() {
   114  	ticker := time.NewTicker(pingPeriod)
   115  	defer func() {
   116  		c.deactivate()
   117  		ticker.Stop()
   118  		c.s.removeConnection(c)
   120  		// close is called by both the writePump and the readPump so one of them
   121  		// will always error
   122  		_ = c.conn.Close()
   123  	}()
   124  	for {
   125  		select {
   126  		case message, ok := <-c.send:
   127  			if err := c.conn.SetWriteDeadline(time.Now().Add(writeWait)); err != nil {
   128  				c.s.log.Debug("closing the connection",
   129  					zap.String("reason", "failed to set the write deadline"),
   130  					zap.Error(err),
   131  				)
   132  				return
   133  			}
   134  			if !ok {
   135  				// The hub closed the channel. Attempt to close the connection
   136  				// gracefully.
   137  				_ = c.conn.WriteMessage(websocket.CloseMessage, []byte{})
   138  				return
   139  			}
   141  			if err := c.conn.WriteJSON(message); err != nil {
   142  				return
   143  			}
   144  		case <-ticker.C:
   145  			if err := c.conn.SetWriteDeadline(time.Now().Add(writeWait)); err != nil {
   146  				c.s.log.Debug("closing the connection",
   147  					zap.String("reason", "failed to set the write deadline"),
   148  					zap.Error(err),
   149  				)
   150  				return
   151  			}
   152  			if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
   153  				return
   154  			}
   155  		}
   156  	}
   157  }
   159  func (c *connection) readMessage() error {
   160  	_, r, err := c.conn.NextReader()
   161  	if err != nil {
   162  		return err
   163  	}
   164  	cmd := &Command{}
   165  	err = json.NewDecoder(r).Decode(cmd)
   166  	if err != nil {
   167  		return err
   168  	}
   170  	switch {
   171  	case cmd.NewBloom != nil:
   172  		err = c.handleNewBloom(cmd.NewBloom)
   173  	case cmd.NewSet != nil:
   174  		c.handleNewSet(cmd.NewSet)
   175  	case cmd.AddAddresses != nil:
   176  		err = c.handleAddAddresses(cmd.AddAddresses)
   177  	default:
   178  		err = ErrInvalidCommand
   179  	}
   180  	if err != nil {
   181  		c.Send(&errorMsg{
   182  			Error: err.Error(),
   183  		})
   184  	}
   185  	return err
   186  }
   188  func (c *connection) handleNewBloom(cmd *NewBloom) error {
   189  	if !cmd.IsParamsValid() {
   190  		return ErrInvalidFilterParam
   191  	}
   192  	filter, err := bloom.New(int(cmd.MaxElements), float64(cmd.CollisionProb), MaxBytes)
   193  	if err != nil {
   194  		return fmt.Errorf("bloom filter creation failed %w", err)
   195  	}
   196  	c.fp.SetFilter(filter)
   197  	return nil
   198  }
   200  func (c *connection) handleNewSet(_ *NewSet) {
   201  	c.fp.NewSet()
   202  }
   204  func (c *connection) handleAddAddresses(cmd *AddAddresses) error {
   205  	if err := cmd.parseAddresses(); err != nil {
   206  		return fmt.Errorf("address parse failed %w", err)
   207  	}
   208  	err := c.fp.Add(cmd.addressIds...)
   209  	if err != nil {
   210  		return fmt.Errorf("address append failed %w", err)
   211  	}
   212  	c.s.subscribedConnections.Add(c)
   213  	return nil
   214  }