github.com/vmware/transport-go@v1.3.4/service/rest_service_test.go (about)

     1  // Copyright 2019-2020 VMware, Inc.
     2  // SPDX-License-Identifier: BSD-2-Clause
     3  
     4  package service
     5  
     6  import (
     7  	"bytes"
     8  	"encoding/json"
     9  	"errors"
    10  	"github.com/stretchr/testify/assert"
    11  	"github.com/vmware/transport-go/model"
    12  	"io/ioutil"
    13  	"net/http"
    14  	"reflect"
    15  	"strings"
    16  	"sync"
    17  	"testing"
    18  )
    19  
    20  type testItem struct {
    21  	Name  string `json:"name"`
    22  	Count int    `json:"count"`
    23  }
    24  
    25  func TestRestServiceRequest_marshalBody(t *testing.T) {
    26  	reqWithStringBody := &RestServiceRequest{Body: "test-body"}
    27  	body, err := reqWithStringBody.marshalBody()
    28  	assert.Nil(t, err)
    29  	assert.Equal(t, []byte("test-body"), body)
    30  
    31  	reqWithBytesBody := &RestServiceRequest{Body: []byte{1, 2, 3, 4}}
    32  	body, err = reqWithBytesBody.marshalBody()
    33  	assert.Nil(t, err)
    34  	assert.Equal(t, reqWithBytesBody.Body, body)
    35  
    36  	item := testItem{Name: "test-name", Count: 5}
    37  	reqWithTestItem := &RestServiceRequest{Body: item}
    38  	body, err = reqWithTestItem.marshalBody()
    39  	assert.Nil(t, err)
    40  	expectedValue, _ := json.Marshal(item)
    41  	assert.Equal(t, expectedValue, body)
    42  }
    43  
    44  func TestRestService_AutoRegistration(t *testing.T) {
    45  	assert.NotNil(t, GetServiceRegistry().(*serviceRegistry).services[restServiceChannel])
    46  }
    47  
    48  // RoundTripFunc .
    49  type RoundTripFunc func(req *http.Request) (*http.Response, error)
    50  
    51  // RoundTrip .
    52  func (f RoundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
    53  	return f(req)
    54  }
    55  
    56  func TestRestService_HandleServiceRequest(t *testing.T) {
    57  	core := newTestFabricCore(restServiceChannel)
    58  
    59  	restService := &restService{}
    60  	var lastHttpRequest *http.Request
    61  	restService.httpClient.Transport = RoundTripFunc(func(req *http.Request) (*http.Response, error) {
    62  		lastHttpRequest = req
    63  		return &http.Response{
    64  			StatusCode: 200,
    65  			Body:       ioutil.NopCloser(bytes.NewBufferString("test-response-body")),
    66  			Header:     make(http.Header),
    67  		}, nil
    68  	})
    69  
    70  	var lastResponse *model.Response
    71  
    72  	wg := sync.WaitGroup{}
    73  	wg.Add(1)
    74  
    75  	mh, _ := core.Bus().ListenStream(restServiceChannel)
    76  	mh.Handle(
    77  		func(message *model.Message) {
    78  			lastResponse = message.Payload.(*model.Response)
    79  			wg.Done()
    80  		},
    81  		func(e error) {
    82  			assert.Fail(t, "unexpected error")
    83  		})
    84  
    85  	restService.HandleServiceRequest(&model.Request{
    86  		Payload: &RestServiceRequest{
    87  			Uri:          "http://localhost:4444/test-url",
    88  			Headers:      map[string]string{"header1": "value1", "header2": "value2"},
    89  			Method:       "UPDATE",
    90  			Body:         "test-body",
    91  			ResponseType: reflect.TypeOf(""),
    92  		},
    93  	}, core)
    94  
    95  	wg.Wait()
    96  
    97  	assert.NotNil(t, lastHttpRequest)
    98  	assert.Equal(t, lastHttpRequest.URL.String(), "http://localhost:4444/test-url")
    99  	assert.Equal(t, lastHttpRequest.Method, "UPDATE")
   100  	assert.Equal(t, lastHttpRequest.Header.Get("header1"), "value1")
   101  	assert.Equal(t, lastHttpRequest.Header.Get("header2"), "value2")
   102  	assert.Equal(t, lastHttpRequest.Header.Get("Content-Type"), "application/merge-patch+json")
   103  	sentBody, _ := ioutil.ReadAll(lastHttpRequest.Body)
   104  	assert.Equal(t, sentBody, []byte("test-body"))
   105  
   106  	assert.NotNil(t, lastResponse)
   107  	assert.Equal(t, lastResponse.Payload, "test-response-body")
   108  	assert.False(t, lastResponse.Error)
   109  
   110  	restService.httpClient.Transport = RoundTripFunc(func(req *http.Request) (*http.Response, error) {
   111  		lastHttpRequest = req
   112  		return &http.Response{
   113  			StatusCode: 200,
   114  			Body:       ioutil.NopCloser(bytes.NewBufferString(`{"name": "test-name", "count": 2}`)),
   115  			Header:     make(http.Header),
   116  		}, nil
   117  	})
   118  
   119  	wg.Add(1)
   120  	restService.HandleServiceRequest(&model.Request{
   121  		Payload: &RestServiceRequest{
   122  			Uri:          "http://localhost:4444/test-url",
   123  			Headers:      map[string]string{"Content-Type": "json"},
   124  			ResponseType: reflect.TypeOf(testItem{}),
   125  		},
   126  	}, core)
   127  
   128  	wg.Wait()
   129  
   130  	assert.Equal(t, lastHttpRequest.Header.Get("Content-Type"), "json")
   131  	assert.Equal(t, lastResponse.Payload, testItem{Name: "test-name", Count: 2})
   132  
   133  	wg.Add(1)
   134  	restService.HandleServiceRequest(&model.Request{
   135  		Payload: &RestServiceRequest{
   136  			Uri:          "http://localhost:4444/test-url",
   137  			ResponseType: reflect.TypeOf(&testItem{}),
   138  		},
   139  	}, core)
   140  
   141  	wg.Wait()
   142  
   143  	assert.Equal(t, lastResponse.Payload, &testItem{Name: "test-name", Count: 2})
   144  
   145  	wg.Add(1)
   146  	restService.HandleServiceRequest(&model.Request{
   147  		Payload: &RestServiceRequest{
   148  			Uri: "http://localhost:4444/test-url",
   149  		},
   150  	}, core)
   151  
   152  	wg.Wait()
   153  
   154  	assert.Equal(t, lastResponse.Payload, map[string]interface{}{"name": "test-name", "count": float64(2)})
   155  
   156  	restService.httpClient.Transport = RoundTripFunc(func(req *http.Request) (*http.Response, error) {
   157  		lastHttpRequest = req
   158  		return &http.Response{
   159  			StatusCode: 200,
   160  			Body:       ioutil.NopCloser(bytes.NewBuffer([]byte{1, 2, 3, 4, 5})),
   161  			Header:     make(http.Header),
   162  		}, nil
   163  	})
   164  
   165  	wg.Add(1)
   166  	restService.HandleServiceRequest(&model.Request{
   167  		Payload: &RestServiceRequest{
   168  			Uri:          "http://localhost:4444/test-url",
   169  			ResponseType: reflect.TypeOf([]byte{}),
   170  		},
   171  	}, core)
   172  
   173  	wg.Wait()
   174  
   175  	assert.Equal(t, lastResponse.Payload, []byte{1, 2, 3, 4, 5})
   176  }
   177  
   178  func TestRestService_HandleJavaServiceRequest(t *testing.T) {
   179  	core := newTestFabricCore(restServiceChannel)
   180  
   181  	wg := sync.WaitGroup{}
   182  
   183  	restService := &restService{}
   184  	var lastHttpRequest *http.Request
   185  	restService.httpClient.Transport = RoundTripFunc(func(req *http.Request) (*http.Response, error) {
   186  		lastHttpRequest = req
   187  		defer wg.Done()
   188  		return &http.Response{
   189  			StatusCode: 200,
   190  			Body:       ioutil.NopCloser(bytes.NewBufferString("test-response-body")),
   191  			Header:     make(http.Header),
   192  		}, nil
   193  	})
   194  
   195  	wg.Add(1)
   196  	restService.HandleServiceRequest(&model.Request{
   197  		Payload: map[string]interface{}{
   198  			"uri":      "http://localhost:4444/test-url",
   199  			"headers":  map[string]string{"header1": "value1", "header2": "value2"},
   200  			"method":   "UPDATE",
   201  			"Body":     "test-body",
   202  			"apiClass": "java.lang.String",
   203  		},
   204  	}, core)
   205  
   206  	wg.Wait()
   207  
   208  	assert.NotNil(t, lastHttpRequest)
   209  	assert.Equal(t, lastHttpRequest.URL.String(), "http://localhost:4444/test-url")
   210  	assert.Equal(t, lastHttpRequest.Method, "UPDATE")
   211  	assert.Equal(t, lastHttpRequest.Header.Get("header1"), "value1")
   212  	assert.Equal(t, lastHttpRequest.Header.Get("header2"), "value2")
   213  	assert.Equal(t, lastHttpRequest.Header.Get("Content-Type"), "application/merge-patch+json")
   214  	sentBody, _ := ioutil.ReadAll(lastHttpRequest.Body)
   215  	assert.Equal(t, sentBody, []byte("test-body"))
   216  }
   217  
   218  func TestRestService_HandleServiceRequest_InvalidInput(t *testing.T) {
   219  	core := newTestFabricCore(restServiceChannel)
   220  
   221  	restService := &restService{}
   222  	var lastResponse *model.Response
   223  
   224  	wg := sync.WaitGroup{}
   225  	wg.Add(1)
   226  	mh, _ := core.Bus().ListenStream(restServiceChannel)
   227  	mh.Handle(
   228  		func(message *model.Message) {
   229  			lastResponse = message.Payload.(*model.Response)
   230  			wg.Done()
   231  		},
   232  		func(e error) {
   233  			assert.Fail(t, "unexpected error")
   234  		})
   235  
   236  	restService.HandleServiceRequest(&model.Request{
   237  		Payload: RestServiceRequest{
   238  			Uri:    "http://localhost:4444/test-url",
   239  			Method: "UPDATE",
   240  		},
   241  	}, core)
   242  
   243  	wg.Wait()
   244  
   245  	assert.NotNil(t, lastResponse)
   246  	assert.True(t, lastResponse.Error)
   247  	assert.Equal(t, lastResponse.ErrorCode, 500)
   248  	assert.Equal(t, lastResponse.ErrorMessage, "invalid RestServiceRequest payload")
   249  
   250  	wg.Add(1)
   251  
   252  	restService.HandleServiceRequest(&model.Request{
   253  		Payload: &RestServiceRequest{
   254  			Uri:    "http://localhost:4444/test-url",
   255  			Method: "@!#$%^&**()",
   256  		},
   257  	}, core)
   258  
   259  	wg.Wait()
   260  	assert.True(t, lastResponse.Error)
   261  	assert.Equal(t, lastResponse.ErrorCode, 500)
   262  
   263  	restService.httpClient.Transport = RoundTripFunc(func(req *http.Request) (*http.Response, error) {
   264  		return nil, errors.New("custom-rest-error")
   265  	})
   266  
   267  	wg.Add(1)
   268  	restService.HandleServiceRequest(&model.Request{
   269  		Payload: &RestServiceRequest{
   270  			Uri: "http://localhost:4444/test-url",
   271  		},
   272  	}, core)
   273  	wg.Wait()
   274  
   275  	assert.True(t, lastResponse.Error)
   276  	assert.Equal(t, lastResponse.ErrorCode, 500)
   277  	assert.True(t, strings.Contains(lastResponse.ErrorMessage, "custom-rest-error"))
   278  
   279  	restService.httpClient.Transport = RoundTripFunc(func(req *http.Request) (*http.Response, error) {
   280  		return &http.Response{
   281  			StatusCode: 404,
   282  			Status:     "404 Not Found",
   283  			Body:       ioutil.NopCloser(bytes.NewBufferString("error-response")),
   284  			Header:     make(http.Header),
   285  		}, nil
   286  	})
   287  
   288  	wg.Add(1)
   289  	restService.HandleServiceRequest(&model.Request{
   290  		Payload: &RestServiceRequest{
   291  			Uri: "http://localhost:4444/test-url",
   292  		},
   293  	}, core)
   294  	wg.Wait()
   295  
   296  	assert.True(t, lastResponse.Error)
   297  	assert.Equal(t, lastResponse.ErrorCode, 404)
   298  	assert.Equal(t, lastResponse.ErrorMessage, "rest-service error, unable to complete request: 404 Not Found")
   299  
   300  	restService.httpClient.Transport = RoundTripFunc(func(req *http.Request) (*http.Response, error) {
   301  		return &http.Response{
   302  			StatusCode: 200,
   303  			Body:       ioutil.NopCloser(bytes.NewBufferString("}")),
   304  			Header:     make(http.Header),
   305  		}, nil
   306  	})
   307  
   308  	wg.Add(1)
   309  	restService.HandleServiceRequest(&model.Request{
   310  		Payload: &RestServiceRequest{
   311  			Uri:          "http://localhost:4444/test-url",
   312  			ResponseType: reflect.TypeOf(&testItem{}),
   313  		},
   314  	}, core)
   315  	wg.Wait()
   316  
   317  	assert.True(t, lastResponse.Error)
   318  	assert.Equal(t, lastResponse.ErrorCode, 500)
   319  	assert.True(t, strings.HasPrefix(lastResponse.ErrorMessage, "failed to deserialize response:"))
   320  }
   321  
   322  func TestRestService_setBaseHost(t *testing.T) {
   323  	core := newTestFabricCore(restServiceChannel)
   324  	restService := &restService{}
   325  
   326  	wg := sync.WaitGroup{}
   327  
   328  	var lastHttpRequest *http.Request
   329  	restService.httpClient.Transport = RoundTripFunc(func(req *http.Request) (*http.Response, error) {
   330  		lastHttpRequest = req
   331  		wg.Done()
   332  		return &http.Response{
   333  			StatusCode: 200,
   334  			Body:       ioutil.NopCloser(bytes.NewBufferString("test-response-body")),
   335  			Header:     make(http.Header),
   336  		}, nil
   337  	})
   338  
   339  	restService.setBaseHost("appfabric.vmware.com:9999")
   340  
   341  	wg.Add(1)
   342  	restService.HandleServiceRequest(&model.Request{
   343  		Payload: &RestServiceRequest{
   344  			Uri: "http://localhost:4444/test-url",
   345  		},
   346  	}, core)
   347  
   348  	wg.Wait()
   349  
   350  	assert.Equal(t, lastHttpRequest.Host, "appfabric.vmware.com:9999")
   351  
   352  	restService.setBaseHost("")
   353  
   354  	wg.Add(1)
   355  	restService.HandleServiceRequest(&model.Request{
   356  		Payload: &RestServiceRequest{
   357  			Uri: "http://localhost:4444/test-url",
   358  		},
   359  	}, core)
   360  	wg.Wait()
   361  
   362  	assert.Equal(t, lastHttpRequest.Host, "localhost:4444")
   363  }