github.com/prysmaticlabs/prysm@v1.4.4/shared/gateway/api_middleware_processing.go (about)

     1  package gateway
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/base64"
     6  	"encoding/json"
     7  	"fmt"
     8  	"io"
     9  	"io/ioutil"
    10  	"net/http"
    11  	"reflect"
    12  	"strconv"
    13  	"strings"
    14  	"time"
    15  
    16  	"github.com/ethereum/go-ethereum/common/hexutil"
    17  	"github.com/pkg/errors"
    18  	"github.com/prysmaticlabs/prysm/shared/grpcutils"
    19  	"github.com/wealdtech/go-bytesutil"
    20  )
    21  
    22  // DeserializeRequestBodyIntoContainer deserializes the request's body into an endpoint-specific struct.
    23  func DeserializeRequestBodyIntoContainer(body io.Reader, requestContainer interface{}) ErrorJson {
    24  	if err := json.NewDecoder(body).Decode(&requestContainer); err != nil {
    25  		return InternalServerErrorWithMessage(err, "could not decode request body")
    26  	}
    27  	return nil
    28  }
    29  
    30  // ProcessRequestContainerFields processes fields of an endpoint-specific container according to field tags.
    31  func ProcessRequestContainerFields(requestContainer interface{}) ErrorJson {
    32  	if err := processField(requestContainer, []fieldProcessor{
    33  		{
    34  			tag: "hex",
    35  			f:   hexToBase64Processor,
    36  		},
    37  	}); err != nil {
    38  		return InternalServerErrorWithMessage(err, "could not process request data")
    39  	}
    40  	return nil
    41  }
    42  
    43  // SetRequestBodyToRequestContainer makes the endpoint-specific container the new body of the request.
    44  func SetRequestBodyToRequestContainer(requestContainer interface{}, req *http.Request) ErrorJson {
    45  	// Serialize the struct, which now includes a base64-encoded value, into JSON.
    46  	j, err := json.Marshal(requestContainer)
    47  	if err != nil {
    48  		return InternalServerErrorWithMessage(err, "could not marshal request")
    49  	}
    50  	// Set the body to the new JSON.
    51  	req.Body = ioutil.NopCloser(bytes.NewReader(j))
    52  	req.Header.Set("Content-Length", strconv.Itoa(len(j)))
    53  	req.ContentLength = int64(len(j))
    54  	return nil
    55  }
    56  
    57  // PrepareRequestForProxying applies additional logic to the request so that it can be correctly proxied to grpc-gateway.
    58  func (m *ApiProxyMiddleware) PrepareRequestForProxying(endpoint Endpoint, req *http.Request) ErrorJson {
    59  	req.URL.Scheme = "http"
    60  	req.URL.Host = m.GatewayAddress
    61  	req.RequestURI = ""
    62  	if errJson := HandleURLParameters(endpoint.Path, req, endpoint.GetRequestURLLiterals); errJson != nil {
    63  		return errJson
    64  	}
    65  	return HandleQueryParameters(req, endpoint.GetRequestQueryParams)
    66  }
    67  
    68  // ProxyRequest proxies the request to grpc-gateway.
    69  func ProxyRequest(req *http.Request) (*http.Response, ErrorJson) {
    70  	// We do not use http.DefaultClient because it does not have any timeout.
    71  	netClient := &http.Client{Timeout: time.Minute * 2}
    72  	grpcResp, err := netClient.Do(req)
    73  	if err != nil {
    74  		return nil, InternalServerErrorWithMessage(err, "could not proxy request")
    75  	}
    76  	if grpcResp == nil {
    77  		return nil, &DefaultErrorJson{Message: "nil response from gRPC-gateway", Code: http.StatusInternalServerError}
    78  	}
    79  	return grpcResp, nil
    80  }
    81  
    82  // ReadGrpcResponseBody reads the body from the grpc-gateway's response.
    83  func ReadGrpcResponseBody(r io.Reader) ([]byte, ErrorJson) {
    84  	body, err := ioutil.ReadAll(r)
    85  	if err != nil {
    86  		return nil, InternalServerErrorWithMessage(err, "could not read response body")
    87  	}
    88  	return body, nil
    89  }
    90  
    91  // DeserializeGrpcResponseBodyIntoErrorJson deserializes the body from the grpc-gateway's response into an error struct.
    92  // The struct can be later examined to check if the request resulted in an error.
    93  func DeserializeGrpcResponseBodyIntoErrorJson(errJson ErrorJson, body []byte) ErrorJson {
    94  	if err := json.Unmarshal(body, errJson); err != nil {
    95  		return InternalServerErrorWithMessage(err, "could not unmarshal error")
    96  	}
    97  	return nil
    98  }
    99  
   100  // HandleGrpcResponseError acts on an error that resulted from a grpc-gateway's response.
   101  func HandleGrpcResponseError(errJson ErrorJson, resp *http.Response, w http.ResponseWriter) {
   102  	// Something went wrong, but the request completed, meaning we can write headers and the error message.
   103  	for h, vs := range resp.Header {
   104  		for _, v := range vs {
   105  			w.Header().Set(h, v)
   106  		}
   107  	}
   108  	// Set code to HTTP code because unmarshalled body contained gRPC code.
   109  	errJson.SetCode(resp.StatusCode)
   110  	WriteError(w, errJson, resp.Header)
   111  }
   112  
   113  // GrpcResponseIsStatusCodeOnly checks whether a grpc-gateway's response contained no body.
   114  func GrpcResponseIsStatusCodeOnly(req *http.Request, responseContainer interface{}) bool {
   115  	return req.Method == "GET" && responseContainer == nil
   116  }
   117  
   118  // DeserializeGrpcResponseBodyIntoContainer deserializes the grpc-gateway's response body into an endpoint-specific struct.
   119  func DeserializeGrpcResponseBodyIntoContainer(body []byte, responseContainer interface{}) ErrorJson {
   120  	if err := json.Unmarshal(body, &responseContainer); err != nil {
   121  		return InternalServerErrorWithMessage(err, "could not unmarshal response")
   122  	}
   123  	return nil
   124  }
   125  
   126  // ProcessMiddlewareResponseFields processes fields of an endpoint-specific container according to field tags.
   127  func ProcessMiddlewareResponseFields(responseContainer interface{}) ErrorJson {
   128  	if err := processField(responseContainer, []fieldProcessor{
   129  		{
   130  			tag: "hex",
   131  			f:   base64ToHexProcessor,
   132  		},
   133  		{
   134  			tag: "enum",
   135  			f:   enumToLowercaseProcessor,
   136  		},
   137  		{
   138  			tag: "time",
   139  			f:   timeToUnixProcessor,
   140  		},
   141  	}); err != nil {
   142  		return InternalServerErrorWithMessage(err, "could not process response data")
   143  	}
   144  	return nil
   145  }
   146  
   147  // SerializeMiddlewareResponseIntoJson serializes the endpoint-specific response struct into a JSON representation.
   148  func SerializeMiddlewareResponseIntoJson(responseContainer interface{}) (jsonResponse []byte, errJson ErrorJson) {
   149  	j, err := json.Marshal(responseContainer)
   150  	if err != nil {
   151  		return nil, InternalServerErrorWithMessage(err, "could not marshal response")
   152  	}
   153  	return j, nil
   154  }
   155  
   156  // WriteMiddlewareResponseHeadersAndBody populates headers and the body of the final response.
   157  func WriteMiddlewareResponseHeadersAndBody(req *http.Request, grpcResp *http.Response, responseJson []byte, w http.ResponseWriter) ErrorJson {
   158  	var statusCodeHeader string
   159  	for h, vs := range grpcResp.Header {
   160  		// We don't want to expose any gRPC metadata in the HTTP response, so we skip forwarding metadata headers.
   161  		if strings.HasPrefix(h, "Grpc-Metadata") {
   162  			if h == "Grpc-Metadata-"+grpcutils.HttpCodeMetadataKey {
   163  				statusCodeHeader = vs[0]
   164  			}
   165  		} else {
   166  			for _, v := range vs {
   167  				w.Header().Set(h, v)
   168  			}
   169  		}
   170  	}
   171  	if req.Method == "GET" {
   172  		w.Header().Set("Content-Length", strconv.Itoa(len(responseJson)))
   173  		if statusCodeHeader != "" {
   174  			code, err := strconv.Atoi(statusCodeHeader)
   175  			if err != nil {
   176  				return InternalServerErrorWithMessage(err, "could not parse status code")
   177  			}
   178  			w.WriteHeader(code)
   179  		} else {
   180  			w.WriteHeader(grpcResp.StatusCode)
   181  		}
   182  		if _, err := io.Copy(w, ioutil.NopCloser(bytes.NewReader(responseJson))); err != nil {
   183  			return InternalServerErrorWithMessage(err, "could not write response message")
   184  		}
   185  	} else if req.Method == "POST" {
   186  		w.WriteHeader(grpcResp.StatusCode)
   187  	}
   188  	return nil
   189  }
   190  
   191  // WriteError writes the error by manipulating headers and the body of the final response.
   192  func WriteError(w http.ResponseWriter, errJson ErrorJson, responseHeader http.Header) {
   193  	// Include custom error in the error JSON.
   194  	if responseHeader != nil {
   195  		customError, ok := responseHeader["Grpc-Metadata-"+grpcutils.CustomErrorMetadataKey]
   196  		if ok {
   197  			// Assume header has only one value and read the 0 index.
   198  			if err := json.Unmarshal([]byte(customError[0]), errJson); err != nil {
   199  				log.WithError(err).Error("Could not unmarshal custom error message")
   200  				return
   201  			}
   202  		}
   203  	}
   204  
   205  	j, err := json.Marshal(errJson)
   206  	if err != nil {
   207  		log.WithError(err).Error("Could not marshal error message")
   208  		return
   209  	}
   210  
   211  	w.Header().Set("Content-Length", strconv.Itoa(len(j)))
   212  	w.Header().Set("Content-Type", "application/json")
   213  	w.WriteHeader(errJson.StatusCode())
   214  	if _, err := io.Copy(w, ioutil.NopCloser(bytes.NewReader(j))); err != nil {
   215  		log.WithError(err).Error("Could not write error message")
   216  	}
   217  }
   218  
   219  // Cleanup performs final cleanup on the initial response from grpc-gateway.
   220  func Cleanup(grpcResponseBody io.ReadCloser) ErrorJson {
   221  	if err := grpcResponseBody.Close(); err != nil {
   222  		return InternalServerErrorWithMessage(err, "could not close response body")
   223  	}
   224  	return nil
   225  }
   226  
   227  // processField calls each processor function on any field that has the matching tag set.
   228  // It is a recursive function.
   229  func processField(s interface{}, processors []fieldProcessor) error {
   230  	kind := reflect.TypeOf(s).Kind()
   231  	if kind != reflect.Ptr && kind != reflect.Slice && kind != reflect.Array {
   232  		return fmt.Errorf("processing fields of kind '%v' is unsupported", kind)
   233  	}
   234  
   235  	t := reflect.TypeOf(s).Elem()
   236  	v := reflect.Indirect(reflect.ValueOf(s))
   237  
   238  	for i := 0; i < t.NumField(); i++ {
   239  		switch v.Field(i).Kind() {
   240  		case reflect.Slice:
   241  			sliceElem := t.Field(i).Type.Elem()
   242  			kind := sliceElem.Kind()
   243  			// Recursively process slices to struct pointers.
   244  			if kind == reflect.Ptr && sliceElem.Elem().Kind() == reflect.Struct {
   245  				for j := 0; j < v.Field(i).Len(); j++ {
   246  					if err := processField(v.Field(i).Index(j).Interface(), processors); err != nil {
   247  						return errors.Wrapf(err, "could not process field '%s'", t.Field(i).Name)
   248  					}
   249  				}
   250  			}
   251  			// Process each string in string slices.
   252  			if kind == reflect.String {
   253  				for _, proc := range processors {
   254  					_, hasTag := t.Field(i).Tag.Lookup(proc.tag)
   255  					if hasTag {
   256  						for j := 0; j < v.Field(i).Len(); j++ {
   257  							if err := proc.f(v.Field(i).Index(j)); err != nil {
   258  								return errors.Wrapf(err, "could not process field '%s'", t.Field(i).Name)
   259  							}
   260  						}
   261  					}
   262  				}
   263  
   264  			}
   265  		// Recursively process struct pointers.
   266  		case reflect.Ptr:
   267  			if v.Field(i).Elem().Kind() == reflect.Struct {
   268  				if err := processField(v.Field(i).Interface(), processors); err != nil {
   269  					return errors.Wrapf(err, "could not process field '%s'", t.Field(i).Name)
   270  				}
   271  			}
   272  		default:
   273  			field := t.Field(i)
   274  			for _, proc := range processors {
   275  				if _, hasTag := field.Tag.Lookup(proc.tag); hasTag {
   276  					if err := proc.f(v.Field(i)); err != nil {
   277  						return errors.Wrapf(err, "could not process field '%s'", t.Field(i).Name)
   278  					}
   279  				}
   280  			}
   281  		}
   282  	}
   283  	return nil
   284  }
   285  
   286  func hexToBase64Processor(v reflect.Value) error {
   287  	b, err := bytesutil.FromHexString(v.String())
   288  	if err != nil {
   289  		return err
   290  	}
   291  	v.SetString(base64.StdEncoding.EncodeToString(b))
   292  	return nil
   293  }
   294  
   295  func base64ToHexProcessor(v reflect.Value) error {
   296  	b, err := base64.StdEncoding.DecodeString(v.String())
   297  	if err != nil {
   298  		return err
   299  	}
   300  	v.SetString(hexutil.Encode(b))
   301  	return nil
   302  }
   303  
   304  func enumToLowercaseProcessor(v reflect.Value) error {
   305  	v.SetString(strings.ToLower(v.String()))
   306  	return nil
   307  }
   308  
   309  func timeToUnixProcessor(v reflect.Value) error {
   310  	t, err := time.Parse(time.RFC3339, v.String())
   311  	if err != nil {
   312  		return err
   313  	}
   314  	v.SetString(strconv.FormatUint(uint64(t.Unix()), 10))
   315  	return nil
   316  }