github.com/klaytn/klaytn@v1.12.1/networks/rpc/websocket.go (about)

     1  // Modifications Copyright 2018 The klaytn Authors
     2  // Copyright 2015 The go-ethereum Authors
     3  // This file is part of the go-ethereum library.
     4  //
     5  // The go-ethereum library is free software: you can redistribute it and/or modify
     6  // it under the terms of the GNU Lesser General Public License as published by
     7  // the Free Software Foundation, either version 3 of the License, or
     8  // (at your option) any later version.
     9  //
    10  // The go-ethereum library is distributed in the hope that it will be useful,
    11  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    12  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    13  // GNU Lesser General Public License for more details.
    14  //
    15  // You should have received a copy of the GNU Lesser General Public License
    16  // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
    17  //
    18  // This file is derived from rpc/websocket.go (2018/06/04).
    19  // Modified and improved for the klaytn development.
    20  
    21  package rpc
    22  
    23  import (
    24  	"bufio"
    25  	"bytes"
    26  	"context"
    27  	"encoding/base64"
    28  	"encoding/json"
    29  	"fmt"
    30  	"net/http"
    31  	"net/url"
    32  	"os"
    33  	"strings"
    34  	"sync"
    35  	"sync/atomic"
    36  	"time"
    37  
    38  	fastws "github.com/clevergo/websocket"
    39  	mapset "github.com/deckarep/golang-set"
    40  	"github.com/gorilla/websocket"
    41  	"github.com/klaytn/klaytn/common"
    42  	"github.com/valyala/fasthttp"
    43  )
    44  
    45  const (
    46  	wsReadBuffer  = 1024
    47  	wsWriteBuffer = 1024
    48  )
    49  
    50  var wsBufferPool = new(sync.Pool)
    51  
    52  func newWebsocketCodec(conn *websocket.Conn) ServerCodec {
    53  	conn.SetReadLimit(int64(common.MaxRequestContentLength))
    54  	if WebsocketReadDeadline != 0 {
    55  		conn.SetReadDeadline(time.Now().Add(time.Duration(WebsocketReadDeadline) * time.Second))
    56  	}
    57  	if WebsocketWriteDeadline != 0 {
    58  		conn.SetWriteDeadline(time.Now().Add(time.Duration(WebsocketWriteDeadline) * time.Second))
    59  	}
    60  	return NewFuncCodec(conn, conn.WriteJSON, conn.ReadJSON)
    61  }
    62  
    63  // WebsocketHandler returns a handler that serves JSON-RPC to WebSocket connections.
    64  //
    65  // allowedOrigins should be a comma-separated list of allowed origin URLs.
    66  // To allow connections with any origin, pass "*".
    67  func (srv *Server) WebsocketHandler(allowedOrigins []string) http.Handler {
    68  	upgrader := websocket.Upgrader{
    69  		ReadBufferSize:  wsReadBuffer,
    70  		WriteBufferSize: wsWriteBuffer,
    71  		WriteBufferPool: wsBufferPool,
    72  		CheckOrigin:     wsHandshakeValidator(allowedOrigins),
    73  	}
    74  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    75  		if atomic.LoadInt32(&srv.wsConnCount) >= MaxWebsocketConnections {
    76  			return
    77  		}
    78  		atomic.AddInt32(&srv.wsConnCount, 1)
    79  		wsConnCounter.Inc(1)
    80  		defer func() {
    81  			atomic.AddInt32(&srv.wsConnCount, -1)
    82  			wsConnCounter.Dec(1)
    83  		}()
    84  		conn, err := upgrader.Upgrade(w, r, nil)
    85  		if err != nil {
    86  			return
    87  		}
    88  		codec := newWebsocketCodec(conn)
    89  		srv.ServeCodec(codec, 0)
    90  	})
    91  }
    92  
    93  var upgrader = fastws.Upgrader{
    94  	ReadBufferSize:  1024,
    95  	WriteBufferSize: 1024,
    96  }
    97  
    98  func (srv *Server) FastWebsocketHandler(ctx *fasthttp.RequestCtx) {
    99  	// TODO-Klaytn handle websocket protocol
   100  	protocol := ctx.Request.Header.Peek("Sec-WebSocket-Protocol")
   101  	if protocol != nil {
   102  		ctx.Response.Header.Set("Sec-WebSocket-Protocol", string(protocol))
   103  	}
   104  
   105  	err := upgrader.Upgrade(ctx, func(conn *fastws.Conn) {
   106  		if atomic.LoadInt32(&srv.wsConnCount) >= MaxWebsocketConnections {
   107  			return
   108  		}
   109  		atomic.AddInt32(&srv.wsConnCount, 1)
   110  		wsConnCounter.Inc(1)
   111  		defer func() {
   112  			atomic.AddInt32(&srv.wsConnCount, -1)
   113  			wsConnCounter.Dec(1)
   114  		}()
   115  		if WebsocketReadDeadline != 0 {
   116  			conn.SetReadDeadline(time.Now().Add(time.Duration(WebsocketReadDeadline) * time.Second))
   117  		}
   118  		if WebsocketWriteDeadline != 0 {
   119  			conn.SetWriteDeadline(time.Now().Add(time.Duration(WebsocketWriteDeadline) * time.Second))
   120  		}
   121  		// Create a custom encode/decode pair to enforce payload size and number encoding
   122  		encoder := func(v interface{}) error {
   123  			msg, err := json.Marshal(v)
   124  			if err != nil {
   125  				return err
   126  			}
   127  			err = conn.WriteMessage(websocket.TextMessage, msg)
   128  			if err != nil {
   129  				return err
   130  			}
   131  			return err
   132  		}
   133  		decoder := func(v interface{}) error {
   134  			_, data, err := conn.ReadMessage()
   135  			if err != nil {
   136  				return err
   137  			}
   138  			dec := json.NewDecoder(bytes.NewReader(data))
   139  			dec.UseNumber()
   140  			return dec.Decode(v)
   141  		}
   142  
   143  		reader := bufio.NewReaderSize(bytes.NewReader(ctx.Request.Body()), common.MaxRequestContentLength)
   144  		srv.ServeCodec(NewFuncCodec(&httpReadWriteNopCloser{reader, ctx.Response.BodyWriter()}, encoder, decoder), 0)
   145  	})
   146  	if err != nil {
   147  		logger.Error("FastWebsocketHandler fail to upgrade message", "err", err)
   148  		return
   149  	}
   150  }
   151  
   152  // NewWSServer creates a new websocket RPC server around an API provider.
   153  //
   154  // Deprecated: use Server.WebsocketHandler
   155  func NewWSServer(allowedOrigins []string, srv *Server) *http.Server {
   156  	return &http.Server{
   157  		Handler: srv.WebsocketHandler(allowedOrigins),
   158  	}
   159  }
   160  
   161  func NewFastWSServer(allowedOrigins []string, srv *Server) *fasthttp.Server {
   162  	upgrader.CheckOrigin = wsFastHandshakeValidator(allowedOrigins)
   163  
   164  	// TODO-Klaytn concurreny default (256 * 1024), goroutine limit (8192)
   165  	return &fasthttp.Server{
   166  		Concurrency:        ConcurrencyLimit,
   167  		MaxRequestBodySize: common.MaxRequestContentLength,
   168  		Handler:            srv.FastWebsocketHandler,
   169  	}
   170  }
   171  
   172  func wsFastHandshakeValidator(allowedOrigins []string) func(ctx *fasthttp.RequestCtx) bool {
   173  	origins := mapset.NewSet()
   174  	allowAllOrigins := false
   175  
   176  	for _, origin := range allowedOrigins {
   177  		if origin == "*" {
   178  			allowAllOrigins = true
   179  		}
   180  		if origin != "" {
   181  			origins.Add(strings.ToLower(origin))
   182  		}
   183  	}
   184  
   185  	// allow localhost if no allowedOrigins are specified.
   186  	if len(origins.ToSlice()) == 0 {
   187  		origins.Add("http://localhost")
   188  		if hostname, err := os.Hostname(); err == nil {
   189  			origins.Add("http://" + strings.ToLower(hostname))
   190  		}
   191  	}
   192  
   193  	logger.Debug(fmt.Sprintf("Allowed origin(s) for WS RPC interface %v\n", origins.ToSlice()))
   194  
   195  	f := func(ctx *fasthttp.RequestCtx) bool {
   196  		// Skip origin verification if no Origin header is present. The origin check
   197  		// is supposed to protect against browser based attacks. Browsers always set
   198  		// Origin. Non-browser software can put anything in origin and checking it doesn't
   199  		// provide additional security.
   200  
   201  		origin := strings.ToLower(string(ctx.Request.Header.Peek("Origin")))
   202  		if allowAllOrigins || origins.Contains(origin) || origin == "" {
   203  			return true
   204  		}
   205  		logger.Warn(fmt.Sprintf("origin '%s' not allowed on WS-RPC interface\n", origin))
   206  		return false
   207  	}
   208  
   209  	return f
   210  }
   211  
   212  // wsHandshakeValidator returns a handler that verifies the origin during the
   213  // websocket upgrade process. When a '*' is specified as an allowed origins all
   214  // connections are accepted.
   215  func wsHandshakeValidator(allowedOrigins []string) func(*http.Request) bool {
   216  	origins := mapset.NewSet()
   217  	allowAllOrigins := false
   218  
   219  	for _, origin := range allowedOrigins {
   220  		if origin == "*" {
   221  			allowAllOrigins = true
   222  		}
   223  		if origin != "" {
   224  			origins.Add(strings.ToLower(origin))
   225  		}
   226  	}
   227  
   228  	// allow localhost if no allowedOrigins are specified.
   229  	if len(origins.ToSlice()) == 0 {
   230  		origins.Add("http://localhost")
   231  		if hostname, err := os.Hostname(); err == nil {
   232  			origins.Add("http://" + strings.ToLower(hostname))
   233  		}
   234  	}
   235  	f := func(req *http.Request) bool {
   236  		// Skip origin verification if no Origin header is present. The origin check
   237  		// is supposed to protect against browser based attacks. Browsers always set
   238  		// Origin. Non-browser software can put anything in origin and checking it doesn't
   239  		// provide additional security.
   240  		if _, ok := req.Header["Origin"]; !ok {
   241  			return true
   242  		}
   243  		// Verify origin against whitelist.
   244  		origin := strings.ToLower(req.Header.Get("Origin"))
   245  		if allowAllOrigins || origins.Contains(origin) {
   246  			return true
   247  		}
   248  
   249  		return false
   250  	}
   251  
   252  	return f
   253  }
   254  
   255  type wsHandshakeError struct {
   256  	err    error
   257  	status string
   258  }
   259  
   260  func (e wsHandshakeError) Error() string {
   261  	s := e.err.Error()
   262  	if e.status != "" {
   263  		s += " (HTTP status " + e.status + ")"
   264  	}
   265  	return s
   266  }
   267  
   268  // DialWebsocket creates a new RPC client that communicates with a JSON-RPC server
   269  // that is listening on the given endpoint.
   270  //
   271  // The context is used for the initial connection establishment. It does not
   272  // affect subsequent interactions with the client.
   273  func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error) {
   274  	endpoint, header, err := wsClientHeaders(endpoint, origin)
   275  	if err != nil {
   276  		return nil, err
   277  	}
   278  
   279  	dialer := websocket.Dialer{
   280  		ReadBufferSize:  wsReadBuffer,
   281  		WriteBufferSize: wsWriteBuffer,
   282  		WriteBufferPool: wsBufferPool,
   283  	}
   284  
   285  	return NewClient(ctx, func(ctx context.Context) (ServerCodec, error) {
   286  		conn, resp, err := dialer.DialContext(ctx, endpoint, header)
   287  		if resp != nil && resp.Body != nil {
   288  			defer resp.Body.Close()
   289  		}
   290  
   291  		if err != nil {
   292  			hErr := wsHandshakeError{err: err}
   293  			if resp != nil {
   294  				hErr.status = resp.Status
   295  			}
   296  			return nil, hErr
   297  		}
   298  		return newWebsocketCodec(conn), nil
   299  	})
   300  }
   301  
   302  func wsClientHeaders(endpoint, origin string) (string, http.Header, error) {
   303  	endpointURL, err := url.Parse(endpoint)
   304  	if err != nil {
   305  		return endpoint, nil, err
   306  	}
   307  
   308  	header := make(http.Header)
   309  
   310  	if origin != "" {
   311  		header.Add("origin", origin)
   312  	}
   313  
   314  	if endpointURL.User != nil {
   315  		b64auth := base64.StdEncoding.EncodeToString([]byte(endpointURL.User.String()))
   316  		header.Add("authorization", "Basic "+b64auth)
   317  		endpointURL.User = nil
   318  	}
   319  	return endpointURL.String(), header, nil
   320  }