github.com/aidoskuneen/adk-node@v0.0.0-20220315131952-2e32567cb7f4/rpc/websocket.go (about)

     1  // Copyright 2021 The adkgo Authors
     2  // This file is part of the adkgo library (adapted for adkgo from go--ethereum v1.10.8).
     3  //
     4  // the adkgo library is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU Lesser General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // the adkgo library is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU Lesser General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU Lesser General Public License
    15  // along with the adkgo library. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  package rpc
    18  
    19  import (
    20  	"context"
    21  	"encoding/base64"
    22  	"fmt"
    23  	"net/http"
    24  	"net/url"
    25  	"os"
    26  	"strings"
    27  	"sync"
    28  	"time"
    29  
    30  	mapset "github.com/deckarep/golang-set"
    31  	"github.com/aidoskuneen/adk-node/log"
    32  	"github.com/gorilla/websocket"
    33  )
    34  
    35  const (
    36  	wsReadBuffer       = 1024
    37  	wsWriteBuffer      = 1024
    38  	wsPingInterval     = 60 * time.Second
    39  	wsPingWriteTimeout = 5 * time.Second
    40  	wsMessageSizeLimit = 15 * 1024 * 1024
    41  )
    42  
    43  var wsBufferPool = new(sync.Pool)
    44  
    45  // WebsocketHandler returns a handler that serves JSON-RPC to WebSocket connections.
    46  //
    47  // allowedOrigins should be a comma-separated list of allowed origin URLs.
    48  // To allow connections with any origin, pass "*".
    49  func (s *Server) WebsocketHandler(allowedOrigins []string) http.Handler {
    50  	var upgrader = websocket.Upgrader{
    51  		ReadBufferSize:  wsReadBuffer,
    52  		WriteBufferSize: wsWriteBuffer,
    53  		WriteBufferPool: wsBufferPool,
    54  		CheckOrigin:     wsHandshakeValidator(allowedOrigins),
    55  	}
    56  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    57  		conn, err := upgrader.Upgrade(w, r, nil)
    58  		if err != nil {
    59  			log.Debug("WebSocket upgrade failed", "err", err)
    60  			return
    61  		}
    62  		codec := newWebsocketCodec(conn)
    63  		s.ServeCodec(codec, 0)
    64  	})
    65  }
    66  
    67  // wsHandshakeValidator returns a handler that verifies the origin during the
    68  // websocket upgrade process. When a '*' is specified as an allowed origins all
    69  // connections are accepted.
    70  func wsHandshakeValidator(allowedOrigins []string) func(*http.Request) bool {
    71  	origins := mapset.NewSet()
    72  	allowAllOrigins := false
    73  
    74  	for _, origin := range allowedOrigins {
    75  		if origin == "*" {
    76  			allowAllOrigins = true
    77  		}
    78  		if origin != "" {
    79  			origins.Add(origin)
    80  		}
    81  	}
    82  	// allow localhost if no allowedOrigins are specified.
    83  	if len(origins.ToSlice()) == 0 {
    84  		origins.Add("http://localhost")
    85  		if hostname, err := os.Hostname(); err == nil {
    86  			origins.Add("http://" + hostname)
    87  		}
    88  	}
    89  	log.Debug(fmt.Sprintf("Allowed origin(s) for WS RPC interface %v", origins.ToSlice()))
    90  
    91  	f := func(req *http.Request) bool {
    92  		// Skip origin verification if no Origin header is present. The origin check
    93  		// is supposed to protect against browser based attacks. Browsers always set
    94  		// Origin. Non-browser software can put anything in origin and checking it doesn't
    95  		// provide additional security.
    96  		if _, ok := req.Header["Origin"]; !ok {
    97  			return true
    98  		}
    99  		// Verify origin against allow list.
   100  		origin := strings.ToLower(req.Header.Get("Origin"))
   101  		if allowAllOrigins || originIsAllowed(origins, origin) {
   102  			return true
   103  		}
   104  		log.Warn("Rejected WebSocket connection", "origin", origin)
   105  		return false
   106  	}
   107  
   108  	return f
   109  }
   110  
   111  type wsHandshakeError struct {
   112  	err    error
   113  	status string
   114  }
   115  
   116  func (e wsHandshakeError) Error() string {
   117  	s := e.err.Error()
   118  	if e.status != "" {
   119  		s += " (HTTP status " + e.status + ")"
   120  	}
   121  	return s
   122  }
   123  
   124  func originIsAllowed(allowedOrigins mapset.Set, browserOrigin string) bool {
   125  	it := allowedOrigins.Iterator()
   126  	for origin := range it.C {
   127  		if ruleAllowsOrigin(origin.(string), browserOrigin) {
   128  			return true
   129  		}
   130  	}
   131  	return false
   132  }
   133  
   134  func ruleAllowsOrigin(allowedOrigin string, browserOrigin string) bool {
   135  	var (
   136  		allowedScheme, allowedHostname, allowedPort string
   137  		browserScheme, browserHostname, browserPort string
   138  		err                                         error
   139  	)
   140  	allowedScheme, allowedHostname, allowedPort, err = parseOriginURL(allowedOrigin)
   141  	if err != nil {
   142  		log.Warn("Error parsing allowed origin specification", "spec", allowedOrigin, "error", err)
   143  		return false
   144  	}
   145  	browserScheme, browserHostname, browserPort, err = parseOriginURL(browserOrigin)
   146  	if err != nil {
   147  		log.Warn("Error parsing browser 'Origin' field", "Origin", browserOrigin, "error", err)
   148  		return false
   149  	}
   150  	if allowedScheme != "" && allowedScheme != browserScheme {
   151  		return false
   152  	}
   153  	if allowedHostname != "" && allowedHostname != browserHostname {
   154  		return false
   155  	}
   156  	if allowedPort != "" && allowedPort != browserPort {
   157  		return false
   158  	}
   159  	return true
   160  }
   161  
   162  func parseOriginURL(origin string) (string, string, string, error) {
   163  	parsedURL, err := url.Parse(strings.ToLower(origin))
   164  	if err != nil {
   165  		return "", "", "", err
   166  	}
   167  	var scheme, hostname, port string
   168  	if strings.Contains(origin, "://") {
   169  		scheme = parsedURL.Scheme
   170  		hostname = parsedURL.Hostname()
   171  		port = parsedURL.Port()
   172  	} else {
   173  		scheme = ""
   174  		hostname = parsedURL.Scheme
   175  		port = parsedURL.Opaque
   176  		if hostname == "" {
   177  			hostname = origin
   178  		}
   179  	}
   180  	return scheme, hostname, port, nil
   181  }
   182  
   183  // DialWebsocketWithDialer creates a new RPC client that communicates with a JSON-RPC server
   184  // that is listening on the given endpoint using the provided dialer.
   185  func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, dialer websocket.Dialer) (*Client, error) {
   186  	endpoint, header, err := wsClientHeaders(endpoint, origin)
   187  	if err != nil {
   188  		return nil, err
   189  	}
   190  	return newClient(ctx, func(ctx context.Context) (ServerCodec, error) {
   191  		conn, resp, err := dialer.DialContext(ctx, endpoint, header)
   192  		if err != nil {
   193  			hErr := wsHandshakeError{err: err}
   194  			if resp != nil {
   195  				hErr.status = resp.Status
   196  			}
   197  			return nil, hErr
   198  		}
   199  		return newWebsocketCodec(conn), nil
   200  	})
   201  }
   202  
   203  // DialWebsocket creates a new RPC client that communicates with a JSON-RPC server
   204  // that is listening on the given endpoint.
   205  //
   206  // The context is used for the initial connection establishment. It does not
   207  // affect subsequent interactions with the client.
   208  func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error) {
   209  	dialer := websocket.Dialer{
   210  		ReadBufferSize:  wsReadBuffer,
   211  		WriteBufferSize: wsWriteBuffer,
   212  		WriteBufferPool: wsBufferPool,
   213  	}
   214  	return DialWebsocketWithDialer(ctx, endpoint, origin, dialer)
   215  }
   216  
   217  func wsClientHeaders(endpoint, origin string) (string, http.Header, error) {
   218  	endpointURL, err := url.Parse(endpoint)
   219  	if err != nil {
   220  		return endpoint, nil, err
   221  	}
   222  	header := make(http.Header)
   223  	if origin != "" {
   224  		header.Add("origin", origin)
   225  	}
   226  	if endpointURL.User != nil {
   227  		b64auth := base64.StdEncoding.EncodeToString([]byte(endpointURL.User.String()))
   228  		header.Add("authorization", "Basic "+b64auth)
   229  		endpointURL.User = nil
   230  	}
   231  	return endpointURL.String(), header, nil
   232  }
   233  
   234  type websocketCodec struct {
   235  	*jsonCodec
   236  	conn *websocket.Conn
   237  
   238  	wg        sync.WaitGroup
   239  	pingReset chan struct{}
   240  }
   241  
   242  func newWebsocketCodec(conn *websocket.Conn) ServerCodec {
   243  	conn.SetReadLimit(wsMessageSizeLimit)
   244  	wc := &websocketCodec{
   245  		jsonCodec: NewFuncCodec(conn, conn.WriteJSON, conn.ReadJSON).(*jsonCodec),
   246  		conn:      conn,
   247  		pingReset: make(chan struct{}, 1),
   248  	}
   249  	wc.wg.Add(1)
   250  	go wc.pingLoop()
   251  	return wc
   252  }
   253  
   254  func (wc *websocketCodec) close() {
   255  	wc.jsonCodec.close()
   256  	wc.wg.Wait()
   257  }
   258  
   259  func (wc *websocketCodec) writeJSON(ctx context.Context, v interface{}) error {
   260  	err := wc.jsonCodec.writeJSON(ctx, v)
   261  	if err == nil {
   262  		// Notify pingLoop to delay the next idle ping.
   263  		select {
   264  		case wc.pingReset <- struct{}{}:
   265  		default:
   266  		}
   267  	}
   268  	return err
   269  }
   270  
   271  // pingLoop sends periodic ping frames when the connection is idle.
   272  func (wc *websocketCodec) pingLoop() {
   273  	var timer = time.NewTimer(wsPingInterval)
   274  	defer wc.wg.Done()
   275  	defer timer.Stop()
   276  
   277  	for {
   278  		select {
   279  		case <-wc.closed():
   280  			return
   281  		case <-wc.pingReset:
   282  			if !timer.Stop() {
   283  				<-timer.C
   284  			}
   285  			timer.Reset(wsPingInterval)
   286  		case <-timer.C:
   287  			wc.jsonCodec.encMu.Lock()
   288  			wc.conn.SetWriteDeadline(time.Now().Add(wsPingWriteTimeout))
   289  			wc.conn.WriteMessage(websocket.PingMessage, nil)
   290  			wc.jsonCodec.encMu.Unlock()
   291  			timer.Reset(wsPingInterval)
   292  		}
   293  	}
   294  }