github.com/iotexproject/iotex-core@v1.14.1-rc1/api/websocket.go (about)

     1  package api
     2  
     3  import (
     4  	"context"
     5  	"math"
     6  	"net/http"
     7  	"sync"
     8  	"time"
     9  
    10  	"github.com/gorilla/websocket"
    11  	"go.uber.org/zap"
    12  	"golang.org/x/time/rate"
    13  
    14  	apitypes "github.com/iotexproject/iotex-core/api/types"
    15  	"github.com/iotexproject/iotex-core/pkg/log"
    16  )
    17  
    18  const (
    19  	// Time allowed to write a message to the peer.
    20  	writeWait = 10 * time.Second
    21  
    22  	// Time allowed to read the next pong message from the peer.
    23  	pongWait = 60 * time.Second
    24  
    25  	// Send pings to peer with this period. Must be less than pongWait.
    26  	pingPeriod = (pongWait * 9) / 10
    27  
    28  	// Maximum message size allowed from peer.
    29  	maxMessageSize = 15 * 1024 * 1024
    30  )
    31  
    32  // WebsocketHandler handles requests from websocket protocol
    33  type WebsocketHandler struct {
    34  	msgHandler Web3Handler
    35  	limiter    *rate.Limiter
    36  }
    37  
    38  var upgrader = websocket.Upgrader{
    39  	ReadBufferSize:  1024,
    40  	WriteBufferSize: 1024,
    41  }
    42  
    43  // type safeWebsocketConn wraps websocket.Conn with a mutex
    44  // to avoid concurrent write to the connection
    45  // https://pkg.go.dev/github.com/gorilla/websocket#hdr-Concurrency
    46  type safeWebsocketConn struct {
    47  	ws *websocket.Conn
    48  	mu sync.Mutex
    49  }
    50  
    51  // WiteJSON writes a JSON message to the connection in a thread-safe way
    52  func (c *safeWebsocketConn) WriteJSON(message interface{}) error {
    53  	c.mu.Lock()
    54  	defer c.mu.Unlock()
    55  	return c.ws.WriteJSON(message)
    56  }
    57  
    58  // WriteMessage writes a message to the connection in a thread-safe way
    59  func (c *safeWebsocketConn) WriteMessage(messageType int, data []byte) error {
    60  	c.mu.Lock()
    61  	defer c.mu.Unlock()
    62  	return c.ws.WriteMessage(messageType, data)
    63  }
    64  
    65  // Close closes the underlying network connection without sending or waiting for a close frame
    66  func (c *safeWebsocketConn) Close() error {
    67  	return c.ws.Close()
    68  }
    69  
    70  // SetWriteDeadline sets the write deadline on the underlying network connection
    71  func (c *safeWebsocketConn) SetWriteDeadline(t time.Time) error {
    72  	c.mu.Lock()
    73  	defer c.mu.Unlock()
    74  	return c.ws.SetWriteDeadline(t)
    75  }
    76  
    77  // NewWebsocketHandler creates a new websocket handler
    78  func NewWebsocketHandler(web3Handler Web3Handler, limiter *rate.Limiter) *WebsocketHandler {
    79  	if limiter == nil {
    80  		// set the limiter to the maximum possible rate
    81  		limiter = rate.NewLimiter(rate.Limit(math.MaxFloat64), 1)
    82  	}
    83  	return &WebsocketHandler{
    84  		msgHandler: web3Handler,
    85  		limiter:    limiter,
    86  	}
    87  }
    88  
    89  func (wsSvr *WebsocketHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
    90  	upgrader.CheckOrigin = func(_ *http.Request) bool { return true }
    91  
    92  	// upgrade this connection to a WebSocket connection
    93  	ws, err := upgrader.Upgrade(w, req, nil)
    94  	if err != nil {
    95  		log.Logger("api").Warn("failed to upgrade http server to websocket", zap.Error(err))
    96  		return
    97  	}
    98  
    99  	wsSvr.handleConnection(req.Context(), ws)
   100  }
   101  
   102  func (wsSvr *WebsocketHandler) handleConnection(ctx context.Context, ws *websocket.Conn) {
   103  	defer ws.Close()
   104  	if err := ws.SetReadDeadline(time.Now().Add(pongWait)); err != nil {
   105  		log.Logger("api").Warn("failed to set read deadline timeout.", zap.Error(err))
   106  	}
   107  	ws.SetReadLimit(maxMessageSize)
   108  	ws.SetPongHandler(func(string) error {
   109  		if err := ws.SetReadDeadline(time.Now().Add(pongWait)); err != nil {
   110  			log.Logger("api").Warn("failed to set read deadline timeout.", zap.Error(err))
   111  		}
   112  		return nil
   113  	})
   114  
   115  	ctx, cancel := context.WithCancel(ctx)
   116  	safeWs := &safeWebsocketConn{ws: ws}
   117  	go ping(ctx, safeWs, cancel)
   118  
   119  	for {
   120  		select {
   121  		case <-ctx.Done():
   122  			return
   123  		default:
   124  			if err := wsSvr.limiter.Wait(ctx); err != nil {
   125  				cancel()
   126  				return
   127  			}
   128  			_, reader, err := ws.NextReader()
   129  			if err != nil {
   130  				log.Logger("api").Debug("Client Disconnected", zap.Error(err))
   131  				cancel()
   132  				return
   133  			}
   134  
   135  			err = wsSvr.msgHandler.HandlePOSTReq(ctx, reader,
   136  				apitypes.NewResponseWriter(
   137  					func(resp interface{}) (int, error) {
   138  						if err = safeWs.SetWriteDeadline(time.Now().Add(writeWait)); err != nil {
   139  							log.Logger("api").Warn("failed to set write deadline timeout.", zap.Error(err))
   140  						}
   141  						return 0, safeWs.WriteJSON(resp)
   142  					}),
   143  			)
   144  			if err != nil {
   145  				log.Logger("api").Warn("fail to respond request.", zap.Error(err))
   146  				cancel()
   147  				return
   148  			}
   149  		}
   150  	}
   151  }
   152  
   153  func ping(ctx context.Context, ws *safeWebsocketConn, cancel context.CancelFunc) {
   154  	pingTicker := time.NewTicker(pingPeriod)
   155  	defer func() {
   156  		pingTicker.Stop()
   157  		if err := ws.Close(); err != nil {
   158  			log.Logger("api").Warn("fail to close websocket connection.", zap.Error(err))
   159  		}
   160  	}()
   161  
   162  	for {
   163  		select {
   164  		case <-ctx.Done():
   165  			return
   166  		case <-pingTicker.C:
   167  			if err := ws.SetWriteDeadline(time.Now().Add(writeWait)); err != nil {
   168  				log.Logger("api").Warn("failed to set write deadline timeout.", zap.Error(err))
   169  			}
   170  			if err := ws.WriteMessage(websocket.PingMessage, []byte{}); err != nil {
   171  				log.Logger("api").Warn("fail to respond request.", zap.Error(err))
   172  				cancel()
   173  				return
   174  			}
   175  		}
   176  	}
   177  }