github.com/vipernet-xyz/tm@v0.34.24/rpc/jsonrpc/server/http_json_handler.go (about)

     1  package server
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/json"
     6  	"fmt"
     7  	"io"
     8  	"net/http"
     9  	"reflect"
    10  	"sort"
    11  
    12  	tmjson "github.com/vipernet-xyz/tm/libs/json"
    13  	"github.com/vipernet-xyz/tm/libs/log"
    14  	types "github.com/vipernet-xyz/tm/rpc/jsonrpc/types"
    15  )
    16  
    17  // HTTP + JSON handler
    18  
    19  // jsonrpc calls grab the given method's function info and runs reflect.Call
    20  func makeJSONRPCHandler(funcMap map[string]*RPCFunc, logger log.Logger) http.HandlerFunc {
    21  	return func(w http.ResponseWriter, r *http.Request) {
    22  		b, err := io.ReadAll(r.Body)
    23  		if err != nil {
    24  			res := types.RPCInvalidRequestError(nil,
    25  				fmt.Errorf("error reading request body: %w", err),
    26  			)
    27  			if wErr := WriteRPCResponseHTTPError(w, http.StatusBadRequest, res); wErr != nil {
    28  				logger.Error("failed to write response", "res", res, "err", wErr)
    29  			}
    30  			return
    31  		}
    32  
    33  		// if its an empty request (like from a browser), just display a list of
    34  		// functions
    35  		if len(b) == 0 {
    36  			writeListOfEndpoints(w, r, funcMap)
    37  			return
    38  		}
    39  
    40  		// first try to unmarshal the incoming request as an array of RPC requests
    41  		var (
    42  			requests  []types.RPCRequest
    43  			responses []types.RPCResponse
    44  		)
    45  		if err := json.Unmarshal(b, &requests); err != nil {
    46  			// next, try to unmarshal as a single request
    47  			var request types.RPCRequest
    48  			if err := json.Unmarshal(b, &request); err != nil {
    49  				res := types.RPCParseError(fmt.Errorf("error unmarshaling request: %w", err))
    50  				if wErr := WriteRPCResponseHTTPError(w, http.StatusInternalServerError, res); wErr != nil {
    51  					logger.Error("failed to write response", "res", res, "err", wErr)
    52  				}
    53  				return
    54  			}
    55  			requests = []types.RPCRequest{request}
    56  		}
    57  
    58  		// Set the default response cache to true unless
    59  		// 1. Any RPC request error.
    60  		// 2. Any RPC request doesn't allow to be cached.
    61  		// 3. Any RPC request has the height argument and the value is 0 (the default).
    62  		cache := true
    63  		for _, request := range requests {
    64  			request := request
    65  
    66  			// A Notification is a Request object without an "id" member.
    67  			// The Server MUST NOT reply to a Notification, including those that are within a batch request.
    68  			if request.ID == nil {
    69  				logger.Debug(
    70  					"HTTPJSONRPC received a notification, skipping... (please send a non-empty ID if you want to call a method)",
    71  					"req", request,
    72  				)
    73  				continue
    74  			}
    75  			if len(r.URL.Path) > 1 {
    76  				responses = append(
    77  					responses,
    78  					types.RPCInvalidRequestError(request.ID, fmt.Errorf("path %s is invalid", r.URL.Path)),
    79  				)
    80  				cache = false
    81  				continue
    82  			}
    83  			rpcFunc, ok := funcMap[request.Method]
    84  			if !ok || (rpcFunc.ws) {
    85  				responses = append(responses, types.RPCMethodNotFoundError(request.ID))
    86  				cache = false
    87  				continue
    88  			}
    89  			ctx := &types.Context{JSONReq: &request, HTTPReq: r}
    90  			args := []reflect.Value{reflect.ValueOf(ctx)}
    91  			if len(request.Params) > 0 {
    92  				fnArgs, err := jsonParamsToArgs(rpcFunc, request.Params)
    93  				if err != nil {
    94  					responses = append(
    95  						responses,
    96  						types.RPCInvalidParamsError(request.ID, fmt.Errorf("error converting json params to arguments: %w", err)),
    97  					)
    98  					cache = false
    99  					continue
   100  				}
   101  				args = append(args, fnArgs...)
   102  			}
   103  
   104  			if cache && !rpcFunc.cacheableWithArgs(args) {
   105  				cache = false
   106  			}
   107  
   108  			returns := rpcFunc.f.Call(args)
   109  			result, err := unreflectResult(returns)
   110  			if err != nil {
   111  				responses = append(responses, types.RPCInternalError(request.ID, err))
   112  				continue
   113  			}
   114  			responses = append(responses, types.NewRPCSuccessResponse(request.ID, result))
   115  		}
   116  
   117  		if len(responses) > 0 {
   118  			var wErr error
   119  			if cache {
   120  				wErr = WriteCacheableRPCResponseHTTP(w, responses...)
   121  			} else {
   122  				wErr = WriteRPCResponseHTTP(w, responses...)
   123  			}
   124  			if wErr != nil {
   125  				logger.Error("failed to write responses", "res", responses, "err", wErr)
   126  			}
   127  		}
   128  	}
   129  }
   130  
   131  func handleInvalidJSONRPCPaths(next http.HandlerFunc) http.HandlerFunc {
   132  	return func(w http.ResponseWriter, r *http.Request) {
   133  		// Since the pattern "/" matches all paths not matched by other registered patterns,
   134  		//  we check whether the path is indeed "/", otherwise return a 404 error
   135  		if r.URL.Path != "/" {
   136  			http.NotFound(w, r)
   137  			return
   138  		}
   139  
   140  		next(w, r)
   141  	}
   142  }
   143  
   144  func mapParamsToArgs(
   145  	rpcFunc *RPCFunc,
   146  	params map[string]json.RawMessage,
   147  	argsOffset int,
   148  ) ([]reflect.Value, error) {
   149  	values := make([]reflect.Value, len(rpcFunc.argNames))
   150  	for i, argName := range rpcFunc.argNames {
   151  		argType := rpcFunc.args[i+argsOffset]
   152  
   153  		if p, ok := params[argName]; ok && p != nil && len(p) > 0 {
   154  			val := reflect.New(argType)
   155  			err := tmjson.Unmarshal(p, val.Interface())
   156  			if err != nil {
   157  				return nil, err
   158  			}
   159  			values[i] = val.Elem()
   160  		} else { // use default for that type
   161  			values[i] = reflect.Zero(argType)
   162  		}
   163  	}
   164  
   165  	return values, nil
   166  }
   167  
   168  func arrayParamsToArgs(
   169  	rpcFunc *RPCFunc,
   170  	params []json.RawMessage,
   171  	argsOffset int,
   172  ) ([]reflect.Value, error) {
   173  	if len(rpcFunc.argNames) != len(params) {
   174  		return nil, fmt.Errorf("expected %v parameters (%v), got %v (%v)",
   175  			len(rpcFunc.argNames), rpcFunc.argNames, len(params), params)
   176  	}
   177  
   178  	values := make([]reflect.Value, len(params))
   179  	for i, p := range params {
   180  		argType := rpcFunc.args[i+argsOffset]
   181  		val := reflect.New(argType)
   182  		err := tmjson.Unmarshal(p, val.Interface())
   183  		if err != nil {
   184  			return nil, err
   185  		}
   186  		values[i] = val.Elem()
   187  	}
   188  	return values, nil
   189  }
   190  
   191  // raw is unparsed json (from json.RawMessage) encoding either a map or an
   192  // array.
   193  //
   194  // Example:
   195  //
   196  //	rpcFunc.args = [rpctypes.Context string]
   197  //	rpcFunc.argNames = ["arg"]
   198  func jsonParamsToArgs(rpcFunc *RPCFunc, raw []byte) ([]reflect.Value, error) {
   199  	const argsOffset = 1
   200  
   201  	// TODO: Make more efficient, perhaps by checking the first character for '{' or '['?
   202  	// First, try to get the map.
   203  	var m map[string]json.RawMessage
   204  	err := json.Unmarshal(raw, &m)
   205  	if err == nil {
   206  		return mapParamsToArgs(rpcFunc, m, argsOffset)
   207  	}
   208  
   209  	// Otherwise, try an array.
   210  	var a []json.RawMessage
   211  	err = json.Unmarshal(raw, &a)
   212  	if err == nil {
   213  		return arrayParamsToArgs(rpcFunc, a, argsOffset)
   214  	}
   215  
   216  	// Otherwise, bad format, we cannot parse
   217  	return nil, fmt.Errorf("unknown type for JSON params: %v. Expected map or array", err)
   218  }
   219  
   220  // writes a list of available rpc endpoints as an html page
   221  func writeListOfEndpoints(w http.ResponseWriter, r *http.Request, funcMap map[string]*RPCFunc) {
   222  	noArgNames := []string{}
   223  	argNames := []string{}
   224  	for name, funcData := range funcMap {
   225  		if len(funcData.args) == 0 {
   226  			noArgNames = append(noArgNames, name)
   227  		} else {
   228  			argNames = append(argNames, name)
   229  		}
   230  	}
   231  	sort.Strings(noArgNames)
   232  	sort.Strings(argNames)
   233  	buf := new(bytes.Buffer)
   234  	buf.WriteString("<html><body>")
   235  	buf.WriteString("<br>Available endpoints:<br>")
   236  
   237  	for _, name := range noArgNames {
   238  		link := fmt.Sprintf("//%s/%s", r.Host, name)
   239  		buf.WriteString(fmt.Sprintf("<a href=\"%s\">%s</a></br>", link, link))
   240  	}
   241  
   242  	buf.WriteString("<br>Endpoints that require arguments:<br>")
   243  	for _, name := range argNames {
   244  		link := fmt.Sprintf("//%s/%s?", r.Host, name)
   245  		funcData := funcMap[name]
   246  		for i, argName := range funcData.argNames {
   247  			link += argName + "=_"
   248  			if i < len(funcData.argNames)-1 {
   249  				link += "&"
   250  			}
   251  		}
   252  		buf.WriteString(fmt.Sprintf("<a href=\"%s\">%s</a></br>", link, link))
   253  	}
   254  	buf.WriteString("</body></html>")
   255  	w.Header().Set("Content-Type", "text/html")
   256  	w.WriteHeader(200)
   257  	w.Write(buf.Bytes()) //nolint: errcheck
   258  }