github.com/Jeffail/benthos/v3@v3.65.0/lib/input/http_server_test.go (about)

     1  package input_test
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"mime"
    10  	"mime/multipart"
    11  	"net"
    12  	"net/http"
    13  	"net/http/httptest"
    14  	"net/textproto"
    15  	"net/url"
    16  	"sync"
    17  	"testing"
    18  	"time"
    19  
    20  	"github.com/Jeffail/benthos/v3/lib/api"
    21  	"github.com/Jeffail/benthos/v3/lib/input"
    22  	"github.com/Jeffail/benthos/v3/lib/log"
    23  	"github.com/Jeffail/benthos/v3/lib/manager"
    24  	"github.com/Jeffail/benthos/v3/lib/message"
    25  	"github.com/Jeffail/benthos/v3/lib/message/roundtrip"
    26  	"github.com/Jeffail/benthos/v3/lib/metrics"
    27  	"github.com/Jeffail/benthos/v3/lib/ratelimit"
    28  	"github.com/Jeffail/benthos/v3/lib/response"
    29  	"github.com/Jeffail/benthos/v3/lib/types"
    30  	"github.com/gorilla/mux"
    31  	"github.com/gorilla/websocket"
    32  	"github.com/stretchr/testify/assert"
    33  	"github.com/stretchr/testify/require"
    34  
    35  	_ "github.com/Jeffail/benthos/v3/public/components/all"
    36  )
    37  
    38  /*
    39  type apiRegGorillaMutWrapper struct {
    40  	mut *http.ServeMux
    41  }
    42  
    43  func (a apiRegGorillaMutWrapper) RegisterEndpoint(path, desc string, h http.HandlerFunc) {
    44  	a.mut.HandleFunc(path, h)
    45  }
    46  */
    47  
    48  type apiRegGorillaMutWrapper struct {
    49  	mut *mux.Router
    50  }
    51  
    52  func (a apiRegGorillaMutWrapper) RegisterEndpoint(path, desc string, h http.HandlerFunc) {
    53  	a.mut.HandleFunc(path, h)
    54  }
    55  
    56  func TestHTTPBasic(t *testing.T) {
    57  	t.Parallel()
    58  
    59  	nTestLoops := 100
    60  
    61  	reg := apiRegGorillaMutWrapper{mut: mux.NewRouter()}
    62  	mgr, err := manager.New(manager.NewConfig(), reg, log.Noop(), metrics.Noop())
    63  	if err != nil {
    64  		t.Fatal(err)
    65  	}
    66  
    67  	conf := input.NewConfig()
    68  	conf.HTTPServer.Path = "/testpost"
    69  
    70  	h, err := input.NewHTTPServer(conf, mgr, log.Noop(), metrics.Noop())
    71  	if err != nil {
    72  		t.Fatal(err)
    73  	}
    74  
    75  	server := httptest.NewServer(reg.mut)
    76  	defer server.Close()
    77  
    78  	// Test both single and multipart messages.
    79  	for i := 0; i < nTestLoops; i++ {
    80  		testStr := fmt.Sprintf("test%v", i)
    81  		testResponse := fmt.Sprintf("response%v", i)
    82  		// Send it as single part
    83  		go func(input, output string) {
    84  			res, err := http.Post(
    85  				server.URL+"/testpost",
    86  				"application/octet-stream",
    87  				bytes.NewBuffer([]byte(input)),
    88  			)
    89  			if err != nil {
    90  				t.Error(err)
    91  			} else if res.StatusCode != 200 {
    92  				t.Errorf("Wrong error code returned: %v", res.StatusCode)
    93  			}
    94  			resBytes, err := io.ReadAll(res.Body)
    95  			if err != nil {
    96  				t.Error(err)
    97  			}
    98  			if exp, act := output, string(resBytes); exp != act {
    99  				t.Errorf("Wrong sync response: %v != %v", act, exp)
   100  			}
   101  		}(testStr, testResponse)
   102  
   103  		var ts types.Transaction
   104  		select {
   105  		case ts = <-h.TransactionChan():
   106  			if res := string(ts.Payload.Get(0).Get()); res != testStr {
   107  				t.Errorf("Wrong result, %v != %v", ts.Payload, res)
   108  			}
   109  			ts.Payload.Get(0).Set([]byte(testResponse))
   110  			roundtrip.SetAsResponse(ts.Payload)
   111  		case <-time.After(time.Second):
   112  			t.Error("Timed out waiting for message")
   113  		}
   114  		select {
   115  		case ts.ResponseChan <- response.NewAck():
   116  		case <-time.After(time.Second):
   117  			t.Error("Timed out waiting for response")
   118  		}
   119  	}
   120  
   121  	// Test MIME multipart parsing, as defined in RFC 2046
   122  	for i := 0; i < nTestLoops; i++ {
   123  		partOne := fmt.Sprintf("test%v part one", i)
   124  		partTwo := fmt.Sprintf("test%v part two", i)
   125  
   126  		testStr := fmt.Sprintf(
   127  			"--foo\r\n"+
   128  				"Content-Type: application/octet-stream\r\n\r\n"+
   129  				"%v\r\n"+
   130  				"--foo\r\n"+
   131  				"Content-Type: application/octet-stream\r\n\r\n"+
   132  				"%v\r\n"+
   133  				"--foo--\r\n",
   134  			partOne, partTwo)
   135  
   136  		// Send it as multi part
   137  		go func() {
   138  			if res, err := http.Post(
   139  				server.URL+"/testpost",
   140  				"multipart/mixed; boundary=foo",
   141  				bytes.NewBuffer([]byte(testStr)),
   142  			); err != nil {
   143  				t.Error(err)
   144  			} else if res.StatusCode != 200 {
   145  				t.Errorf("Wrong error code returned: %v", res.StatusCode)
   146  			}
   147  		}()
   148  
   149  		var ts types.Transaction
   150  		select {
   151  		case ts = <-h.TransactionChan():
   152  			if exp, actual := 2, ts.Payload.Len(); exp != actual {
   153  				t.Errorf("Wrong number of parts: %v != %v", actual, exp)
   154  			} else if exp, actual := partOne, string(ts.Payload.Get(0).Get()); exp != actual {
   155  				t.Errorf("Wrong result, %v != %v", actual, exp)
   156  			} else if exp, actual := partTwo, string(ts.Payload.Get(1).Get()); exp != actual {
   157  				t.Errorf("Wrong result, %v != %v", actual, exp)
   158  			}
   159  		case <-time.After(time.Second):
   160  			t.Error("Timed out waiting for message")
   161  		}
   162  		select {
   163  		case ts.ResponseChan <- response.NewAck():
   164  		case <-time.After(time.Second):
   165  			t.Error("Timed out waiting for response")
   166  		}
   167  	}
   168  
   169  	// Test requests without content-type
   170  	client := &http.Client{}
   171  
   172  	for i := 0; i < nTestLoops; i++ {
   173  		testStr := fmt.Sprintf("test%v", i)
   174  		testResponse := fmt.Sprintf("response%v", i)
   175  		// Send it as single part
   176  		go func(input, output string) {
   177  			req, err := http.NewRequest(
   178  				"POST", server.URL+"/testpost", bytes.NewBuffer([]byte(input)))
   179  			if err != nil {
   180  				t.Error(err)
   181  			}
   182  			res, err := client.Do(req)
   183  			if err != nil {
   184  				t.Error(err)
   185  			} else if res.StatusCode != 200 {
   186  				t.Errorf("Wrong error code returned: %v", res.StatusCode)
   187  			}
   188  			resBytes, err := io.ReadAll(res.Body)
   189  			if err != nil {
   190  				t.Error(err)
   191  			}
   192  			if exp, act := output, string(resBytes); exp != act {
   193  				t.Errorf("Wrong sync response: %v != %v", act, exp)
   194  			}
   195  		}(testStr, testResponse)
   196  
   197  		var ts types.Transaction
   198  		select {
   199  		case ts = <-h.TransactionChan():
   200  			if res := string(ts.Payload.Get(0).Get()); res != testStr {
   201  				t.Errorf("Wrong result, %v != %v", ts.Payload, res)
   202  			}
   203  			ts.Payload.Get(0).Set([]byte(testResponse))
   204  			roundtrip.SetAsResponse(ts.Payload)
   205  		case <-time.After(time.Second):
   206  			t.Error("Timed out waiting for message")
   207  		}
   208  		select {
   209  		case ts.ResponseChan <- response.NewAck():
   210  		case <-time.After(time.Second):
   211  			t.Error("Timed out waiting for response")
   212  		}
   213  	}
   214  
   215  	h.CloseAsync()
   216  }
   217  
   218  func getFreePort() (int, error) {
   219  	addr, err := net.ResolveTCPAddr("tcp", "localhost:0")
   220  	if err != nil {
   221  		return 0, err
   222  	}
   223  
   224  	listener, err := net.ListenTCP("tcp", addr)
   225  	if err != nil {
   226  		return 0, err
   227  	}
   228  	defer listener.Close()
   229  	return listener.Addr().(*net.TCPAddr).Port, nil
   230  }
   231  
   232  func TestHTTPServerLifecycle(t *testing.T) {
   233  	freePort, err := getFreePort()
   234  	require.NoError(t, err)
   235  
   236  	apiConf := api.NewConfig()
   237  	apiConf.Address = fmt.Sprintf("0.0.0.0:%v", freePort)
   238  	apiConf.Enabled = true
   239  
   240  	testURL := fmt.Sprintf("http://localhost:%v/foo/bar", freePort)
   241  
   242  	apiImpl, err := api.New("", "", apiConf, nil, log.Noop(), metrics.Noop())
   243  	require.NoError(t, err)
   244  
   245  	go func() {
   246  		_ = apiImpl.ListenAndServe()
   247  	}()
   248  	defer apiImpl.Shutdown(context.Background())
   249  
   250  	mgr, err := manager.New(manager.NewConfig(), apiImpl, log.Noop(), metrics.Noop())
   251  	require.NoError(t, err)
   252  
   253  	conf := input.NewConfig()
   254  	conf.HTTPServer.Path = "/foo/bar"
   255  
   256  	timeout := time.Second * 5
   257  	readNextMsg := func(in input.Type) (types.Message, error) {
   258  		t.Helper()
   259  		var tran types.Transaction
   260  		select {
   261  		case tran = <-in.TransactionChan():
   262  			select {
   263  			case tran.ResponseChan <- response.NewAck():
   264  			case <-time.After(timeout):
   265  				return nil, errors.New("timed out 1")
   266  			}
   267  		case <-time.After(timeout):
   268  			return nil, errors.New("timed out 2")
   269  		}
   270  		return tran.Payload, nil
   271  	}
   272  
   273  	server, err := input.NewHTTPServer(conf, mgr, log.Noop(), metrics.Noop())
   274  	require.NoError(t, err)
   275  
   276  	dummyData := []byte("a bunch of jolly leprechauns await")
   277  	go func() {
   278  		resp, cerr := http.Post(testURL, "text/plain", bytes.NewReader(dummyData))
   279  		if assert.NoError(t, cerr) {
   280  			resp.Body.Close()
   281  		}
   282  	}()
   283  
   284  	msg, err := readNextMsg(server)
   285  	require.NoError(t, err)
   286  	assert.Equal(t, dummyData, message.GetAllBytes(msg)[0])
   287  
   288  	server.CloseAsync()
   289  	assert.NoError(t, server.WaitForClose(time.Second))
   290  
   291  	res, err := http.Post(testURL, "text/plain", bytes.NewReader(dummyData))
   292  	assert.NoError(t, err)
   293  	assert.Equal(t, 404, res.StatusCode)
   294  
   295  	serverTwo, err := input.NewHTTPServer(conf, mgr, log.Noop(), metrics.Noop())
   296  	require.NoError(t, err)
   297  
   298  	go func() {
   299  		resp, cerr := http.Post(testURL, "text/plain", bytes.NewReader(dummyData))
   300  		if assert.NoError(t, cerr) {
   301  			resp.Body.Close()
   302  		}
   303  	}()
   304  
   305  	msg, err = readNextMsg(serverTwo)
   306  	require.NoError(t, err)
   307  	assert.Equal(t, dummyData, message.GetAllBytes(msg)[0])
   308  
   309  	serverTwo.CloseAsync()
   310  	assert.NoError(t, serverTwo.WaitForClose(time.Second))
   311  }
   312  
   313  func TestHTTPServerMetadata(t *testing.T) {
   314  	reg := apiRegGorillaMutWrapper{mut: mux.NewRouter()}
   315  	mgr, err := manager.New(manager.NewConfig(), reg, log.Noop(), metrics.Noop())
   316  	require.NoError(t, err)
   317  
   318  	conf := input.NewConfig()
   319  	conf.HTTPServer.Path = "/across/the/rainbow/bridge"
   320  
   321  	server, err := input.NewHTTPServer(conf, mgr, log.Noop(), metrics.Noop())
   322  	require.NoError(t, err)
   323  
   324  	defer func() {
   325  		server.CloseAsync()
   326  		assert.NoError(t, server.WaitForClose(time.Second))
   327  	}()
   328  
   329  	testServer := httptest.NewServer(reg.mut)
   330  	defer testServer.Close()
   331  
   332  	dummyPath := "/across/the/rainbow/bridge"
   333  	dummyQuery := url.Values{"foo": []string{"bar"}}
   334  	serverURL, err := url.Parse(testServer.URL)
   335  	require.NoError(t, err)
   336  
   337  	serverURL.Path = dummyPath
   338  	serverURL.RawQuery = dummyQuery.Encode()
   339  
   340  	dummyData := []byte("a bunch of jolly leprechauns await")
   341  	go func() {
   342  		resp, cerr := http.Post(serverURL.String(), "text/plain", bytes.NewReader(dummyData))
   343  		require.NoError(t, cerr)
   344  		defer resp.Body.Close()
   345  	}()
   346  
   347  	timeout := time.Second * 5
   348  
   349  	readNextMsg := func() (types.Message, error) {
   350  		var tran types.Transaction
   351  		select {
   352  		case tran = <-server.TransactionChan():
   353  			select {
   354  			case tran.ResponseChan <- response.NewAck():
   355  			case <-time.After(timeout):
   356  				return nil, errors.New("timed out 1")
   357  			}
   358  		case <-time.After(timeout):
   359  			return nil, errors.New("timed out 2")
   360  		}
   361  		return tran.Payload, nil
   362  	}
   363  
   364  	msg, err := readNextMsg()
   365  	require.NoError(t, err)
   366  	assert.Equal(t, dummyData, message.GetAllBytes(msg)[0])
   367  
   368  	meta := msg.Get(0).Metadata()
   369  	assert.Equal(t, dummyPath, meta.Get("http_server_request_path"))
   370  	assert.Equal(t, "POST", meta.Get("http_server_verb"))
   371  	assert.Regexp(t, "^Go-http-client/", meta.Get("http_server_user_agent"))
   372  	// Make sure query params are set in the metadata
   373  	assert.Contains(t, "bar", meta.Get("foo"))
   374  }
   375  
   376  func TestHTTPtServerPathParameters(t *testing.T) {
   377  	reg := apiRegGorillaMutWrapper{mut: mux.NewRouter()}
   378  	mgr, err := manager.New(manager.NewConfig(), reg, log.Noop(), metrics.Noop())
   379  	require.NoError(t, err)
   380  
   381  	conf := input.NewConfig()
   382  	conf.HTTPServer.Path = "/test/{foo}/{bar}"
   383  	conf.HTTPServer.AllowedVerbs = append(conf.HTTPServer.AllowedVerbs, "PUT")
   384  
   385  	server, err := input.NewHTTPServer(conf, mgr, log.Noop(), metrics.Noop())
   386  	require.NoError(t, err)
   387  
   388  	defer func() {
   389  		server.CloseAsync()
   390  		assert.NoError(t, server.WaitForClose(time.Second))
   391  	}()
   392  
   393  	testServer := httptest.NewServer(reg.mut)
   394  	defer testServer.Close()
   395  
   396  	dummyPath := "/test/foo1/bar1"
   397  	dummyQuery := url.Values{"mylove": []string{"will go on"}}
   398  	serverURL, err := url.Parse(testServer.URL)
   399  	require.NoError(t, err)
   400  
   401  	serverURL.Path = dummyPath
   402  	serverURL.RawQuery = dummyQuery.Encode()
   403  
   404  	dummyData := []byte("a bunch of jolly leprechauns await")
   405  	go func() {
   406  		req, cerr := http.NewRequest("PUT", serverURL.String(), bytes.NewReader(dummyData))
   407  		require.NoError(t, cerr)
   408  		req.Header.Set("Content-Type", "text/plain")
   409  		resp, cerr := http.DefaultClient.Do(req)
   410  		require.NoError(t, cerr)
   411  		defer resp.Body.Close()
   412  	}()
   413  
   414  	readNextMsg := func() (types.Message, error) {
   415  		var tran types.Transaction
   416  		select {
   417  		case tran = <-server.TransactionChan():
   418  			select {
   419  			case tran.ResponseChan <- response.NewAck():
   420  			case <-time.After(time.Second):
   421  				return nil, errors.New("timed out")
   422  			}
   423  		case <-time.After(time.Second):
   424  			return nil, errors.New("timed out")
   425  		}
   426  		return tran.Payload, nil
   427  	}
   428  
   429  	msg, err := readNextMsg()
   430  	require.NoError(t, err)
   431  	assert.Equal(t, dummyData, message.GetAllBytes(msg)[0])
   432  
   433  	meta := msg.Get(0).Metadata()
   434  
   435  	assert.Equal(t, dummyPath, meta.Get("http_server_request_path"))
   436  	assert.Equal(t, "PUT", meta.Get("http_server_verb"))
   437  	assert.Equal(t, "foo1", meta.Get("foo"))
   438  	assert.Equal(t, "bar1", meta.Get("bar"))
   439  	assert.Equal(t, "will go on", meta.Get("mylove"))
   440  }
   441  
   442  func TestHTTPBadRequests(t *testing.T) {
   443  	t.Parallel()
   444  
   445  	reg := apiRegGorillaMutWrapper{mut: mux.NewRouter()}
   446  	mgr, err := manager.New(manager.NewConfig(), reg, log.Noop(), metrics.Noop())
   447  	if err != nil {
   448  		t.Fatal(err)
   449  	}
   450  
   451  	conf := input.NewConfig()
   452  	conf.HTTPServer.Path = "/testpost"
   453  
   454  	h, err := input.NewHTTPServer(conf, mgr, log.Noop(), metrics.Noop())
   455  	if err != nil {
   456  		t.Fatal(err)
   457  	}
   458  
   459  	server := httptest.NewServer(reg.mut)
   460  	defer server.Close()
   461  
   462  	res, err := http.Get(server.URL + "/testpost")
   463  	if err != nil {
   464  		t.Error(err)
   465  		return
   466  	}
   467  	if exp, act := http.StatusMethodNotAllowed, res.StatusCode; exp != act {
   468  		t.Errorf("unexpected HTTP response code: %v != %v", exp, act)
   469  	}
   470  
   471  	h.CloseAsync()
   472  	if err := h.WaitForClose(time.Second * 5); err != nil {
   473  		t.Error(err)
   474  	}
   475  }
   476  
   477  func TestHTTPTimeout(t *testing.T) {
   478  	t.Parallel()
   479  
   480  	reg := apiRegGorillaMutWrapper{mut: mux.NewRouter()}
   481  	mgr, err := manager.New(manager.NewConfig(), reg, log.Noop(), metrics.Noop())
   482  	if err != nil {
   483  		t.Fatal(err)
   484  	}
   485  
   486  	conf := input.NewConfig()
   487  	conf.HTTPServer.Path = "/testpost"
   488  	conf.HTTPServer.Timeout = "1ms"
   489  
   490  	h, err := input.NewHTTPServer(conf, mgr, log.Noop(), metrics.Noop())
   491  	if err != nil {
   492  		t.Fatal(err)
   493  	}
   494  
   495  	server := httptest.NewServer(reg.mut)
   496  	defer server.Close()
   497  
   498  	var res *http.Response
   499  	res, err = http.Post(
   500  		server.URL+"/testpost",
   501  		"application/octet-stream",
   502  		bytes.NewBuffer([]byte("hello world")),
   503  	)
   504  	if err != nil {
   505  		t.Fatal(err)
   506  	}
   507  	if exp, act := http.StatusRequestTimeout, res.StatusCode; exp != act {
   508  		t.Errorf("Unexpected status code: %v != %v", exp, act)
   509  	}
   510  
   511  	h.CloseAsync()
   512  	if err := h.WaitForClose(time.Second * 5); err != nil {
   513  		t.Error(err)
   514  	}
   515  }
   516  
   517  func TestHTTPRateLimit(t *testing.T) {
   518  	t.Parallel()
   519  
   520  	reg := apiRegGorillaMutWrapper{mut: mux.NewRouter()}
   521  
   522  	rlConf := ratelimit.NewConfig()
   523  	rlConf.Type = ratelimit.TypeLocal
   524  	rlConf.Local.Count = 1
   525  	rlConf.Local.Interval = "60s"
   526  
   527  	mgrConf := manager.NewConfig()
   528  	mgrConf.RateLimits["foorl"] = rlConf
   529  	mgr, err := manager.New(mgrConf, reg, log.Noop(), metrics.Noop())
   530  	if err != nil {
   531  		t.Fatal(err)
   532  	}
   533  
   534  	conf := input.NewConfig()
   535  	conf.HTTPServer.Path = "/testpost"
   536  	conf.HTTPServer.RateLimit = "foorl"
   537  
   538  	h, err := input.NewHTTPServer(conf, mgr, log.Noop(), metrics.Noop())
   539  	if err != nil {
   540  		t.Fatal(err)
   541  	}
   542  
   543  	server := httptest.NewServer(reg.mut)
   544  	defer server.Close()
   545  
   546  	go func() {
   547  		var ts types.Transaction
   548  		select {
   549  		case ts = <-h.TransactionChan():
   550  		case <-time.After(time.Second):
   551  			t.Error("Timed out waiting for message")
   552  		}
   553  		select {
   554  		case ts.ResponseChan <- response.NewAck():
   555  		case <-time.After(time.Second):
   556  			t.Error("Timed out waiting for response")
   557  		}
   558  	}()
   559  
   560  	var res *http.Response
   561  	res, err = http.Post(
   562  		server.URL+"/testpost",
   563  		"application/octet-stream",
   564  		bytes.NewBuffer([]byte("hello world")),
   565  	)
   566  	if err != nil {
   567  		t.Fatal(err)
   568  	}
   569  	if exp, act := http.StatusOK, res.StatusCode; exp != act {
   570  		t.Errorf("Unexpected status code: %v != %v", exp, act)
   571  	}
   572  
   573  	res, err = http.Post(
   574  		server.URL+"/testpost",
   575  		"application/octet-stream",
   576  		bytes.NewBuffer([]byte("hello world")),
   577  	)
   578  	if err != nil {
   579  		t.Fatal(err)
   580  	}
   581  	if exp, act := http.StatusTooManyRequests, res.StatusCode; exp != act {
   582  		t.Errorf("Unexpected status code: %v != %v", exp, act)
   583  	}
   584  
   585  	h.CloseAsync()
   586  	if err := h.WaitForClose(time.Second * 5); err != nil {
   587  		t.Error(err)
   588  	}
   589  }
   590  
   591  func TestHTTPServerWebsockets(t *testing.T) {
   592  	t.Parallel()
   593  
   594  	reg := apiRegGorillaMutWrapper{mut: mux.NewRouter()}
   595  
   596  	mgr, err := manager.New(manager.NewConfig(), reg, log.Noop(), metrics.Noop())
   597  	if err != nil {
   598  		t.Fatal(err)
   599  	}
   600  
   601  	conf := input.NewConfig()
   602  	conf.HTTPServer.WSPath = "/testws"
   603  
   604  	h, err := input.NewHTTPServer(conf, mgr, log.Noop(), metrics.Noop())
   605  	if err != nil {
   606  		t.Fatal(err)
   607  	}
   608  
   609  	server := httptest.NewServer(reg.mut)
   610  	defer server.Close()
   611  
   612  	purl, err := url.Parse(server.URL + "/testws")
   613  	if err != nil {
   614  		t.Fatal(err)
   615  	}
   616  	purl.Scheme = "ws"
   617  
   618  	var client *websocket.Conn
   619  	if client, _, err = websocket.DefaultDialer.Dial(purl.String(), http.Header{}); err != nil {
   620  		t.Fatal(err)
   621  	}
   622  
   623  	wg := sync.WaitGroup{}
   624  	wg.Add(1)
   625  	go func() {
   626  		if clientErr := client.WriteMessage(
   627  			websocket.BinaryMessage, []byte("hello world 1"),
   628  		); clientErr != nil {
   629  			t.Error(clientErr)
   630  		}
   631  		wg.Done()
   632  	}()
   633  
   634  	var ts types.Transaction
   635  	select {
   636  	case ts = <-h.TransactionChan():
   637  	case <-time.After(time.Second):
   638  		t.Error("Timed out waiting for message")
   639  	}
   640  	if exp, act := `[hello world 1]`, fmt.Sprintf("%s", message.GetAllBytes(ts.Payload)); exp != act {
   641  		t.Errorf("Unexpected message: %v != %v", act, exp)
   642  	}
   643  	select {
   644  	case ts.ResponseChan <- response.NewAck():
   645  	case <-time.After(time.Second):
   646  		t.Error("Timed out waiting for response")
   647  	}
   648  	wg.Wait()
   649  
   650  	wg.Add(1)
   651  	go func() {
   652  		if closeErr := client.WriteMessage(
   653  			websocket.BinaryMessage, []byte("hello world 2"),
   654  		); closeErr != nil {
   655  			t.Error(closeErr)
   656  		}
   657  		wg.Done()
   658  	}()
   659  
   660  	select {
   661  	case ts = <-h.TransactionChan():
   662  	case <-time.After(time.Second):
   663  		t.Error("Timed out waiting for message")
   664  	}
   665  	if exp, act := `[hello world 2]`, fmt.Sprintf("%s", message.GetAllBytes(ts.Payload)); exp != act {
   666  		t.Errorf("Unexpected message: %v != %v", act, exp)
   667  	}
   668  	select {
   669  	case ts.ResponseChan <- response.NewAck():
   670  	case <-time.After(time.Second):
   671  		t.Error("Timed out waiting for response")
   672  	}
   673  	wg.Wait()
   674  
   675  	h.CloseAsync()
   676  	if err := h.WaitForClose(time.Second * 5); err != nil {
   677  		t.Error(err)
   678  	}
   679  }
   680  
   681  func TestHTTPServerWSRateLimit(t *testing.T) {
   682  	t.Parallel()
   683  
   684  	reg := apiRegGorillaMutWrapper{mut: mux.NewRouter()}
   685  
   686  	rlConf := ratelimit.NewConfig()
   687  	rlConf.Type = ratelimit.TypeLocal
   688  	rlConf.Local.Count = 1
   689  	rlConf.Local.Interval = "60s"
   690  
   691  	mgrConf := manager.NewConfig()
   692  	mgrConf.RateLimits["foorl"] = rlConf
   693  	mgr, err := manager.New(mgrConf, reg, log.Noop(), metrics.Noop())
   694  	if err != nil {
   695  		t.Fatal(err)
   696  	}
   697  
   698  	conf := input.NewConfig()
   699  	conf.HTTPServer.WSPath = "/testws"
   700  	conf.HTTPServer.WSWelcomeMessage = "test welcome"
   701  	conf.HTTPServer.WSRateLimitMessage = "test rate limited"
   702  	conf.HTTPServer.RateLimit = "foorl"
   703  
   704  	h, err := input.NewHTTPServer(conf, mgr, log.Noop(), metrics.Noop())
   705  	if err != nil {
   706  		t.Fatal(err)
   707  	}
   708  
   709  	server := httptest.NewServer(reg.mut)
   710  	defer server.Close()
   711  
   712  	purl, err := url.Parse(server.URL + "/testws")
   713  	if err != nil {
   714  		t.Fatal(err)
   715  	}
   716  	purl.Scheme = "ws"
   717  
   718  	var client *websocket.Conn
   719  	if client, _, err = websocket.DefaultDialer.Dial(purl.String(), http.Header{}); err != nil {
   720  		t.Fatal(err)
   721  	}
   722  
   723  	go func() {
   724  		var ts types.Transaction
   725  		select {
   726  		case ts = <-h.TransactionChan():
   727  		case <-time.After(time.Second):
   728  			t.Error("Timed out waiting for message")
   729  		}
   730  		select {
   731  		case ts.ResponseChan <- response.NewAck():
   732  		case <-time.After(time.Second):
   733  			t.Error("Timed out waiting for response")
   734  		}
   735  	}()
   736  
   737  	var msgBytes []byte
   738  	if _, msgBytes, err = client.ReadMessage(); err != nil {
   739  		t.Fatal(err)
   740  	}
   741  	if exp, act := "test welcome", string(msgBytes); exp != act {
   742  		t.Errorf("Unexpected welcome message: %v != %v", act, exp)
   743  	}
   744  
   745  	if err = client.WriteMessage(
   746  		websocket.BinaryMessage, []byte("hello world"),
   747  	); err != nil {
   748  		t.Fatal(err)
   749  	}
   750  
   751  	if err = client.WriteMessage(
   752  		websocket.BinaryMessage, []byte("hello world"),
   753  	); err != nil {
   754  		t.Fatal(err)
   755  	}
   756  
   757  	if _, msgBytes, err = client.ReadMessage(); err != nil {
   758  		t.Fatal(err)
   759  	}
   760  	if exp, act := "test rate limited", string(msgBytes); exp != act {
   761  		t.Errorf("Unexpected rate limit message: %v != %v", act, exp)
   762  	}
   763  
   764  	h.CloseAsync()
   765  	if err := h.WaitForClose(time.Second * 5); err != nil {
   766  		t.Error(err)
   767  	}
   768  }
   769  
   770  func TestHTTPSyncResponseHeaders(t *testing.T) {
   771  	t.Parallel()
   772  
   773  	reg := apiRegGorillaMutWrapper{mut: mux.NewRouter()}
   774  	mgr, err := manager.New(manager.NewConfig(), reg, log.Noop(), metrics.Noop())
   775  	if err != nil {
   776  		t.Fatal(err)
   777  	}
   778  
   779  	conf := input.NewConfig()
   780  	conf.HTTPServer.Path = "/testpost"
   781  	conf.HTTPServer.Response.Headers["Content-Type"] = "application/json"
   782  	conf.HTTPServer.Response.Headers["foo"] = `${!json("field1")}`
   783  	conf.HTTPServer.Response.ExtractMetadata.IncludePrefixes = []string{"Loca"}
   784  	conf.HTTPServer.Response.ExtractMetadata.IncludePatterns = []string{"name"}
   785  
   786  	h, err := input.NewHTTPServer(conf, mgr, log.Noop(), metrics.Noop())
   787  	if err != nil {
   788  		t.Fatal(err)
   789  	}
   790  
   791  	server := httptest.NewServer(reg.mut)
   792  	defer server.Close()
   793  
   794  	input := `{"foo":"test message","field1":"bar"}`
   795  
   796  	wg := sync.WaitGroup{}
   797  	wg.Add(1)
   798  	go func() {
   799  		defer wg.Done()
   800  
   801  		req, err := http.NewRequest(http.MethodPost, server.URL+"/testpost", bytes.NewBuffer([]byte(input)))
   802  		if err != nil {
   803  			t.Error(err)
   804  		}
   805  		req.Header.Set("Content-Type", "application/octet-stream")
   806  		req.Header.Set("Location", "Asgard")
   807  		req.Header.Set("Username", "Thor")
   808  		req.Header.Set("Language", "Norse")
   809  		res, err := http.DefaultClient.Do(req)
   810  		if err != nil {
   811  			t.Error(err)
   812  		} else if res.StatusCode != 200 {
   813  			t.Errorf("Wrong error code returned: %v", res.StatusCode)
   814  		}
   815  		resBytes, err := io.ReadAll(res.Body)
   816  		if err != nil {
   817  			t.Error(err)
   818  		}
   819  		if exp, act := input, string(resBytes); exp != act {
   820  			t.Errorf("Wrong sync response: %v != %v", act, exp)
   821  		}
   822  		if exp, act := "application/json", res.Header.Get("Content-Type"); exp != act {
   823  			t.Errorf("Wrong sync response header: %v != %v", act, exp)
   824  		}
   825  		if exp, act := "bar", res.Header.Get("foo"); exp != act {
   826  			t.Errorf("Wrong sync response header: %v != %v", act, exp)
   827  		}
   828  		if exp, act := "Asgard", res.Header.Get("Location"); exp != act {
   829  			t.Errorf("Wrong sync response header: %v != %v", act, exp)
   830  		}
   831  		if exp, act := "Thor", res.Header.Get("Username"); exp != act {
   832  			t.Errorf("Wrong sync response header: %v != %v", act, exp)
   833  		}
   834  		if exp, act := "", res.Header.Get("Language"); exp != act {
   835  			t.Errorf("Wrong sync response header: %v != %v", act, exp)
   836  		}
   837  	}()
   838  
   839  	var ts types.Transaction
   840  	select {
   841  	case ts = <-h.TransactionChan():
   842  		if res := string(ts.Payload.Get(0).Get()); res != input {
   843  			t.Errorf("Wrong result, %v != %v", ts.Payload, res)
   844  		}
   845  		roundtrip.SetAsResponse(ts.Payload)
   846  	case <-time.After(time.Second):
   847  		t.Fatal("Timed out waiting for message")
   848  	}
   849  	select {
   850  	case ts.ResponseChan <- response.NewAck():
   851  	case <-time.After(time.Second):
   852  		t.Error("Timed out waiting for response")
   853  	}
   854  
   855  	h.CloseAsync()
   856  	if err := h.WaitForClose(time.Second * 5); err != nil {
   857  		t.Error(err)
   858  	}
   859  
   860  	wg.Wait()
   861  }
   862  
   863  func createMultipart(payloads []string, contentType string) (hdr string, bodyBytes []byte, err error) {
   864  	body := &bytes.Buffer{}
   865  	writer := multipart.NewWriter(body)
   866  
   867  	for i := 0; i < len(payloads) && err == nil; i++ {
   868  		var part io.Writer
   869  		if part, err = writer.CreatePart(textproto.MIMEHeader{
   870  			"Content-Type": []string{contentType},
   871  		}); err == nil {
   872  			_, err = io.Copy(part, bytes.NewReader([]byte(payloads[i])))
   873  		}
   874  	}
   875  
   876  	if err != nil {
   877  		return "", nil, err
   878  	}
   879  
   880  	writer.Close()
   881  	return writer.FormDataContentType(), body.Bytes(), nil
   882  }
   883  
   884  func readMultipart(res *http.Response) ([]string, error) {
   885  	var params map[string]string
   886  	var err error
   887  	if contentType := res.Header.Get("Content-Type"); len(contentType) > 0 {
   888  		if _, params, err = mime.ParseMediaType(contentType); err != nil {
   889  			return nil, err
   890  		}
   891  	}
   892  
   893  	var buffer bytes.Buffer
   894  	var output []string
   895  
   896  	mr := multipart.NewReader(res.Body, params["boundary"])
   897  	var bufferIndex int64
   898  	for {
   899  		var p *multipart.Part
   900  		if p, err = mr.NextPart(); err != nil {
   901  			if err == io.EOF {
   902  				break
   903  			}
   904  			return nil, err
   905  		}
   906  
   907  		var bytesRead int64
   908  		if bytesRead, err = buffer.ReadFrom(p); err != nil {
   909  			return nil, err
   910  		}
   911  
   912  		output = append(output, string(buffer.Bytes()[bufferIndex:bufferIndex+bytesRead]))
   913  		bufferIndex += bytesRead
   914  	}
   915  
   916  	return output, nil
   917  }
   918  
   919  func TestHTTPSyncResponseMultipart(t *testing.T) {
   920  	t.Parallel()
   921  
   922  	reg := apiRegGorillaMutWrapper{mut: mux.NewRouter()}
   923  	mgr, err := manager.New(manager.NewConfig(), reg, log.Noop(), metrics.Noop())
   924  	require.NoError(t, err)
   925  
   926  	conf := input.NewConfig()
   927  	conf.HTTPServer.Path = "/testpost"
   928  	conf.HTTPServer.Response.Headers["Content-Type"] = "application/json"
   929  
   930  	h, err := input.NewHTTPServer(conf, mgr, log.Noop(), metrics.Noop())
   931  	require.NoError(t, err)
   932  
   933  	server := httptest.NewServer(reg.mut)
   934  	t.Cleanup(func() {
   935  		server.Close()
   936  	})
   937  
   938  	input := []string{
   939  		`{"foo":"test message 1","field1":"bar"}`,
   940  		`{"foo":"test message 2","field1":"baz"}`,
   941  		`{"foo":"test message 3","field1":"buz"}`,
   942  	}
   943  	output := []string{
   944  		`{"foo":"test message 4","field1":"bar"}`,
   945  		`{"foo":"test message 5","field1":"baz"}`,
   946  		`{"foo":"test message 6","field1":"buz"}`,
   947  	}
   948  
   949  	wg := sync.WaitGroup{}
   950  	wg.Add(1)
   951  	go func() {
   952  		defer wg.Done()
   953  
   954  		hdr, body, err := createMultipart(input, "application/octet-stream")
   955  		require.NoError(t, err)
   956  
   957  		res, err := http.Post(server.URL+"/testpost", hdr, bytes.NewReader(body))
   958  		require.NoError(t, err)
   959  		require.Equal(t, 200, res.StatusCode)
   960  
   961  		act, err := readMultipart(res)
   962  		require.NoError(t, err)
   963  		assert.Equal(t, output, act)
   964  	}()
   965  
   966  	var ts types.Transaction
   967  	select {
   968  	case ts = <-h.TransactionChan():
   969  		for i, in := range input {
   970  			assert.Equal(t, in, string(ts.Payload.Get(i).Get()))
   971  		}
   972  		for i, o := range output {
   973  			ts.Payload.Get(i).Set([]byte(o))
   974  		}
   975  		roundtrip.SetAsResponse(ts.Payload)
   976  	case <-time.After(time.Second):
   977  		t.Fatal("Timed out waiting for message")
   978  	}
   979  	select {
   980  	case ts.ResponseChan <- response.NewAck():
   981  	case <-time.After(time.Second):
   982  		t.Error("Timed out waiting for response")
   983  	}
   984  
   985  	h.CloseAsync()
   986  	err = h.WaitForClose(time.Second * 5)
   987  	require.NoError(t, err)
   988  
   989  	wg.Wait()
   990  }
   991  
   992  func TestHTTPSyncResponseHeadersStatus(t *testing.T) {
   993  	t.Parallel()
   994  
   995  	reg := apiRegGorillaMutWrapper{mut: mux.NewRouter()}
   996  	mgr, err := manager.New(manager.NewConfig(), reg, log.Noop(), metrics.Noop())
   997  	if err != nil {
   998  		t.Fatal(err)
   999  	}
  1000  
  1001  	conf := input.NewConfig()
  1002  	conf.HTTPServer.Path = "/testpost"
  1003  	conf.HTTPServer.Response.Status = `${! meta("status").or("200") }`
  1004  	conf.HTTPServer.Response.Headers["Content-Type"] = "application/json"
  1005  	conf.HTTPServer.Response.Headers["foo"] = `${!json("field1")}`
  1006  
  1007  	h, err := input.NewHTTPServer(conf, mgr, log.Noop(), metrics.Noop())
  1008  	if err != nil {
  1009  		t.Fatal(err)
  1010  	}
  1011  
  1012  	server := httptest.NewServer(reg.mut)
  1013  	defer server.Close()
  1014  
  1015  	input := `{"foo":"test message","field1":"bar"}`
  1016  
  1017  	wg := sync.WaitGroup{}
  1018  	wg.Add(1)
  1019  	go func() {
  1020  		defer wg.Done()
  1021  
  1022  		res, err := http.Post(
  1023  			server.URL+"/testpost",
  1024  			"application/octet-stream",
  1025  			bytes.NewBuffer([]byte(input)),
  1026  		)
  1027  		if err != nil {
  1028  			t.Error(err)
  1029  		} else if res.StatusCode != 200 {
  1030  			t.Errorf("Wrong error code returned: %v", res.StatusCode)
  1031  		}
  1032  		resBytes, err := io.ReadAll(res.Body)
  1033  		if err != nil {
  1034  			t.Error(err)
  1035  		}
  1036  		if exp, act := input, string(resBytes); exp != act {
  1037  			t.Errorf("Wrong sync response: %v != %v", act, exp)
  1038  		}
  1039  		if exp, act := "application/json", res.Header.Get("Content-Type"); exp != act {
  1040  			t.Errorf("Wrong sync response header: %v != %v", act, exp)
  1041  		}
  1042  		if exp, act := "bar", res.Header.Get("foo"); exp != act {
  1043  			t.Errorf("Wrong sync response header: %v != %v", act, exp)
  1044  		}
  1045  
  1046  		res, err = http.Post(
  1047  			server.URL+"/testpost",
  1048  			"application/octet-stream",
  1049  			bytes.NewBuffer([]byte(input)),
  1050  		)
  1051  		if err != nil {
  1052  			t.Error(err)
  1053  		} else if res.StatusCode != 400 {
  1054  			t.Errorf("Wrong error code returned: %v", res.StatusCode)
  1055  		}
  1056  		resBytes, err = io.ReadAll(res.Body)
  1057  		if err != nil {
  1058  			t.Error(err)
  1059  		}
  1060  		if exp, act := input, string(resBytes); exp != act {
  1061  			t.Errorf("Wrong sync response: %v != %v", act, exp)
  1062  		}
  1063  		if exp, act := "application/json", res.Header.Get("Content-Type"); exp != act {
  1064  			t.Errorf("Wrong sync response header: %v != %v", act, exp)
  1065  		}
  1066  		if exp, act := "bar", res.Header.Get("foo"); exp != act {
  1067  			t.Errorf("Wrong sync response header: %v != %v", act, exp)
  1068  		}
  1069  	}()
  1070  
  1071  	// Non errored message
  1072  	var ts types.Transaction
  1073  	select {
  1074  	case ts = <-h.TransactionChan():
  1075  		if res := string(ts.Payload.Get(0).Get()); res != input {
  1076  			t.Errorf("Wrong result, %v != %v", ts.Payload, res)
  1077  		}
  1078  		roundtrip.SetAsResponse(ts.Payload)
  1079  	case <-time.After(time.Second):
  1080  		t.Fatal("Timed out waiting for message")
  1081  	}
  1082  	select {
  1083  	case ts.ResponseChan <- response.NewAck():
  1084  	case <-time.After(time.Second):
  1085  		t.Error("Timed out waiting for response")
  1086  	}
  1087  
  1088  	// Errored message
  1089  	select {
  1090  	case ts = <-h.TransactionChan():
  1091  		if res := string(ts.Payload.Get(0).Get()); res != input {
  1092  			t.Errorf("Wrong result, %v != %v", ts.Payload, res)
  1093  		}
  1094  		ts.Payload.Get(0).Metadata().Set("status", "400")
  1095  		roundtrip.SetAsResponse(ts.Payload)
  1096  	case <-time.After(time.Second):
  1097  		t.Fatal("Timed out waiting for message")
  1098  	}
  1099  	select {
  1100  	case ts.ResponseChan <- response.NewAck():
  1101  	case <-time.After(time.Second):
  1102  		t.Error("Timed out waiting for response")
  1103  	}
  1104  
  1105  	h.CloseAsync()
  1106  	if err := h.WaitForClose(time.Second * 5); err != nil {
  1107  		t.Error(err)
  1108  	}
  1109  
  1110  	wg.Wait()
  1111  }