github.com/vmware/transport-go@v1.3.4/service/rest_service_test.go (about) 1 // Copyright 2019-2020 VMware, Inc. 2 // SPDX-License-Identifier: BSD-2-Clause 3 4 package service 5 6 import ( 7 "bytes" 8 "encoding/json" 9 "errors" 10 "github.com/stretchr/testify/assert" 11 "github.com/vmware/transport-go/model" 12 "io/ioutil" 13 "net/http" 14 "reflect" 15 "strings" 16 "sync" 17 "testing" 18 ) 19 20 type testItem struct { 21 Name string `json:"name"` 22 Count int `json:"count"` 23 } 24 25 func TestRestServiceRequest_marshalBody(t *testing.T) { 26 reqWithStringBody := &RestServiceRequest{Body: "test-body"} 27 body, err := reqWithStringBody.marshalBody() 28 assert.Nil(t, err) 29 assert.Equal(t, []byte("test-body"), body) 30 31 reqWithBytesBody := &RestServiceRequest{Body: []byte{1, 2, 3, 4}} 32 body, err = reqWithBytesBody.marshalBody() 33 assert.Nil(t, err) 34 assert.Equal(t, reqWithBytesBody.Body, body) 35 36 item := testItem{Name: "test-name", Count: 5} 37 reqWithTestItem := &RestServiceRequest{Body: item} 38 body, err = reqWithTestItem.marshalBody() 39 assert.Nil(t, err) 40 expectedValue, _ := json.Marshal(item) 41 assert.Equal(t, expectedValue, body) 42 } 43 44 func TestRestService_AutoRegistration(t *testing.T) { 45 assert.NotNil(t, GetServiceRegistry().(*serviceRegistry).services[restServiceChannel]) 46 } 47 48 // RoundTripFunc . 49 type RoundTripFunc func(req *http.Request) (*http.Response, error) 50 51 // RoundTrip . 52 func (f RoundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { 53 return f(req) 54 } 55 56 func TestRestService_HandleServiceRequest(t *testing.T) { 57 core := newTestFabricCore(restServiceChannel) 58 59 restService := &restService{} 60 var lastHttpRequest *http.Request 61 restService.httpClient.Transport = RoundTripFunc(func(req *http.Request) (*http.Response, error) { 62 lastHttpRequest = req 63 return &http.Response{ 64 StatusCode: 200, 65 Body: ioutil.NopCloser(bytes.NewBufferString("test-response-body")), 66 Header: make(http.Header), 67 }, nil 68 }) 69 70 var lastResponse *model.Response 71 72 wg := sync.WaitGroup{} 73 wg.Add(1) 74 75 mh, _ := core.Bus().ListenStream(restServiceChannel) 76 mh.Handle( 77 func(message *model.Message) { 78 lastResponse = message.Payload.(*model.Response) 79 wg.Done() 80 }, 81 func(e error) { 82 assert.Fail(t, "unexpected error") 83 }) 84 85 restService.HandleServiceRequest(&model.Request{ 86 Payload: &RestServiceRequest{ 87 Uri: "http://localhost:4444/test-url", 88 Headers: map[string]string{"header1": "value1", "header2": "value2"}, 89 Method: "UPDATE", 90 Body: "test-body", 91 ResponseType: reflect.TypeOf(""), 92 }, 93 }, core) 94 95 wg.Wait() 96 97 assert.NotNil(t, lastHttpRequest) 98 assert.Equal(t, lastHttpRequest.URL.String(), "http://localhost:4444/test-url") 99 assert.Equal(t, lastHttpRequest.Method, "UPDATE") 100 assert.Equal(t, lastHttpRequest.Header.Get("header1"), "value1") 101 assert.Equal(t, lastHttpRequest.Header.Get("header2"), "value2") 102 assert.Equal(t, lastHttpRequest.Header.Get("Content-Type"), "application/merge-patch+json") 103 sentBody, _ := ioutil.ReadAll(lastHttpRequest.Body) 104 assert.Equal(t, sentBody, []byte("test-body")) 105 106 assert.NotNil(t, lastResponse) 107 assert.Equal(t, lastResponse.Payload, "test-response-body") 108 assert.False(t, lastResponse.Error) 109 110 restService.httpClient.Transport = RoundTripFunc(func(req *http.Request) (*http.Response, error) { 111 lastHttpRequest = req 112 return &http.Response{ 113 StatusCode: 200, 114 Body: ioutil.NopCloser(bytes.NewBufferString(`{"name": "test-name", "count": 2}`)), 115 Header: make(http.Header), 116 }, nil 117 }) 118 119 wg.Add(1) 120 restService.HandleServiceRequest(&model.Request{ 121 Payload: &RestServiceRequest{ 122 Uri: "http://localhost:4444/test-url", 123 Headers: map[string]string{"Content-Type": "json"}, 124 ResponseType: reflect.TypeOf(testItem{}), 125 }, 126 }, core) 127 128 wg.Wait() 129 130 assert.Equal(t, lastHttpRequest.Header.Get("Content-Type"), "json") 131 assert.Equal(t, lastResponse.Payload, testItem{Name: "test-name", Count: 2}) 132 133 wg.Add(1) 134 restService.HandleServiceRequest(&model.Request{ 135 Payload: &RestServiceRequest{ 136 Uri: "http://localhost:4444/test-url", 137 ResponseType: reflect.TypeOf(&testItem{}), 138 }, 139 }, core) 140 141 wg.Wait() 142 143 assert.Equal(t, lastResponse.Payload, &testItem{Name: "test-name", Count: 2}) 144 145 wg.Add(1) 146 restService.HandleServiceRequest(&model.Request{ 147 Payload: &RestServiceRequest{ 148 Uri: "http://localhost:4444/test-url", 149 }, 150 }, core) 151 152 wg.Wait() 153 154 assert.Equal(t, lastResponse.Payload, map[string]interface{}{"name": "test-name", "count": float64(2)}) 155 156 restService.httpClient.Transport = RoundTripFunc(func(req *http.Request) (*http.Response, error) { 157 lastHttpRequest = req 158 return &http.Response{ 159 StatusCode: 200, 160 Body: ioutil.NopCloser(bytes.NewBuffer([]byte{1, 2, 3, 4, 5})), 161 Header: make(http.Header), 162 }, nil 163 }) 164 165 wg.Add(1) 166 restService.HandleServiceRequest(&model.Request{ 167 Payload: &RestServiceRequest{ 168 Uri: "http://localhost:4444/test-url", 169 ResponseType: reflect.TypeOf([]byte{}), 170 }, 171 }, core) 172 173 wg.Wait() 174 175 assert.Equal(t, lastResponse.Payload, []byte{1, 2, 3, 4, 5}) 176 } 177 178 func TestRestService_HandleJavaServiceRequest(t *testing.T) { 179 core := newTestFabricCore(restServiceChannel) 180 181 wg := sync.WaitGroup{} 182 183 restService := &restService{} 184 var lastHttpRequest *http.Request 185 restService.httpClient.Transport = RoundTripFunc(func(req *http.Request) (*http.Response, error) { 186 lastHttpRequest = req 187 defer wg.Done() 188 return &http.Response{ 189 StatusCode: 200, 190 Body: ioutil.NopCloser(bytes.NewBufferString("test-response-body")), 191 Header: make(http.Header), 192 }, nil 193 }) 194 195 wg.Add(1) 196 restService.HandleServiceRequest(&model.Request{ 197 Payload: map[string]interface{}{ 198 "uri": "http://localhost:4444/test-url", 199 "headers": map[string]string{"header1": "value1", "header2": "value2"}, 200 "method": "UPDATE", 201 "Body": "test-body", 202 "apiClass": "java.lang.String", 203 }, 204 }, core) 205 206 wg.Wait() 207 208 assert.NotNil(t, lastHttpRequest) 209 assert.Equal(t, lastHttpRequest.URL.String(), "http://localhost:4444/test-url") 210 assert.Equal(t, lastHttpRequest.Method, "UPDATE") 211 assert.Equal(t, lastHttpRequest.Header.Get("header1"), "value1") 212 assert.Equal(t, lastHttpRequest.Header.Get("header2"), "value2") 213 assert.Equal(t, lastHttpRequest.Header.Get("Content-Type"), "application/merge-patch+json") 214 sentBody, _ := ioutil.ReadAll(lastHttpRequest.Body) 215 assert.Equal(t, sentBody, []byte("test-body")) 216 } 217 218 func TestRestService_HandleServiceRequest_InvalidInput(t *testing.T) { 219 core := newTestFabricCore(restServiceChannel) 220 221 restService := &restService{} 222 var lastResponse *model.Response 223 224 wg := sync.WaitGroup{} 225 wg.Add(1) 226 mh, _ := core.Bus().ListenStream(restServiceChannel) 227 mh.Handle( 228 func(message *model.Message) { 229 lastResponse = message.Payload.(*model.Response) 230 wg.Done() 231 }, 232 func(e error) { 233 assert.Fail(t, "unexpected error") 234 }) 235 236 restService.HandleServiceRequest(&model.Request{ 237 Payload: RestServiceRequest{ 238 Uri: "http://localhost:4444/test-url", 239 Method: "UPDATE", 240 }, 241 }, core) 242 243 wg.Wait() 244 245 assert.NotNil(t, lastResponse) 246 assert.True(t, lastResponse.Error) 247 assert.Equal(t, lastResponse.ErrorCode, 500) 248 assert.Equal(t, lastResponse.ErrorMessage, "invalid RestServiceRequest payload") 249 250 wg.Add(1) 251 252 restService.HandleServiceRequest(&model.Request{ 253 Payload: &RestServiceRequest{ 254 Uri: "http://localhost:4444/test-url", 255 Method: "@!#$%^&**()", 256 }, 257 }, core) 258 259 wg.Wait() 260 assert.True(t, lastResponse.Error) 261 assert.Equal(t, lastResponse.ErrorCode, 500) 262 263 restService.httpClient.Transport = RoundTripFunc(func(req *http.Request) (*http.Response, error) { 264 return nil, errors.New("custom-rest-error") 265 }) 266 267 wg.Add(1) 268 restService.HandleServiceRequest(&model.Request{ 269 Payload: &RestServiceRequest{ 270 Uri: "http://localhost:4444/test-url", 271 }, 272 }, core) 273 wg.Wait() 274 275 assert.True(t, lastResponse.Error) 276 assert.Equal(t, lastResponse.ErrorCode, 500) 277 assert.True(t, strings.Contains(lastResponse.ErrorMessage, "custom-rest-error")) 278 279 restService.httpClient.Transport = RoundTripFunc(func(req *http.Request) (*http.Response, error) { 280 return &http.Response{ 281 StatusCode: 404, 282 Status: "404 Not Found", 283 Body: ioutil.NopCloser(bytes.NewBufferString("error-response")), 284 Header: make(http.Header), 285 }, nil 286 }) 287 288 wg.Add(1) 289 restService.HandleServiceRequest(&model.Request{ 290 Payload: &RestServiceRequest{ 291 Uri: "http://localhost:4444/test-url", 292 }, 293 }, core) 294 wg.Wait() 295 296 assert.True(t, lastResponse.Error) 297 assert.Equal(t, lastResponse.ErrorCode, 404) 298 assert.Equal(t, lastResponse.ErrorMessage, "rest-service error, unable to complete request: 404 Not Found") 299 300 restService.httpClient.Transport = RoundTripFunc(func(req *http.Request) (*http.Response, error) { 301 return &http.Response{ 302 StatusCode: 200, 303 Body: ioutil.NopCloser(bytes.NewBufferString("}")), 304 Header: make(http.Header), 305 }, nil 306 }) 307 308 wg.Add(1) 309 restService.HandleServiceRequest(&model.Request{ 310 Payload: &RestServiceRequest{ 311 Uri: "http://localhost:4444/test-url", 312 ResponseType: reflect.TypeOf(&testItem{}), 313 }, 314 }, core) 315 wg.Wait() 316 317 assert.True(t, lastResponse.Error) 318 assert.Equal(t, lastResponse.ErrorCode, 500) 319 assert.True(t, strings.HasPrefix(lastResponse.ErrorMessage, "failed to deserialize response:")) 320 } 321 322 func TestRestService_setBaseHost(t *testing.T) { 323 core := newTestFabricCore(restServiceChannel) 324 restService := &restService{} 325 326 wg := sync.WaitGroup{} 327 328 var lastHttpRequest *http.Request 329 restService.httpClient.Transport = RoundTripFunc(func(req *http.Request) (*http.Response, error) { 330 lastHttpRequest = req 331 wg.Done() 332 return &http.Response{ 333 StatusCode: 200, 334 Body: ioutil.NopCloser(bytes.NewBufferString("test-response-body")), 335 Header: make(http.Header), 336 }, nil 337 }) 338 339 restService.setBaseHost("appfabric.vmware.com:9999") 340 341 wg.Add(1) 342 restService.HandleServiceRequest(&model.Request{ 343 Payload: &RestServiceRequest{ 344 Uri: "http://localhost:4444/test-url", 345 }, 346 }, core) 347 348 wg.Wait() 349 350 assert.Equal(t, lastHttpRequest.Host, "appfabric.vmware.com:9999") 351 352 restService.setBaseHost("") 353 354 wg.Add(1) 355 restService.HandleServiceRequest(&model.Request{ 356 Payload: &RestServiceRequest{ 357 Uri: "http://localhost:4444/test-url", 358 }, 359 }, core) 360 wg.Wait() 361 362 assert.Equal(t, lastHttpRequest.Host, "localhost:4444") 363 }