github.com/badrootd/celestia-core@v0.0.0-20240305091328-aa4207a4b25d/rpc/core/events.go (about)

     1  package core
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"time"
     8  
     9  	cmtpubsub "github.com/badrootd/celestia-core/libs/pubsub"
    10  	cmtquery "github.com/badrootd/celestia-core/libs/pubsub/query"
    11  	ctypes "github.com/badrootd/celestia-core/rpc/core/types"
    12  	rpctypes "github.com/badrootd/celestia-core/rpc/jsonrpc/types"
    13  )
    14  
    15  const (
    16  	// maxQueryLength is the maximum length of a query string that will be
    17  	// accepted. This is just a safety check to avoid outlandish queries.
    18  	maxQueryLength = 512
    19  )
    20  
    21  // Subscribe for events via WebSocket.
    22  // More: https://docs.cometbft.com/v0.34/rpc/#/Websocket/subscribe
    23  func Subscribe(ctx *rpctypes.Context, query string) (*ctypes.ResultSubscribe, error) {
    24  	addr := ctx.RemoteAddr()
    25  	env := GetEnvironment()
    26  
    27  	if env.EventBus.NumClients() >= env.Config.MaxSubscriptionClients {
    28  		return nil, fmt.Errorf("max_subscription_clients %d reached", env.Config.MaxSubscriptionClients)
    29  	} else if env.EventBus.NumClientSubscriptions(addr) >= env.Config.MaxSubscriptionsPerClient {
    30  		return nil, fmt.Errorf("max_subscriptions_per_client %d reached", env.Config.MaxSubscriptionsPerClient)
    31  	} else if len(query) > maxQueryLength {
    32  		return nil, errors.New("maximum query length exceeded")
    33  	}
    34  
    35  	env.Logger.Info("Subscribe to query", "remote", addr, "query", query)
    36  
    37  	q, err := cmtquery.New(query)
    38  	if err != nil {
    39  		return nil, fmt.Errorf("failed to parse query: %w", err)
    40  	}
    41  
    42  	subCtx, cancel := context.WithTimeout(ctx.Context(), SubscribeTimeout)
    43  	defer cancel()
    44  
    45  	sub, err := env.EventBus.Subscribe(subCtx, addr, q, env.Config.SubscriptionBufferSize)
    46  	if err != nil {
    47  		return nil, err
    48  	}
    49  
    50  	closeIfSlow := env.Config.CloseOnSlowClient
    51  
    52  	// Capture the current ID, since it can change in the future.
    53  	subscriptionID := ctx.JSONReq.ID
    54  	go func() {
    55  		for sub != nil {
    56  			select {
    57  			case msg := <-sub.Out():
    58  				var (
    59  					resultEvent = &ctypes.ResultEvent{Query: query, Data: msg.Data(), Events: msg.Events()}
    60  					resp        = rpctypes.NewRPCSuccessResponse(subscriptionID, resultEvent)
    61  				)
    62  				writeCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
    63  				defer cancel()
    64  				if err := ctx.WSConn.WriteRPCResponse(writeCtx, resp); err != nil {
    65  					env.Logger.Info("Can't write response (slow client)",
    66  						"to", addr, "subscriptionID", subscriptionID, "err", err)
    67  
    68  					if closeIfSlow {
    69  						var (
    70  							err  = errors.New("subscription was cancelled (reason: slow client)")
    71  							resp = rpctypes.RPCServerError(subscriptionID, err)
    72  						)
    73  						if !ctx.WSConn.TryWriteRPCResponse(resp) {
    74  							env.Logger.Info("Can't write response (slow client)",
    75  								"to", addr, "subscriptionID", subscriptionID, "err", err)
    76  						}
    77  						return
    78  					}
    79  				}
    80  			case <-sub.Cancelled():
    81  				if sub.Err() != cmtpubsub.ErrUnsubscribed {
    82  					var reason string
    83  					if sub.Err() == nil {
    84  						reason = "CometBFT exited"
    85  					} else {
    86  						reason = sub.Err().Error()
    87  					}
    88  					var (
    89  						err  = fmt.Errorf("subscription was cancelled (reason: %s)", reason)
    90  						resp = rpctypes.RPCServerError(subscriptionID, err)
    91  					)
    92  					if !ctx.WSConn.TryWriteRPCResponse(resp) {
    93  						env.Logger.Info("Can't write response (slow client)",
    94  							"to", addr, "subscriptionID", subscriptionID, "err", err)
    95  					}
    96  				}
    97  				return
    98  			}
    99  		}
   100  	}()
   101  
   102  	return &ctypes.ResultSubscribe{}, nil
   103  }
   104  
   105  // Unsubscribe from events via WebSocket.
   106  // More: https://docs.cometbft.com/v0.34/rpc/#/Websocket/unsubscribe
   107  func Unsubscribe(ctx *rpctypes.Context, query string) (*ctypes.ResultUnsubscribe, error) {
   108  	addr := ctx.RemoteAddr()
   109  	env := GetEnvironment()
   110  	env.Logger.Info("Unsubscribe from query", "remote", addr, "query", query)
   111  	q, err := cmtquery.New(query)
   112  	if err != nil {
   113  		return nil, fmt.Errorf("failed to parse query: %w", err)
   114  	}
   115  	err = env.EventBus.Unsubscribe(context.Background(), addr, q)
   116  	if err != nil {
   117  		return nil, err
   118  	}
   119  	return &ctypes.ResultUnsubscribe{}, nil
   120  }
   121  
   122  // UnsubscribeAll from all events via WebSocket.
   123  // More: https://docs.cometbft.com/v0.34/rpc/#/Websocket/unsubscribe_all
   124  func UnsubscribeAll(ctx *rpctypes.Context) (*ctypes.ResultUnsubscribe, error) {
   125  	addr := ctx.RemoteAddr()
   126  	env := GetEnvironment()
   127  	env.Logger.Info("Unsubscribe from all", "remote", addr)
   128  	err := env.EventBus.UnsubscribeAll(context.Background(), addr)
   129  	if err != nil {
   130  		return nil, err
   131  	}
   132  	return &ctypes.ResultUnsubscribe{}, nil
   133  }