code.vegaprotocol.io/vega@v0.79.0/wallet/service/v2/endpoint_handle_request.go (about)

     1  // Copyright (C) 2023 Gobalsky Labs Limited
     2  //
     3  // This program is free software: you can redistribute it and/or modify
     4  // it under the terms of the GNU Affero General Public License as
     5  // published by the Free Software Foundation, either version 3 of the
     6  // License, or (at your option) any later version.
     7  //
     8  // This program is distributed in the hope that it will be useful,
     9  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    10  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    11  // GNU Affero General Public License for more details.
    12  //
    13  // You should have received a copy of the GNU Affero General Public License
    14  // along with this program.  If not, see <http://www.gnu.org/licenses/>.
    15  
    16  package v2
    17  
    18  import (
    19  	"context"
    20  	"encoding/json"
    21  	"errors"
    22  	"fmt"
    23  	"io"
    24  	"net/http"
    25  	"strings"
    26  
    27  	vfmt "code.vegaprotocol.io/vega/libs/fmt"
    28  	"code.vegaprotocol.io/vega/libs/jsonrpc"
    29  	vgrand "code.vegaprotocol.io/vega/libs/rand"
    30  	"code.vegaprotocol.io/vega/logging"
    31  	"code.vegaprotocol.io/vega/wallet/api"
    32  
    33  	"github.com/julienschmidt/httprouter"
    34  	"go.uber.org/zap"
    35  )
    36  
    37  func (a *API) HandleRequest(w http.ResponseWriter, httpRequest *http.Request, _ httprouter.Params) {
    38  	traceID := vgrand.RandomStr(64)
    39  	ctx := context.WithValue(httpRequest.Context(), jsonrpc.TraceIDKey, traceID)
    40  
    41  	a.log.Info("New request",
    42  		logging.String("url", vfmt.Escape(httpRequest.URL.String())),
    43  		logging.String("trace-id", traceID),
    44  	)
    45  
    46  	lw := newResponseWriter(w, traceID)
    47  	defer logResponse(a.log, lw)
    48  
    49  	rpcRequest, errDetails := a.unmarshallRequest(traceID, httpRequest)
    50  	if errDetails != nil {
    51  		lw.SetStatusCode(http.StatusBadRequest)
    52  		// Failing to unmarshall the request prevent us from retrieving the
    53  		// request ID. So, it's left empty.
    54  		lw.WriteJSONRPCResponse(jsonrpc.NewErrorResponse("", errDetails))
    55  		return
    56  	}
    57  
    58  	response := a.processJSONRPCRequest(ctx, traceID, lw, httpRequest, rpcRequest)
    59  
    60  	// If the request doesn't have an ID, it's a notification. Notifications do
    61  	// not send content back, even if an error occurred.
    62  	if rpcRequest.IsNotification() {
    63  		lw.SetStatusCode(http.StatusNoContent)
    64  		return
    65  	}
    66  
    67  	if response.Error == nil {
    68  		lw.SetStatusCode(http.StatusOK)
    69  	} else {
    70  		if response.Error.Code == api.ErrorCodeAuthenticationFailure {
    71  			lw.SetStatusCode(401)
    72  		} else if response.Error.IsInternalError() {
    73  			lw.SetStatusCode(http.StatusInternalServerError)
    74  		} else {
    75  			lw.SetStatusCode(http.StatusBadRequest)
    76  		}
    77  	}
    78  	lw.WriteJSONRPCResponse(response)
    79  }
    80  
    81  func (a *API) unmarshallRequest(traceID string, r *http.Request) (jsonrpc.Request, *jsonrpc.ErrorDetails) {
    82  	defer func() {
    83  		_ = r.Body.Close()
    84  	}()
    85  
    86  	body, err := io.ReadAll(r.Body)
    87  	if err != nil {
    88  		return jsonrpc.Request{}, jsonrpc.NewParseError(ErrCouldNotReadRequestBody)
    89  	}
    90  
    91  	if len(body) == 0 {
    92  		return jsonrpc.Request{}, jsonrpc.NewParseError(ErrRequestCannotBeBlank)
    93  	}
    94  
    95  	request := jsonrpc.Request{}
    96  	if err := json.Unmarshal(body, &request); err != nil {
    97  		a.log.Error("Request could not be parsed",
    98  			logging.String("trace-id", traceID),
    99  			logging.Error(err),
   100  		)
   101  
   102  		var syntaxError *json.SyntaxError
   103  		var unmarshallTypeError *json.UnmarshalTypeError
   104  		if errors.As(err, &syntaxError) || errors.As(err, &unmarshallTypeError) || errors.As(err, &unmarshallTypeError) {
   105  			return jsonrpc.Request{}, jsonrpc.NewParseError(err)
   106  		}
   107  
   108  		return jsonrpc.Request{}, jsonrpc.NewInvalidRequest(err)
   109  	}
   110  
   111  	strReq, _ := json.Marshal(&request)
   112  	a.log.Info("Request successfully parsed",
   113  		logging.String("request", vfmt.Escape(string(strReq))),
   114  		logging.String("trace-id", traceID),
   115  	)
   116  
   117  	return request, nil
   118  }
   119  
   120  func (a *API) processJSONRPCRequest(ctx context.Context, traceID string, lw *responseWriter, httpRequest *http.Request, rpcRequest jsonrpc.Request) *jsonrpc.Response {
   121  	// check for unicode headers
   122  	for k, h := range httpRequest.Header {
   123  		for _, v := range h {
   124  			if len([]rune(v)) != len(v) {
   125  				return jsonrpc.NewErrorResponse(rpcRequest.ID, &jsonrpc.ErrorDetails{
   126  					Code:    jsonrpc.ErrorCodeInvalidRequest,
   127  					Message: fmt.Sprintf("Header %s contains invalid characters", k),
   128  				})
   129  			}
   130  		}
   131  	}
   132  	if err := rpcRequest.Check(); err != nil {
   133  		a.log.Info("invalid RPC request",
   134  			zap.String("trace-id", traceID),
   135  			zap.Error(err))
   136  		return jsonrpc.NewErrorResponse(rpcRequest.ID, jsonrpc.NewInvalidRequest(err))
   137  	}
   138  
   139  	// We add this pre-check so users stop asking why they can't access the
   140  	// administrative endpoints.
   141  	if strings.HasPrefix(rpcRequest.Method, "admin.") {
   142  		a.log.Debug("attempt to call administrative endpoint rejected",
   143  			zap.String("trace-id", traceID),
   144  			zap.String("method", vfmt.Escape(rpcRequest.Method)))
   145  		return jsonrpc.NewErrorResponse(rpcRequest.ID, jsonrpc.NewUnsupportedMethod(ErrAdminEndpointsNotExposed))
   146  	}
   147  
   148  	command, ok := a.commands[rpcRequest.Method]
   149  	if !ok {
   150  		a.log.Debug("unknown RPC method",
   151  			zap.String("trace-id", traceID),
   152  			zap.String("method", vfmt.Escape(rpcRequest.Method)))
   153  		return jsonrpc.NewErrorResponse(rpcRequest.ID, jsonrpc.NewMethodNotFound(rpcRequest.Method))
   154  	}
   155  
   156  	result, errDetails := command(ctx, lw, httpRequest, rpcRequest)
   157  	if errDetails != nil {
   158  		a.log.Info("RPC request failed",
   159  			zap.String("trace-id", traceID),
   160  			zap.Error(errDetails))
   161  
   162  		return jsonrpc.NewErrorResponse(rpcRequest.ID, errDetails)
   163  	}
   164  
   165  	a.log.Info("RPC request succeeded",
   166  		zap.String("trace-id", traceID))
   167  
   168  	return jsonrpc.NewSuccessfulResponse(rpcRequest.ID, result)
   169  }
   170  
   171  func logResponse(logger *zap.Logger, lw *responseWriter) {
   172  	if lw.statusCode >= 400 && lw.statusCode <= 499 {
   173  		logger.Error("Client error",
   174  			logging.Int("http-status", lw.statusCode),
   175  			logging.String("response", string(lw.response)),
   176  			logging.String("request-id", vfmt.Escape(lw.requestID)),
   177  			logging.String("trace-id", lw.traceID),
   178  		)
   179  		return
   180  	}
   181  	if lw.statusCode >= 500 && lw.statusCode <= 599 {
   182  		logger.Error("Internal error",
   183  			logging.Int("http-status", lw.statusCode),
   184  			logging.Error(lw.internalError),
   185  			logging.String("request-id", vfmt.Escape(lw.requestID)),
   186  			logging.String("trace-id", lw.traceID),
   187  		)
   188  		return
   189  	}
   190  	logger.Info("Successful response",
   191  		logging.Int("http-status", lw.statusCode),
   192  		logging.String("response", string(lw.response)),
   193  		logging.String("request-id", vfmt.Escape(lw.requestID)),
   194  		logging.String("trace-id", lw.traceID),
   195  	)
   196  }