github.com/Finschia/ostracon@v1.1.5/rpc/jsonrpc/server/http_json_handler.go (about)

     1  package server
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/json"
     6  	"fmt"
     7  	"io"
     8  	"net/http"
     9  	"reflect"
    10  	"sort"
    11  	"strconv"
    12  
    13  	tmjson "github.com/Finschia/ostracon/libs/json"
    14  	"github.com/Finschia/ostracon/libs/log"
    15  	types "github.com/Finschia/ostracon/rpc/jsonrpc/types"
    16  )
    17  
    18  // HTTP + JSON handler
    19  
    20  // jsonrpc calls grab the given method's function info and runs reflect.Call
    21  func makeJSONRPCHandler(funcMap map[string]*RPCFunc, logger log.Logger) http.HandlerFunc {
    22  	return func(w http.ResponseWriter, r *http.Request) {
    23  		b, err := io.ReadAll(r.Body)
    24  		if err != nil {
    25  			res := types.RPCInvalidRequestError(nil,
    26  				fmt.Errorf("error reading request body: %w", err),
    27  			)
    28  			if wErr := WriteRPCResponseHTTPError(w, http.StatusBadRequest, res); wErr != nil {
    29  				logger.Error("failed to write response", "res", res, "err", wErr)
    30  			}
    31  			return
    32  		}
    33  
    34  		// if its an empty request (like from a browser), just display a list of
    35  		// functions
    36  		if len(b) == 0 {
    37  			writeListOfEndpoints(w, r, funcMap)
    38  			return
    39  		}
    40  
    41  		// first try to unmarshal the incoming request as an array of RPC requests
    42  		var (
    43  			requests  []types.RPCRequest
    44  			responses []types.RPCResponse
    45  		)
    46  		if err := json.Unmarshal(b, &requests); err != nil {
    47  			// next, try to unmarshal as a single request
    48  			var request types.RPCRequest
    49  			if err := json.Unmarshal(b, &request); err != nil {
    50  				res := types.RPCParseError(fmt.Errorf("error unmarshaling request: %w", err))
    51  				if wErr := WriteRPCResponseHTTPError(w, http.StatusInternalServerError, res); wErr != nil {
    52  					logger.Error("failed to write response", "res", res, "err", wErr)
    53  				}
    54  				return
    55  			}
    56  			requests = []types.RPCRequest{request}
    57  		}
    58  		// read the Max-Batch-Request-Num from header
    59  		maxBatchRequestNum, err := strconv.Atoi(r.Header.Get("Max-Batch-Request-Num"))
    60  		if err != nil {
    61  			res := types.RPCInvalidRequestError(nil,
    62  				fmt.Errorf("error reading request header key"),
    63  			)
    64  			if wErr := WriteRPCResponseHTTPError(w, http.StatusInternalServerError, res); wErr != nil {
    65  				logger.Error("failed to write response", "res", res, "err", wErr)
    66  			}
    67  			return
    68  		}
    69  
    70  		// if the number of requests in the batch exceeds the max_batch_request_num
    71  		// return a invalid request error
    72  		if len(requests) > maxBatchRequestNum {
    73  			res := types.RPCInvalidRequestError(nil,
    74  				fmt.Errorf("too many requests in a request batch, current is %d, where the upper limit is %d", len(requests), maxBatchRequestNum),
    75  			)
    76  			if wErr := WriteRPCResponseHTTPError(w, http.StatusBadRequest, res); wErr != nil {
    77  				logger.Error("failed to write response", "res", res, "err", wErr)
    78  			}
    79  			return
    80  		}
    81  
    82  		// Set the default response cache to true unless
    83  		// 1. Any RPC request error.
    84  		// 2. Any RPC request doesn't allow to be cached.
    85  		// 3. Any RPC request has the height argument and the value is 0 (the default).
    86  		cache := true
    87  		for _, request := range requests {
    88  			request := request
    89  
    90  			// A Notification is a Request object without an "id" member.
    91  			// The Server MUST NOT reply to a Notification, including those that are within a batch request.
    92  			if request.ID == nil {
    93  				logger.Debug(
    94  					"HTTPJSONRPC received a notification, skipping... (please send a non-empty ID if you want to call a method)",
    95  					"req", request,
    96  				)
    97  				continue
    98  			}
    99  			if len(r.URL.Path) > 1 {
   100  				responses = append(
   101  					responses,
   102  					types.RPCInvalidRequestError(request.ID, fmt.Errorf("path %s is invalid", r.URL.Path)),
   103  				)
   104  				cache = false
   105  				continue
   106  			}
   107  			rpcFunc, ok := funcMap[request.Method]
   108  			if !ok || (rpcFunc.ws) {
   109  				responses = append(responses, types.RPCMethodNotFoundError(request.ID))
   110  				cache = false
   111  				continue
   112  			}
   113  			ctx := &types.Context{JSONReq: &request, HTTPReq: r}
   114  			args := []reflect.Value{reflect.ValueOf(ctx)}
   115  			if len(request.Params) > 0 {
   116  				fnArgs, err := jsonParamsToArgs(rpcFunc, request.Params)
   117  				if err != nil {
   118  					responses = append(
   119  						responses,
   120  						types.RPCInvalidParamsError(request.ID, fmt.Errorf("error converting json params to arguments: %w", err)),
   121  					)
   122  					cache = false
   123  					continue
   124  				}
   125  				args = append(args, fnArgs...)
   126  			}
   127  
   128  			if cache && !rpcFunc.cacheableWithArgs(args) {
   129  				cache = false
   130  			}
   131  
   132  			returns := rpcFunc.f.Call(args)
   133  			result, err := unreflectResult(returns)
   134  			if err != nil {
   135  				responses = append(responses, types.RPCInternalError(request.ID, err))
   136  				continue
   137  			}
   138  			responses = append(responses, types.NewRPCSuccessResponse(request.ID, result))
   139  		}
   140  
   141  		if len(responses) > 0 {
   142  			var wErr error
   143  			if cache {
   144  				wErr = WriteCacheableRPCResponseHTTP(w, responses...)
   145  			} else {
   146  				wErr = WriteRPCResponseHTTP(w, responses...)
   147  			}
   148  			if wErr != nil {
   149  				logger.Error("failed to write responses", "res", responses, "err", wErr)
   150  			}
   151  		}
   152  	}
   153  }
   154  
   155  func handleInvalidJSONRPCPaths(next http.HandlerFunc) http.HandlerFunc {
   156  	return func(w http.ResponseWriter, r *http.Request) {
   157  		// Since the pattern "/" matches all paths not matched by other registered patterns,
   158  		//  we check whether the path is indeed "/", otherwise return a 404 error
   159  		if r.URL.Path != "/" {
   160  			http.NotFound(w, r)
   161  			return
   162  		}
   163  
   164  		next(w, r)
   165  	}
   166  }
   167  
   168  func mapParamsToArgs(
   169  	rpcFunc *RPCFunc,
   170  	params map[string]json.RawMessage,
   171  	argsOffset int,
   172  ) ([]reflect.Value, error) {
   173  	values := make([]reflect.Value, len(rpcFunc.argNames))
   174  	for i, argName := range rpcFunc.argNames {
   175  		argType := rpcFunc.args[i+argsOffset]
   176  
   177  		if p, ok := params[argName]; ok && p != nil && len(p) > 0 {
   178  			val := reflect.New(argType)
   179  			err := tmjson.Unmarshal(p, val.Interface())
   180  			if err != nil {
   181  				return nil, err
   182  			}
   183  			values[i] = val.Elem()
   184  		} else { // use default for that type
   185  			values[i] = reflect.Zero(argType)
   186  		}
   187  	}
   188  
   189  	return values, nil
   190  }
   191  
   192  func arrayParamsToArgs(
   193  	rpcFunc *RPCFunc,
   194  	params []json.RawMessage,
   195  	argsOffset int,
   196  ) ([]reflect.Value, error) {
   197  	if len(rpcFunc.argNames) != len(params) {
   198  		return nil, fmt.Errorf("expected %v parameters (%v), got %v (%v)",
   199  			len(rpcFunc.argNames), rpcFunc.argNames, len(params), params)
   200  	}
   201  
   202  	values := make([]reflect.Value, len(params))
   203  	for i, p := range params {
   204  		argType := rpcFunc.args[i+argsOffset]
   205  		val := reflect.New(argType)
   206  		err := tmjson.Unmarshal(p, val.Interface())
   207  		if err != nil {
   208  			return nil, err
   209  		}
   210  		values[i] = val.Elem()
   211  	}
   212  	return values, nil
   213  }
   214  
   215  // raw is unparsed json (from json.RawMessage) encoding either a map or an
   216  // array.
   217  //
   218  // Example:
   219  //
   220  //	rpcFunc.args = [rpctypes.Context string]
   221  //	rpcFunc.argNames = ["arg"]
   222  func jsonParamsToArgs(rpcFunc *RPCFunc, raw []byte) ([]reflect.Value, error) {
   223  	const argsOffset = 1
   224  
   225  	// TODO: Make more efficient, perhaps by checking the first character for '{' or '['?
   226  	// First, try to get the map.
   227  	var m map[string]json.RawMessage
   228  	err := json.Unmarshal(raw, &m)
   229  	if err == nil {
   230  		return mapParamsToArgs(rpcFunc, m, argsOffset)
   231  	}
   232  
   233  	// Otherwise, try an array.
   234  	var a []json.RawMessage
   235  	err = json.Unmarshal(raw, &a)
   236  	if err == nil {
   237  		return arrayParamsToArgs(rpcFunc, a, argsOffset)
   238  	}
   239  
   240  	// Otherwise, bad format, we cannot parse
   241  	return nil, fmt.Errorf("unknown type for JSON params: %v. Expected map or array", err)
   242  }
   243  
   244  // writes a list of available rpc endpoints as an html page
   245  func writeListOfEndpoints(w http.ResponseWriter, r *http.Request, funcMap map[string]*RPCFunc) {
   246  	noArgNames := []string{}
   247  	argNames := []string{}
   248  	for name, funcData := range funcMap {
   249  		if len(funcData.args) == 0 {
   250  			noArgNames = append(noArgNames, name)
   251  		} else {
   252  			argNames = append(argNames, name)
   253  		}
   254  	}
   255  	sort.Strings(noArgNames)
   256  	sort.Strings(argNames)
   257  	buf := new(bytes.Buffer)
   258  	buf.WriteString("<html><body>")
   259  	buf.WriteString("<br>Available endpoints:<br>")
   260  
   261  	for _, name := range noArgNames {
   262  		link := fmt.Sprintf("//%s/%s", r.Host, name)
   263  		buf.WriteString(fmt.Sprintf("<a href=\"%s\">%s</a></br>", link, link))
   264  	}
   265  
   266  	buf.WriteString("<br>Endpoints that require arguments:<br>")
   267  	for _, name := range argNames {
   268  		link := fmt.Sprintf("//%s/%s?", r.Host, name)
   269  		funcData := funcMap[name]
   270  		for i, argName := range funcData.argNames {
   271  			link += argName + "=_"
   272  			if i < len(funcData.argNames)-1 {
   273  				link += "&"
   274  			}
   275  		}
   276  		buf.WriteString(fmt.Sprintf("<a href=\"%s\">%s</a></br>", link, link))
   277  	}
   278  	buf.WriteString("</body></html>")
   279  	w.Header().Set("Content-Type", "text/html")
   280  	w.WriteHeader(200)
   281  	w.Write(buf.Bytes()) // nolint: errcheck
   282  }