github.com/vmware/transport-go@v1.3.4/service/rest_service.go (about)

     1  // Copyright 2019-2020 VMware, Inc.
     2  // SPDX-License-Identifier: BSD-2-Clause
     3  
     4  package service
     5  
     6  import (
     7  	"bytes"
     8  	"encoding/json"
     9  	"github.com/vmware/transport-go/model"
    10  	"io"
    11  	"net/http"
    12  	"net/url"
    13  	"reflect"
    14  	"strings"
    15  )
    16  
    17  const (
    18  	restServiceChannel = "fabric-rest"
    19  )
    20  
    21  type RestServiceRequest struct {
    22  	// The destination URL of the request.
    23  	Uri string `json:"uri"`
    24  	// HTTP Method to use, e.g. GET, POST, PATCH etc.
    25  	Method string `json:"method"`
    26  	// The body of the request. String and []byte payloads will be sent as is,
    27  	// all other payloads will be serialized as json.
    28  	Body interface{} `json:"body"`
    29  	//  HTTP headers of the request.
    30  	Headers map[string]string `json:"headers"`
    31  	// Optional type of the response body. If provided the service will try to deserialize
    32  	// the response to this type.
    33  	// If omitted the response body will be deserialized as map[string]interface{}
    34  	// Note that if the response body is not a valid json you should set
    35  	// the ResponseType to string or []byte otherwise you might get deserialization error
    36  	// or empty result.
    37  	ResponseType reflect.Type
    38  	// Shouldn't be populated directly, the field is used to deserialize
    39  	// com.vmware.bifrost.core.model.RestServiceRequest Java/Typescript requests
    40  	ApiClass string `json:"apiClass"`
    41  }
    42  
    43  func (request *RestServiceRequest) marshalBody() ([]byte, error) {
    44  	// don't marshal string and []byte payloads as json
    45  	stringPayload, ok := request.Body.(string)
    46  	if ok {
    47  		return []byte(stringPayload), nil
    48  	}
    49  	bytePayload, ok := request.Body.([]byte)
    50  	if ok {
    51  		return bytePayload, nil
    52  	}
    53  	// encode the message payload as JSON
    54  	return json.Marshal(request.Body)
    55  }
    56  
    57  type restService struct {
    58  	httpClient http.Client
    59  	baseHost   string
    60  }
    61  
    62  func (rs *restService) setBaseHost(host string) {
    63  	rs.baseHost = host
    64  }
    65  
    66  func (rs *restService) HandleServiceRequest(request *model.Request, core FabricServiceCore) {
    67  
    68  	restReq, ok := rs.getRestServiceRequest(request)
    69  	if !ok {
    70  		core.SendErrorResponse(request, 500, "invalid RestServiceRequest payload")
    71  		return
    72  	}
    73  
    74  	body, err := restReq.marshalBody()
    75  	if err != nil {
    76  		core.SendErrorResponse(request, 500, "cannot marshal request body: "+err.Error())
    77  		return
    78  	}
    79  
    80  	httpReq, err := http.NewRequest(restReq.Method,
    81  		rs.getRequestUrl(restReq.Uri, core), bytes.NewBuffer(body))
    82  
    83  	if err != nil {
    84  		core.SendErrorResponse(request, 500, err.Error())
    85  		return
    86  	}
    87  
    88  	// update headers
    89  	for k, v := range restReq.Headers {
    90  		httpReq.Header.Add(k, v)
    91  	}
    92  
    93  	// add default Content-Type header if such is not provided in the request
    94  	if httpReq.Header.Get("Content-Type") == "" {
    95  		httpReq.Header.Add("Content-Type", "application/merge-patch+json")
    96  	}
    97  
    98  	contentType := httpReq.Header.Get("Content-Type")
    99  	if strings.Contains(contentType, "json") {
   100  		// leaving restReq.ResponseType empty is equivalent to treating the response as JSON. see deserializeResponse().
   101  	} else {
   102  		// otherwise default to byte slice. note that we have an arm for the string type, but defaulting to the byte
   103  		// slice makes the payload more flexible to handle in downstream consumers
   104  		restReq.ResponseType = reflect.TypeOf([]byte{})
   105  	}
   106  
   107  	httpResp, err := rs.httpClient.Do(httpReq)
   108  	if err != nil {
   109  		core.SendErrorResponse(request, 500, err.Error())
   110  		return
   111  	}
   112  	defer httpResp.Body.Close()
   113  
   114  	if httpResp.StatusCode >= 300 {
   115  		core.SendErrorResponseWithPayload(request, httpResp.StatusCode,
   116  			"rest-service error, unable to complete request: "+httpResp.Status,
   117  			map[string]interface{}{
   118  				"errorCode": httpResp.StatusCode,
   119  				"message":   "rest-service error, unable to complete request: " + httpResp.Status,
   120  			})
   121  		return
   122  	}
   123  
   124  	result, err := rs.deserializeResponse(httpResp.Body, restReq.ResponseType)
   125  	if err != nil {
   126  		core.SendErrorResponse(request, 500, "failed to deserialize response:"+err.Error())
   127  	} else {
   128  		core.SendResponse(request, result)
   129  	}
   130  }
   131  
   132  func (rs *restService) getRestServiceRequest(request *model.Request) (*RestServiceRequest, bool) {
   133  	restReq, ok := request.Payload.(*RestServiceRequest)
   134  	if ok {
   135  		return restReq, true
   136  	}
   137  
   138  	// check if the request.Payload is map[string]interface{} and convert it to RestServiceRequest
   139  	// This is needed to handle requests coming from Java/Typescript Transport clients.
   140  	reqAsMap, ok := request.Payload.(map[string]interface{})
   141  	if ok {
   142  		restServReqInt, err := model.ConvertValueToType(reqAsMap, reflect.TypeOf(&RestServiceRequest{}))
   143  		if err == nil && restServReqInt != nil {
   144  			restServReq := restServReqInt.(*RestServiceRequest)
   145  			if restServReq.ApiClass == "java.lang.String" {
   146  				restServReq.ResponseType = reflect.TypeOf("")
   147  			}
   148  			return restServReq, true
   149  		}
   150  	}
   151  
   152  	return nil, false
   153  }
   154  
   155  func (rs *restService) getRequestUrl(address string, core FabricServiceCore) string {
   156  	if rs.baseHost == "" {
   157  		return address
   158  	}
   159  
   160  	result, err := url.Parse(address)
   161  	if err != nil {
   162  		return address
   163  	}
   164  	result.Host = rs.baseHost
   165  	return result.String()
   166  }
   167  
   168  func (rs *restService) deserializeResponse(
   169  	body io.ReadCloser, responseType reflect.Type) (interface{}, error) {
   170  
   171  	if responseType != nil {
   172  
   173  		// check for string responseType
   174  		if responseType.Kind() == reflect.String {
   175  			buf := new(bytes.Buffer)
   176  			_, err := buf.ReadFrom(body)
   177  			if err != nil {
   178  				return nil, err
   179  			}
   180  			return buf.String(), nil
   181  		}
   182  
   183  		// check for []byte responseType
   184  		if responseType.Kind() == reflect.Slice &&
   185  			responseType == reflect.TypeOf([]byte{}) {
   186  			buf := new(bytes.Buffer)
   187  			_, err := buf.ReadFrom(body)
   188  			if err != nil {
   189  				return nil, err
   190  			}
   191  			return buf.Bytes(), nil
   192  		}
   193  
   194  		var returnResultAsPointer bool
   195  		if responseType.Kind() == reflect.Ptr {
   196  			returnResultAsPointer = true
   197  			responseType = responseType.Elem()
   198  		}
   199  		decodedValuePtr := reflect.New(responseType).Interface()
   200  		err := json.NewDecoder(body).Decode(&decodedValuePtr)
   201  		if err != nil {
   202  			return nil, err
   203  		}
   204  		if returnResultAsPointer {
   205  			return decodedValuePtr, nil
   206  		} else {
   207  			return reflect.ValueOf(decodedValuePtr).Elem().Interface(), nil
   208  		}
   209  	} else {
   210  		var result map[string]interface{}
   211  		err := json.NewDecoder(body).Decode(&result)
   212  		return result, err
   213  	}
   214  }