github.com/Jeffail/benthos/v3@v3.65.0/lib/input/http_server_test.go (about) 1 package input_test 2 3 import ( 4 "bytes" 5 "context" 6 "errors" 7 "fmt" 8 "io" 9 "mime" 10 "mime/multipart" 11 "net" 12 "net/http" 13 "net/http/httptest" 14 "net/textproto" 15 "net/url" 16 "sync" 17 "testing" 18 "time" 19 20 "github.com/Jeffail/benthos/v3/lib/api" 21 "github.com/Jeffail/benthos/v3/lib/input" 22 "github.com/Jeffail/benthos/v3/lib/log" 23 "github.com/Jeffail/benthos/v3/lib/manager" 24 "github.com/Jeffail/benthos/v3/lib/message" 25 "github.com/Jeffail/benthos/v3/lib/message/roundtrip" 26 "github.com/Jeffail/benthos/v3/lib/metrics" 27 "github.com/Jeffail/benthos/v3/lib/ratelimit" 28 "github.com/Jeffail/benthos/v3/lib/response" 29 "github.com/Jeffail/benthos/v3/lib/types" 30 "github.com/gorilla/mux" 31 "github.com/gorilla/websocket" 32 "github.com/stretchr/testify/assert" 33 "github.com/stretchr/testify/require" 34 35 _ "github.com/Jeffail/benthos/v3/public/components/all" 36 ) 37 38 /* 39 type apiRegGorillaMutWrapper struct { 40 mut *http.ServeMux 41 } 42 43 func (a apiRegGorillaMutWrapper) RegisterEndpoint(path, desc string, h http.HandlerFunc) { 44 a.mut.HandleFunc(path, h) 45 } 46 */ 47 48 type apiRegGorillaMutWrapper struct { 49 mut *mux.Router 50 } 51 52 func (a apiRegGorillaMutWrapper) RegisterEndpoint(path, desc string, h http.HandlerFunc) { 53 a.mut.HandleFunc(path, h) 54 } 55 56 func TestHTTPBasic(t *testing.T) { 57 t.Parallel() 58 59 nTestLoops := 100 60 61 reg := apiRegGorillaMutWrapper{mut: mux.NewRouter()} 62 mgr, err := manager.New(manager.NewConfig(), reg, log.Noop(), metrics.Noop()) 63 if err != nil { 64 t.Fatal(err) 65 } 66 67 conf := input.NewConfig() 68 conf.HTTPServer.Path = "/testpost" 69 70 h, err := input.NewHTTPServer(conf, mgr, log.Noop(), metrics.Noop()) 71 if err != nil { 72 t.Fatal(err) 73 } 74 75 server := httptest.NewServer(reg.mut) 76 defer server.Close() 77 78 // Test both single and multipart messages. 79 for i := 0; i < nTestLoops; i++ { 80 testStr := fmt.Sprintf("test%v", i) 81 testResponse := fmt.Sprintf("response%v", i) 82 // Send it as single part 83 go func(input, output string) { 84 res, err := http.Post( 85 server.URL+"/testpost", 86 "application/octet-stream", 87 bytes.NewBuffer([]byte(input)), 88 ) 89 if err != nil { 90 t.Error(err) 91 } else if res.StatusCode != 200 { 92 t.Errorf("Wrong error code returned: %v", res.StatusCode) 93 } 94 resBytes, err := io.ReadAll(res.Body) 95 if err != nil { 96 t.Error(err) 97 } 98 if exp, act := output, string(resBytes); exp != act { 99 t.Errorf("Wrong sync response: %v != %v", act, exp) 100 } 101 }(testStr, testResponse) 102 103 var ts types.Transaction 104 select { 105 case ts = <-h.TransactionChan(): 106 if res := string(ts.Payload.Get(0).Get()); res != testStr { 107 t.Errorf("Wrong result, %v != %v", ts.Payload, res) 108 } 109 ts.Payload.Get(0).Set([]byte(testResponse)) 110 roundtrip.SetAsResponse(ts.Payload) 111 case <-time.After(time.Second): 112 t.Error("Timed out waiting for message") 113 } 114 select { 115 case ts.ResponseChan <- response.NewAck(): 116 case <-time.After(time.Second): 117 t.Error("Timed out waiting for response") 118 } 119 } 120 121 // Test MIME multipart parsing, as defined in RFC 2046 122 for i := 0; i < nTestLoops; i++ { 123 partOne := fmt.Sprintf("test%v part one", i) 124 partTwo := fmt.Sprintf("test%v part two", i) 125 126 testStr := fmt.Sprintf( 127 "--foo\r\n"+ 128 "Content-Type: application/octet-stream\r\n\r\n"+ 129 "%v\r\n"+ 130 "--foo\r\n"+ 131 "Content-Type: application/octet-stream\r\n\r\n"+ 132 "%v\r\n"+ 133 "--foo--\r\n", 134 partOne, partTwo) 135 136 // Send it as multi part 137 go func() { 138 if res, err := http.Post( 139 server.URL+"/testpost", 140 "multipart/mixed; boundary=foo", 141 bytes.NewBuffer([]byte(testStr)), 142 ); err != nil { 143 t.Error(err) 144 } else if res.StatusCode != 200 { 145 t.Errorf("Wrong error code returned: %v", res.StatusCode) 146 } 147 }() 148 149 var ts types.Transaction 150 select { 151 case ts = <-h.TransactionChan(): 152 if exp, actual := 2, ts.Payload.Len(); exp != actual { 153 t.Errorf("Wrong number of parts: %v != %v", actual, exp) 154 } else if exp, actual := partOne, string(ts.Payload.Get(0).Get()); exp != actual { 155 t.Errorf("Wrong result, %v != %v", actual, exp) 156 } else if exp, actual := partTwo, string(ts.Payload.Get(1).Get()); exp != actual { 157 t.Errorf("Wrong result, %v != %v", actual, exp) 158 } 159 case <-time.After(time.Second): 160 t.Error("Timed out waiting for message") 161 } 162 select { 163 case ts.ResponseChan <- response.NewAck(): 164 case <-time.After(time.Second): 165 t.Error("Timed out waiting for response") 166 } 167 } 168 169 // Test requests without content-type 170 client := &http.Client{} 171 172 for i := 0; i < nTestLoops; i++ { 173 testStr := fmt.Sprintf("test%v", i) 174 testResponse := fmt.Sprintf("response%v", i) 175 // Send it as single part 176 go func(input, output string) { 177 req, err := http.NewRequest( 178 "POST", server.URL+"/testpost", bytes.NewBuffer([]byte(input))) 179 if err != nil { 180 t.Error(err) 181 } 182 res, err := client.Do(req) 183 if err != nil { 184 t.Error(err) 185 } else if res.StatusCode != 200 { 186 t.Errorf("Wrong error code returned: %v", res.StatusCode) 187 } 188 resBytes, err := io.ReadAll(res.Body) 189 if err != nil { 190 t.Error(err) 191 } 192 if exp, act := output, string(resBytes); exp != act { 193 t.Errorf("Wrong sync response: %v != %v", act, exp) 194 } 195 }(testStr, testResponse) 196 197 var ts types.Transaction 198 select { 199 case ts = <-h.TransactionChan(): 200 if res := string(ts.Payload.Get(0).Get()); res != testStr { 201 t.Errorf("Wrong result, %v != %v", ts.Payload, res) 202 } 203 ts.Payload.Get(0).Set([]byte(testResponse)) 204 roundtrip.SetAsResponse(ts.Payload) 205 case <-time.After(time.Second): 206 t.Error("Timed out waiting for message") 207 } 208 select { 209 case ts.ResponseChan <- response.NewAck(): 210 case <-time.After(time.Second): 211 t.Error("Timed out waiting for response") 212 } 213 } 214 215 h.CloseAsync() 216 } 217 218 func getFreePort() (int, error) { 219 addr, err := net.ResolveTCPAddr("tcp", "localhost:0") 220 if err != nil { 221 return 0, err 222 } 223 224 listener, err := net.ListenTCP("tcp", addr) 225 if err != nil { 226 return 0, err 227 } 228 defer listener.Close() 229 return listener.Addr().(*net.TCPAddr).Port, nil 230 } 231 232 func TestHTTPServerLifecycle(t *testing.T) { 233 freePort, err := getFreePort() 234 require.NoError(t, err) 235 236 apiConf := api.NewConfig() 237 apiConf.Address = fmt.Sprintf("0.0.0.0:%v", freePort) 238 apiConf.Enabled = true 239 240 testURL := fmt.Sprintf("http://localhost:%v/foo/bar", freePort) 241 242 apiImpl, err := api.New("", "", apiConf, nil, log.Noop(), metrics.Noop()) 243 require.NoError(t, err) 244 245 go func() { 246 _ = apiImpl.ListenAndServe() 247 }() 248 defer apiImpl.Shutdown(context.Background()) 249 250 mgr, err := manager.New(manager.NewConfig(), apiImpl, log.Noop(), metrics.Noop()) 251 require.NoError(t, err) 252 253 conf := input.NewConfig() 254 conf.HTTPServer.Path = "/foo/bar" 255 256 timeout := time.Second * 5 257 readNextMsg := func(in input.Type) (types.Message, error) { 258 t.Helper() 259 var tran types.Transaction 260 select { 261 case tran = <-in.TransactionChan(): 262 select { 263 case tran.ResponseChan <- response.NewAck(): 264 case <-time.After(timeout): 265 return nil, errors.New("timed out 1") 266 } 267 case <-time.After(timeout): 268 return nil, errors.New("timed out 2") 269 } 270 return tran.Payload, nil 271 } 272 273 server, err := input.NewHTTPServer(conf, mgr, log.Noop(), metrics.Noop()) 274 require.NoError(t, err) 275 276 dummyData := []byte("a bunch of jolly leprechauns await") 277 go func() { 278 resp, cerr := http.Post(testURL, "text/plain", bytes.NewReader(dummyData)) 279 if assert.NoError(t, cerr) { 280 resp.Body.Close() 281 } 282 }() 283 284 msg, err := readNextMsg(server) 285 require.NoError(t, err) 286 assert.Equal(t, dummyData, message.GetAllBytes(msg)[0]) 287 288 server.CloseAsync() 289 assert.NoError(t, server.WaitForClose(time.Second)) 290 291 res, err := http.Post(testURL, "text/plain", bytes.NewReader(dummyData)) 292 assert.NoError(t, err) 293 assert.Equal(t, 404, res.StatusCode) 294 295 serverTwo, err := input.NewHTTPServer(conf, mgr, log.Noop(), metrics.Noop()) 296 require.NoError(t, err) 297 298 go func() { 299 resp, cerr := http.Post(testURL, "text/plain", bytes.NewReader(dummyData)) 300 if assert.NoError(t, cerr) { 301 resp.Body.Close() 302 } 303 }() 304 305 msg, err = readNextMsg(serverTwo) 306 require.NoError(t, err) 307 assert.Equal(t, dummyData, message.GetAllBytes(msg)[0]) 308 309 serverTwo.CloseAsync() 310 assert.NoError(t, serverTwo.WaitForClose(time.Second)) 311 } 312 313 func TestHTTPServerMetadata(t *testing.T) { 314 reg := apiRegGorillaMutWrapper{mut: mux.NewRouter()} 315 mgr, err := manager.New(manager.NewConfig(), reg, log.Noop(), metrics.Noop()) 316 require.NoError(t, err) 317 318 conf := input.NewConfig() 319 conf.HTTPServer.Path = "/across/the/rainbow/bridge" 320 321 server, err := input.NewHTTPServer(conf, mgr, log.Noop(), metrics.Noop()) 322 require.NoError(t, err) 323 324 defer func() { 325 server.CloseAsync() 326 assert.NoError(t, server.WaitForClose(time.Second)) 327 }() 328 329 testServer := httptest.NewServer(reg.mut) 330 defer testServer.Close() 331 332 dummyPath := "/across/the/rainbow/bridge" 333 dummyQuery := url.Values{"foo": []string{"bar"}} 334 serverURL, err := url.Parse(testServer.URL) 335 require.NoError(t, err) 336 337 serverURL.Path = dummyPath 338 serverURL.RawQuery = dummyQuery.Encode() 339 340 dummyData := []byte("a bunch of jolly leprechauns await") 341 go func() { 342 resp, cerr := http.Post(serverURL.String(), "text/plain", bytes.NewReader(dummyData)) 343 require.NoError(t, cerr) 344 defer resp.Body.Close() 345 }() 346 347 timeout := time.Second * 5 348 349 readNextMsg := func() (types.Message, error) { 350 var tran types.Transaction 351 select { 352 case tran = <-server.TransactionChan(): 353 select { 354 case tran.ResponseChan <- response.NewAck(): 355 case <-time.After(timeout): 356 return nil, errors.New("timed out 1") 357 } 358 case <-time.After(timeout): 359 return nil, errors.New("timed out 2") 360 } 361 return tran.Payload, nil 362 } 363 364 msg, err := readNextMsg() 365 require.NoError(t, err) 366 assert.Equal(t, dummyData, message.GetAllBytes(msg)[0]) 367 368 meta := msg.Get(0).Metadata() 369 assert.Equal(t, dummyPath, meta.Get("http_server_request_path")) 370 assert.Equal(t, "POST", meta.Get("http_server_verb")) 371 assert.Regexp(t, "^Go-http-client/", meta.Get("http_server_user_agent")) 372 // Make sure query params are set in the metadata 373 assert.Contains(t, "bar", meta.Get("foo")) 374 } 375 376 func TestHTTPtServerPathParameters(t *testing.T) { 377 reg := apiRegGorillaMutWrapper{mut: mux.NewRouter()} 378 mgr, err := manager.New(manager.NewConfig(), reg, log.Noop(), metrics.Noop()) 379 require.NoError(t, err) 380 381 conf := input.NewConfig() 382 conf.HTTPServer.Path = "/test/{foo}/{bar}" 383 conf.HTTPServer.AllowedVerbs = append(conf.HTTPServer.AllowedVerbs, "PUT") 384 385 server, err := input.NewHTTPServer(conf, mgr, log.Noop(), metrics.Noop()) 386 require.NoError(t, err) 387 388 defer func() { 389 server.CloseAsync() 390 assert.NoError(t, server.WaitForClose(time.Second)) 391 }() 392 393 testServer := httptest.NewServer(reg.mut) 394 defer testServer.Close() 395 396 dummyPath := "/test/foo1/bar1" 397 dummyQuery := url.Values{"mylove": []string{"will go on"}} 398 serverURL, err := url.Parse(testServer.URL) 399 require.NoError(t, err) 400 401 serverURL.Path = dummyPath 402 serverURL.RawQuery = dummyQuery.Encode() 403 404 dummyData := []byte("a bunch of jolly leprechauns await") 405 go func() { 406 req, cerr := http.NewRequest("PUT", serverURL.String(), bytes.NewReader(dummyData)) 407 require.NoError(t, cerr) 408 req.Header.Set("Content-Type", "text/plain") 409 resp, cerr := http.DefaultClient.Do(req) 410 require.NoError(t, cerr) 411 defer resp.Body.Close() 412 }() 413 414 readNextMsg := func() (types.Message, error) { 415 var tran types.Transaction 416 select { 417 case tran = <-server.TransactionChan(): 418 select { 419 case tran.ResponseChan <- response.NewAck(): 420 case <-time.After(time.Second): 421 return nil, errors.New("timed out") 422 } 423 case <-time.After(time.Second): 424 return nil, errors.New("timed out") 425 } 426 return tran.Payload, nil 427 } 428 429 msg, err := readNextMsg() 430 require.NoError(t, err) 431 assert.Equal(t, dummyData, message.GetAllBytes(msg)[0]) 432 433 meta := msg.Get(0).Metadata() 434 435 assert.Equal(t, dummyPath, meta.Get("http_server_request_path")) 436 assert.Equal(t, "PUT", meta.Get("http_server_verb")) 437 assert.Equal(t, "foo1", meta.Get("foo")) 438 assert.Equal(t, "bar1", meta.Get("bar")) 439 assert.Equal(t, "will go on", meta.Get("mylove")) 440 } 441 442 func TestHTTPBadRequests(t *testing.T) { 443 t.Parallel() 444 445 reg := apiRegGorillaMutWrapper{mut: mux.NewRouter()} 446 mgr, err := manager.New(manager.NewConfig(), reg, log.Noop(), metrics.Noop()) 447 if err != nil { 448 t.Fatal(err) 449 } 450 451 conf := input.NewConfig() 452 conf.HTTPServer.Path = "/testpost" 453 454 h, err := input.NewHTTPServer(conf, mgr, log.Noop(), metrics.Noop()) 455 if err != nil { 456 t.Fatal(err) 457 } 458 459 server := httptest.NewServer(reg.mut) 460 defer server.Close() 461 462 res, err := http.Get(server.URL + "/testpost") 463 if err != nil { 464 t.Error(err) 465 return 466 } 467 if exp, act := http.StatusMethodNotAllowed, res.StatusCode; exp != act { 468 t.Errorf("unexpected HTTP response code: %v != %v", exp, act) 469 } 470 471 h.CloseAsync() 472 if err := h.WaitForClose(time.Second * 5); err != nil { 473 t.Error(err) 474 } 475 } 476 477 func TestHTTPTimeout(t *testing.T) { 478 t.Parallel() 479 480 reg := apiRegGorillaMutWrapper{mut: mux.NewRouter()} 481 mgr, err := manager.New(manager.NewConfig(), reg, log.Noop(), metrics.Noop()) 482 if err != nil { 483 t.Fatal(err) 484 } 485 486 conf := input.NewConfig() 487 conf.HTTPServer.Path = "/testpost" 488 conf.HTTPServer.Timeout = "1ms" 489 490 h, err := input.NewHTTPServer(conf, mgr, log.Noop(), metrics.Noop()) 491 if err != nil { 492 t.Fatal(err) 493 } 494 495 server := httptest.NewServer(reg.mut) 496 defer server.Close() 497 498 var res *http.Response 499 res, err = http.Post( 500 server.URL+"/testpost", 501 "application/octet-stream", 502 bytes.NewBuffer([]byte("hello world")), 503 ) 504 if err != nil { 505 t.Fatal(err) 506 } 507 if exp, act := http.StatusRequestTimeout, res.StatusCode; exp != act { 508 t.Errorf("Unexpected status code: %v != %v", exp, act) 509 } 510 511 h.CloseAsync() 512 if err := h.WaitForClose(time.Second * 5); err != nil { 513 t.Error(err) 514 } 515 } 516 517 func TestHTTPRateLimit(t *testing.T) { 518 t.Parallel() 519 520 reg := apiRegGorillaMutWrapper{mut: mux.NewRouter()} 521 522 rlConf := ratelimit.NewConfig() 523 rlConf.Type = ratelimit.TypeLocal 524 rlConf.Local.Count = 1 525 rlConf.Local.Interval = "60s" 526 527 mgrConf := manager.NewConfig() 528 mgrConf.RateLimits["foorl"] = rlConf 529 mgr, err := manager.New(mgrConf, reg, log.Noop(), metrics.Noop()) 530 if err != nil { 531 t.Fatal(err) 532 } 533 534 conf := input.NewConfig() 535 conf.HTTPServer.Path = "/testpost" 536 conf.HTTPServer.RateLimit = "foorl" 537 538 h, err := input.NewHTTPServer(conf, mgr, log.Noop(), metrics.Noop()) 539 if err != nil { 540 t.Fatal(err) 541 } 542 543 server := httptest.NewServer(reg.mut) 544 defer server.Close() 545 546 go func() { 547 var ts types.Transaction 548 select { 549 case ts = <-h.TransactionChan(): 550 case <-time.After(time.Second): 551 t.Error("Timed out waiting for message") 552 } 553 select { 554 case ts.ResponseChan <- response.NewAck(): 555 case <-time.After(time.Second): 556 t.Error("Timed out waiting for response") 557 } 558 }() 559 560 var res *http.Response 561 res, err = http.Post( 562 server.URL+"/testpost", 563 "application/octet-stream", 564 bytes.NewBuffer([]byte("hello world")), 565 ) 566 if err != nil { 567 t.Fatal(err) 568 } 569 if exp, act := http.StatusOK, res.StatusCode; exp != act { 570 t.Errorf("Unexpected status code: %v != %v", exp, act) 571 } 572 573 res, err = http.Post( 574 server.URL+"/testpost", 575 "application/octet-stream", 576 bytes.NewBuffer([]byte("hello world")), 577 ) 578 if err != nil { 579 t.Fatal(err) 580 } 581 if exp, act := http.StatusTooManyRequests, res.StatusCode; exp != act { 582 t.Errorf("Unexpected status code: %v != %v", exp, act) 583 } 584 585 h.CloseAsync() 586 if err := h.WaitForClose(time.Second * 5); err != nil { 587 t.Error(err) 588 } 589 } 590 591 func TestHTTPServerWebsockets(t *testing.T) { 592 t.Parallel() 593 594 reg := apiRegGorillaMutWrapper{mut: mux.NewRouter()} 595 596 mgr, err := manager.New(manager.NewConfig(), reg, log.Noop(), metrics.Noop()) 597 if err != nil { 598 t.Fatal(err) 599 } 600 601 conf := input.NewConfig() 602 conf.HTTPServer.WSPath = "/testws" 603 604 h, err := input.NewHTTPServer(conf, mgr, log.Noop(), metrics.Noop()) 605 if err != nil { 606 t.Fatal(err) 607 } 608 609 server := httptest.NewServer(reg.mut) 610 defer server.Close() 611 612 purl, err := url.Parse(server.URL + "/testws") 613 if err != nil { 614 t.Fatal(err) 615 } 616 purl.Scheme = "ws" 617 618 var client *websocket.Conn 619 if client, _, err = websocket.DefaultDialer.Dial(purl.String(), http.Header{}); err != nil { 620 t.Fatal(err) 621 } 622 623 wg := sync.WaitGroup{} 624 wg.Add(1) 625 go func() { 626 if clientErr := client.WriteMessage( 627 websocket.BinaryMessage, []byte("hello world 1"), 628 ); clientErr != nil { 629 t.Error(clientErr) 630 } 631 wg.Done() 632 }() 633 634 var ts types.Transaction 635 select { 636 case ts = <-h.TransactionChan(): 637 case <-time.After(time.Second): 638 t.Error("Timed out waiting for message") 639 } 640 if exp, act := `[hello world 1]`, fmt.Sprintf("%s", message.GetAllBytes(ts.Payload)); exp != act { 641 t.Errorf("Unexpected message: %v != %v", act, exp) 642 } 643 select { 644 case ts.ResponseChan <- response.NewAck(): 645 case <-time.After(time.Second): 646 t.Error("Timed out waiting for response") 647 } 648 wg.Wait() 649 650 wg.Add(1) 651 go func() { 652 if closeErr := client.WriteMessage( 653 websocket.BinaryMessage, []byte("hello world 2"), 654 ); closeErr != nil { 655 t.Error(closeErr) 656 } 657 wg.Done() 658 }() 659 660 select { 661 case ts = <-h.TransactionChan(): 662 case <-time.After(time.Second): 663 t.Error("Timed out waiting for message") 664 } 665 if exp, act := `[hello world 2]`, fmt.Sprintf("%s", message.GetAllBytes(ts.Payload)); exp != act { 666 t.Errorf("Unexpected message: %v != %v", act, exp) 667 } 668 select { 669 case ts.ResponseChan <- response.NewAck(): 670 case <-time.After(time.Second): 671 t.Error("Timed out waiting for response") 672 } 673 wg.Wait() 674 675 h.CloseAsync() 676 if err := h.WaitForClose(time.Second * 5); err != nil { 677 t.Error(err) 678 } 679 } 680 681 func TestHTTPServerWSRateLimit(t *testing.T) { 682 t.Parallel() 683 684 reg := apiRegGorillaMutWrapper{mut: mux.NewRouter()} 685 686 rlConf := ratelimit.NewConfig() 687 rlConf.Type = ratelimit.TypeLocal 688 rlConf.Local.Count = 1 689 rlConf.Local.Interval = "60s" 690 691 mgrConf := manager.NewConfig() 692 mgrConf.RateLimits["foorl"] = rlConf 693 mgr, err := manager.New(mgrConf, reg, log.Noop(), metrics.Noop()) 694 if err != nil { 695 t.Fatal(err) 696 } 697 698 conf := input.NewConfig() 699 conf.HTTPServer.WSPath = "/testws" 700 conf.HTTPServer.WSWelcomeMessage = "test welcome" 701 conf.HTTPServer.WSRateLimitMessage = "test rate limited" 702 conf.HTTPServer.RateLimit = "foorl" 703 704 h, err := input.NewHTTPServer(conf, mgr, log.Noop(), metrics.Noop()) 705 if err != nil { 706 t.Fatal(err) 707 } 708 709 server := httptest.NewServer(reg.mut) 710 defer server.Close() 711 712 purl, err := url.Parse(server.URL + "/testws") 713 if err != nil { 714 t.Fatal(err) 715 } 716 purl.Scheme = "ws" 717 718 var client *websocket.Conn 719 if client, _, err = websocket.DefaultDialer.Dial(purl.String(), http.Header{}); err != nil { 720 t.Fatal(err) 721 } 722 723 go func() { 724 var ts types.Transaction 725 select { 726 case ts = <-h.TransactionChan(): 727 case <-time.After(time.Second): 728 t.Error("Timed out waiting for message") 729 } 730 select { 731 case ts.ResponseChan <- response.NewAck(): 732 case <-time.After(time.Second): 733 t.Error("Timed out waiting for response") 734 } 735 }() 736 737 var msgBytes []byte 738 if _, msgBytes, err = client.ReadMessage(); err != nil { 739 t.Fatal(err) 740 } 741 if exp, act := "test welcome", string(msgBytes); exp != act { 742 t.Errorf("Unexpected welcome message: %v != %v", act, exp) 743 } 744 745 if err = client.WriteMessage( 746 websocket.BinaryMessage, []byte("hello world"), 747 ); err != nil { 748 t.Fatal(err) 749 } 750 751 if err = client.WriteMessage( 752 websocket.BinaryMessage, []byte("hello world"), 753 ); err != nil { 754 t.Fatal(err) 755 } 756 757 if _, msgBytes, err = client.ReadMessage(); err != nil { 758 t.Fatal(err) 759 } 760 if exp, act := "test rate limited", string(msgBytes); exp != act { 761 t.Errorf("Unexpected rate limit message: %v != %v", act, exp) 762 } 763 764 h.CloseAsync() 765 if err := h.WaitForClose(time.Second * 5); err != nil { 766 t.Error(err) 767 } 768 } 769 770 func TestHTTPSyncResponseHeaders(t *testing.T) { 771 t.Parallel() 772 773 reg := apiRegGorillaMutWrapper{mut: mux.NewRouter()} 774 mgr, err := manager.New(manager.NewConfig(), reg, log.Noop(), metrics.Noop()) 775 if err != nil { 776 t.Fatal(err) 777 } 778 779 conf := input.NewConfig() 780 conf.HTTPServer.Path = "/testpost" 781 conf.HTTPServer.Response.Headers["Content-Type"] = "application/json" 782 conf.HTTPServer.Response.Headers["foo"] = `${!json("field1")}` 783 conf.HTTPServer.Response.ExtractMetadata.IncludePrefixes = []string{"Loca"} 784 conf.HTTPServer.Response.ExtractMetadata.IncludePatterns = []string{"name"} 785 786 h, err := input.NewHTTPServer(conf, mgr, log.Noop(), metrics.Noop()) 787 if err != nil { 788 t.Fatal(err) 789 } 790 791 server := httptest.NewServer(reg.mut) 792 defer server.Close() 793 794 input := `{"foo":"test message","field1":"bar"}` 795 796 wg := sync.WaitGroup{} 797 wg.Add(1) 798 go func() { 799 defer wg.Done() 800 801 req, err := http.NewRequest(http.MethodPost, server.URL+"/testpost", bytes.NewBuffer([]byte(input))) 802 if err != nil { 803 t.Error(err) 804 } 805 req.Header.Set("Content-Type", "application/octet-stream") 806 req.Header.Set("Location", "Asgard") 807 req.Header.Set("Username", "Thor") 808 req.Header.Set("Language", "Norse") 809 res, err := http.DefaultClient.Do(req) 810 if err != nil { 811 t.Error(err) 812 } else if res.StatusCode != 200 { 813 t.Errorf("Wrong error code returned: %v", res.StatusCode) 814 } 815 resBytes, err := io.ReadAll(res.Body) 816 if err != nil { 817 t.Error(err) 818 } 819 if exp, act := input, string(resBytes); exp != act { 820 t.Errorf("Wrong sync response: %v != %v", act, exp) 821 } 822 if exp, act := "application/json", res.Header.Get("Content-Type"); exp != act { 823 t.Errorf("Wrong sync response header: %v != %v", act, exp) 824 } 825 if exp, act := "bar", res.Header.Get("foo"); exp != act { 826 t.Errorf("Wrong sync response header: %v != %v", act, exp) 827 } 828 if exp, act := "Asgard", res.Header.Get("Location"); exp != act { 829 t.Errorf("Wrong sync response header: %v != %v", act, exp) 830 } 831 if exp, act := "Thor", res.Header.Get("Username"); exp != act { 832 t.Errorf("Wrong sync response header: %v != %v", act, exp) 833 } 834 if exp, act := "", res.Header.Get("Language"); exp != act { 835 t.Errorf("Wrong sync response header: %v != %v", act, exp) 836 } 837 }() 838 839 var ts types.Transaction 840 select { 841 case ts = <-h.TransactionChan(): 842 if res := string(ts.Payload.Get(0).Get()); res != input { 843 t.Errorf("Wrong result, %v != %v", ts.Payload, res) 844 } 845 roundtrip.SetAsResponse(ts.Payload) 846 case <-time.After(time.Second): 847 t.Fatal("Timed out waiting for message") 848 } 849 select { 850 case ts.ResponseChan <- response.NewAck(): 851 case <-time.After(time.Second): 852 t.Error("Timed out waiting for response") 853 } 854 855 h.CloseAsync() 856 if err := h.WaitForClose(time.Second * 5); err != nil { 857 t.Error(err) 858 } 859 860 wg.Wait() 861 } 862 863 func createMultipart(payloads []string, contentType string) (hdr string, bodyBytes []byte, err error) { 864 body := &bytes.Buffer{} 865 writer := multipart.NewWriter(body) 866 867 for i := 0; i < len(payloads) && err == nil; i++ { 868 var part io.Writer 869 if part, err = writer.CreatePart(textproto.MIMEHeader{ 870 "Content-Type": []string{contentType}, 871 }); err == nil { 872 _, err = io.Copy(part, bytes.NewReader([]byte(payloads[i]))) 873 } 874 } 875 876 if err != nil { 877 return "", nil, err 878 } 879 880 writer.Close() 881 return writer.FormDataContentType(), body.Bytes(), nil 882 } 883 884 func readMultipart(res *http.Response) ([]string, error) { 885 var params map[string]string 886 var err error 887 if contentType := res.Header.Get("Content-Type"); len(contentType) > 0 { 888 if _, params, err = mime.ParseMediaType(contentType); err != nil { 889 return nil, err 890 } 891 } 892 893 var buffer bytes.Buffer 894 var output []string 895 896 mr := multipart.NewReader(res.Body, params["boundary"]) 897 var bufferIndex int64 898 for { 899 var p *multipart.Part 900 if p, err = mr.NextPart(); err != nil { 901 if err == io.EOF { 902 break 903 } 904 return nil, err 905 } 906 907 var bytesRead int64 908 if bytesRead, err = buffer.ReadFrom(p); err != nil { 909 return nil, err 910 } 911 912 output = append(output, string(buffer.Bytes()[bufferIndex:bufferIndex+bytesRead])) 913 bufferIndex += bytesRead 914 } 915 916 return output, nil 917 } 918 919 func TestHTTPSyncResponseMultipart(t *testing.T) { 920 t.Parallel() 921 922 reg := apiRegGorillaMutWrapper{mut: mux.NewRouter()} 923 mgr, err := manager.New(manager.NewConfig(), reg, log.Noop(), metrics.Noop()) 924 require.NoError(t, err) 925 926 conf := input.NewConfig() 927 conf.HTTPServer.Path = "/testpost" 928 conf.HTTPServer.Response.Headers["Content-Type"] = "application/json" 929 930 h, err := input.NewHTTPServer(conf, mgr, log.Noop(), metrics.Noop()) 931 require.NoError(t, err) 932 933 server := httptest.NewServer(reg.mut) 934 t.Cleanup(func() { 935 server.Close() 936 }) 937 938 input := []string{ 939 `{"foo":"test message 1","field1":"bar"}`, 940 `{"foo":"test message 2","field1":"baz"}`, 941 `{"foo":"test message 3","field1":"buz"}`, 942 } 943 output := []string{ 944 `{"foo":"test message 4","field1":"bar"}`, 945 `{"foo":"test message 5","field1":"baz"}`, 946 `{"foo":"test message 6","field1":"buz"}`, 947 } 948 949 wg := sync.WaitGroup{} 950 wg.Add(1) 951 go func() { 952 defer wg.Done() 953 954 hdr, body, err := createMultipart(input, "application/octet-stream") 955 require.NoError(t, err) 956 957 res, err := http.Post(server.URL+"/testpost", hdr, bytes.NewReader(body)) 958 require.NoError(t, err) 959 require.Equal(t, 200, res.StatusCode) 960 961 act, err := readMultipart(res) 962 require.NoError(t, err) 963 assert.Equal(t, output, act) 964 }() 965 966 var ts types.Transaction 967 select { 968 case ts = <-h.TransactionChan(): 969 for i, in := range input { 970 assert.Equal(t, in, string(ts.Payload.Get(i).Get())) 971 } 972 for i, o := range output { 973 ts.Payload.Get(i).Set([]byte(o)) 974 } 975 roundtrip.SetAsResponse(ts.Payload) 976 case <-time.After(time.Second): 977 t.Fatal("Timed out waiting for message") 978 } 979 select { 980 case ts.ResponseChan <- response.NewAck(): 981 case <-time.After(time.Second): 982 t.Error("Timed out waiting for response") 983 } 984 985 h.CloseAsync() 986 err = h.WaitForClose(time.Second * 5) 987 require.NoError(t, err) 988 989 wg.Wait() 990 } 991 992 func TestHTTPSyncResponseHeadersStatus(t *testing.T) { 993 t.Parallel() 994 995 reg := apiRegGorillaMutWrapper{mut: mux.NewRouter()} 996 mgr, err := manager.New(manager.NewConfig(), reg, log.Noop(), metrics.Noop()) 997 if err != nil { 998 t.Fatal(err) 999 } 1000 1001 conf := input.NewConfig() 1002 conf.HTTPServer.Path = "/testpost" 1003 conf.HTTPServer.Response.Status = `${! meta("status").or("200") }` 1004 conf.HTTPServer.Response.Headers["Content-Type"] = "application/json" 1005 conf.HTTPServer.Response.Headers["foo"] = `${!json("field1")}` 1006 1007 h, err := input.NewHTTPServer(conf, mgr, log.Noop(), metrics.Noop()) 1008 if err != nil { 1009 t.Fatal(err) 1010 } 1011 1012 server := httptest.NewServer(reg.mut) 1013 defer server.Close() 1014 1015 input := `{"foo":"test message","field1":"bar"}` 1016 1017 wg := sync.WaitGroup{} 1018 wg.Add(1) 1019 go func() { 1020 defer wg.Done() 1021 1022 res, err := http.Post( 1023 server.URL+"/testpost", 1024 "application/octet-stream", 1025 bytes.NewBuffer([]byte(input)), 1026 ) 1027 if err != nil { 1028 t.Error(err) 1029 } else if res.StatusCode != 200 { 1030 t.Errorf("Wrong error code returned: %v", res.StatusCode) 1031 } 1032 resBytes, err := io.ReadAll(res.Body) 1033 if err != nil { 1034 t.Error(err) 1035 } 1036 if exp, act := input, string(resBytes); exp != act { 1037 t.Errorf("Wrong sync response: %v != %v", act, exp) 1038 } 1039 if exp, act := "application/json", res.Header.Get("Content-Type"); exp != act { 1040 t.Errorf("Wrong sync response header: %v != %v", act, exp) 1041 } 1042 if exp, act := "bar", res.Header.Get("foo"); exp != act { 1043 t.Errorf("Wrong sync response header: %v != %v", act, exp) 1044 } 1045 1046 res, err = http.Post( 1047 server.URL+"/testpost", 1048 "application/octet-stream", 1049 bytes.NewBuffer([]byte(input)), 1050 ) 1051 if err != nil { 1052 t.Error(err) 1053 } else if res.StatusCode != 400 { 1054 t.Errorf("Wrong error code returned: %v", res.StatusCode) 1055 } 1056 resBytes, err = io.ReadAll(res.Body) 1057 if err != nil { 1058 t.Error(err) 1059 } 1060 if exp, act := input, string(resBytes); exp != act { 1061 t.Errorf("Wrong sync response: %v != %v", act, exp) 1062 } 1063 if exp, act := "application/json", res.Header.Get("Content-Type"); exp != act { 1064 t.Errorf("Wrong sync response header: %v != %v", act, exp) 1065 } 1066 if exp, act := "bar", res.Header.Get("foo"); exp != act { 1067 t.Errorf("Wrong sync response header: %v != %v", act, exp) 1068 } 1069 }() 1070 1071 // Non errored message 1072 var ts types.Transaction 1073 select { 1074 case ts = <-h.TransactionChan(): 1075 if res := string(ts.Payload.Get(0).Get()); res != input { 1076 t.Errorf("Wrong result, %v != %v", ts.Payload, res) 1077 } 1078 roundtrip.SetAsResponse(ts.Payload) 1079 case <-time.After(time.Second): 1080 t.Fatal("Timed out waiting for message") 1081 } 1082 select { 1083 case ts.ResponseChan <- response.NewAck(): 1084 case <-time.After(time.Second): 1085 t.Error("Timed out waiting for response") 1086 } 1087 1088 // Errored message 1089 select { 1090 case ts = <-h.TransactionChan(): 1091 if res := string(ts.Payload.Get(0).Get()); res != input { 1092 t.Errorf("Wrong result, %v != %v", ts.Payload, res) 1093 } 1094 ts.Payload.Get(0).Metadata().Set("status", "400") 1095 roundtrip.SetAsResponse(ts.Payload) 1096 case <-time.After(time.Second): 1097 t.Fatal("Timed out waiting for message") 1098 } 1099 select { 1100 case ts.ResponseChan <- response.NewAck(): 1101 case <-time.After(time.Second): 1102 t.Error("Timed out waiting for response") 1103 } 1104 1105 h.CloseAsync() 1106 if err := h.WaitForClose(time.Second * 5); err != nil { 1107 t.Error(err) 1108 } 1109 1110 wg.Wait() 1111 }