github.com/Jeffail/benthos/v3@v3.65.0/lib/input/socket_server_test.go (about)

     1  package input
     2  
     3  import (
     4  	"errors"
     5  	"net"
     6  	"path/filepath"
     7  	"sort"
     8  	"sync"
     9  	"testing"
    10  	"time"
    11  
    12  	"github.com/Jeffail/benthos/v3/lib/log"
    13  	"github.com/Jeffail/benthos/v3/lib/message"
    14  	"github.com/Jeffail/benthos/v3/lib/metrics"
    15  	"github.com/Jeffail/benthos/v3/lib/response"
    16  	"github.com/Jeffail/benthos/v3/lib/types"
    17  	"github.com/stretchr/testify/assert"
    18  	"github.com/stretchr/testify/require"
    19  )
    20  
    21  func TestSocketServerBasic(t *testing.T) {
    22  	tmpDir := t.TempDir()
    23  
    24  	conf := NewConfig()
    25  	conf.SocketServer.Network = "unix"
    26  	conf.SocketServer.Address = filepath.Join(tmpDir, "benthos.sock")
    27  
    28  	rdr, err := NewSocketServer(conf, nil, log.Noop(), metrics.Noop())
    29  	require.NoError(t, err)
    30  
    31  	defer func() {
    32  		rdr.CloseAsync()
    33  		assert.NoError(t, rdr.WaitForClose(time.Second))
    34  	}()
    35  
    36  	conn, err := net.Dial("unix", conf.SocketServer.Address)
    37  	require.NoError(t, err)
    38  
    39  	wg := sync.WaitGroup{}
    40  	wg.Add(1)
    41  	go func() {
    42  		conn.SetWriteDeadline(time.Now().Add(time.Second * 5))
    43  		_, cerr := conn.Write([]byte("foo\n"))
    44  		require.NoError(t, cerr)
    45  
    46  		_, cerr = conn.Write([]byte("bar\n"))
    47  		require.NoError(t, cerr)
    48  
    49  		_, cerr = conn.Write([]byte("baz\n"))
    50  		require.NoError(t, cerr)
    51  		wg.Done()
    52  	}()
    53  
    54  	readNextMsg := func() (types.Message, error) {
    55  		var tran types.Transaction
    56  		select {
    57  		case tran = <-rdr.TransactionChan():
    58  			select {
    59  			case tran.ResponseChan <- response.NewAck():
    60  			case <-time.After(time.Second):
    61  				return nil, errors.New("timed out")
    62  			}
    63  		case <-time.After(time.Second):
    64  			return nil, errors.New("timed out")
    65  		}
    66  		return tran.Payload, nil
    67  	}
    68  
    69  	exp := [][]byte{[]byte("foo")}
    70  	msg, err := readNextMsg()
    71  	require.NoError(t, err)
    72  	assert.Equal(t, exp, message.GetAllBytes(msg))
    73  
    74  	exp = [][]byte{[]byte("bar")}
    75  	msg, err = readNextMsg()
    76  	require.NoError(t, err)
    77  	assert.Equal(t, exp, message.GetAllBytes(msg))
    78  
    79  	exp = [][]byte{[]byte("baz")}
    80  	msg, err = readNextMsg()
    81  	require.NoError(t, err)
    82  	assert.Equal(t, exp, message.GetAllBytes(msg))
    83  
    84  	wg.Wait()
    85  	conn.Close()
    86  }
    87  
    88  func TestSocketServerRetries(t *testing.T) {
    89  	tmpDir := t.TempDir()
    90  
    91  	conf := NewConfig()
    92  	conf.SocketServer.Network = "unix"
    93  	conf.SocketServer.Address = filepath.Join(tmpDir, "benthos.sock")
    94  
    95  	rdr, err := NewSocketServer(conf, nil, log.Noop(), metrics.Noop())
    96  	require.NoError(t, err)
    97  
    98  	defer func() {
    99  		rdr.CloseAsync()
   100  		assert.NoError(t, rdr.WaitForClose(time.Second))
   101  	}()
   102  
   103  	conn, err := net.Dial("unix", conf.SocketServer.Address)
   104  	require.NoError(t, err)
   105  
   106  	wg := sync.WaitGroup{}
   107  	wg.Add(1)
   108  	go func() {
   109  		conn.SetWriteDeadline(time.Now().Add(time.Second * 5))
   110  		_, cerr := conn.Write([]byte("foo\n"))
   111  		require.NoError(t, cerr)
   112  
   113  		_, cerr = conn.Write([]byte("bar\n"))
   114  		require.NoError(t, cerr)
   115  
   116  		_, cerr = conn.Write([]byte("baz\n"))
   117  		require.NoError(t, cerr)
   118  		wg.Done()
   119  	}()
   120  
   121  	readNextMsg := func(reject bool) (types.Message, error) {
   122  		var tran types.Transaction
   123  		select {
   124  		case tran = <-rdr.TransactionChan():
   125  			var res types.Response = response.NewAck()
   126  			if reject {
   127  				res = response.NewError(errors.New("test err"))
   128  			}
   129  			select {
   130  			case tran.ResponseChan <- res:
   131  			case <-time.After(time.Second * 5):
   132  				return nil, errors.New("timed out")
   133  			}
   134  		case <-time.After(time.Second * 5):
   135  			return nil, errors.New("timed out")
   136  		}
   137  		return tran.Payload, nil
   138  	}
   139  
   140  	exp := [][]byte{[]byte("foo")}
   141  	msg, err := readNextMsg(false)
   142  	require.NoError(t, err)
   143  	assert.Equal(t, exp, message.GetAllBytes(msg))
   144  
   145  	exp = [][]byte{[]byte("bar")}
   146  	msg, err = readNextMsg(true)
   147  	require.NoError(t, err)
   148  	assert.Equal(t, exp, message.GetAllBytes(msg))
   149  
   150  	expRemaining := []string{"bar", "baz"}
   151  	actRemaining := []string{}
   152  
   153  	msg, err = readNextMsg(false)
   154  	require.NoError(t, err)
   155  	require.Equal(t, 1, msg.Len())
   156  	actRemaining = append(actRemaining, string(msg.Get(0).Get()))
   157  
   158  	msg, err = readNextMsg(false)
   159  	require.NoError(t, err)
   160  	require.Equal(t, 1, msg.Len())
   161  	actRemaining = append(actRemaining, string(msg.Get(0).Get()))
   162  
   163  	sort.Strings(actRemaining)
   164  	assert.Equal(t, expRemaining, actRemaining)
   165  
   166  	wg.Wait()
   167  	conn.Close()
   168  }
   169  
   170  func TestSocketServerWriteClosed(t *testing.T) {
   171  	tmpDir := t.TempDir()
   172  
   173  	conf := NewConfig()
   174  	conf.SocketServer.Network = "unix"
   175  	conf.SocketServer.Address = filepath.Join(tmpDir, "b.sock")
   176  
   177  	rdr, err := NewSocketServer(conf, nil, log.Noop(), metrics.Noop())
   178  	require.NoError(t, err)
   179  
   180  	conn, err := net.Dial("unix", conf.SocketServer.Address)
   181  	require.NoError(t, err)
   182  
   183  	conn.SetWriteDeadline(time.Now().Add(time.Second * 5))
   184  
   185  	rdr.CloseAsync()
   186  	assert.NoError(t, rdr.WaitForClose(time.Second*3))
   187  
   188  	_, cerr := conn.Write([]byte("bar\n"))
   189  	require.Error(t, cerr)
   190  
   191  	_, open := <-rdr.TransactionChan()
   192  	assert.False(t, open)
   193  
   194  	conn.Close()
   195  }
   196  
   197  func TestSocketServerRecon(t *testing.T) {
   198  	tmpDir := t.TempDir()
   199  
   200  	conf := NewConfig()
   201  	conf.SocketServer.Network = "unix"
   202  	conf.SocketServer.Address = filepath.Join(tmpDir, "benthos.sock")
   203  
   204  	rdr, err := NewSocketServer(conf, nil, log.Noop(), metrics.Noop())
   205  	require.NoError(t, err)
   206  
   207  	addr := rdr.(*SocketServer).Addr()
   208  
   209  	defer func() {
   210  		rdr.CloseAsync()
   211  		assert.NoError(t, rdr.WaitForClose(time.Second))
   212  	}()
   213  
   214  	conn, err := net.Dial("unix", addr.String())
   215  	require.NoError(t, err)
   216  
   217  	wg := sync.WaitGroup{}
   218  	wg.Add(1)
   219  	go func() {
   220  		conn.SetWriteDeadline(time.Now().Add(time.Second * 5))
   221  		_, cerr := conn.Write([]byte("foo\n"))
   222  		require.NoError(t, cerr)
   223  
   224  		conn.Close()
   225  		conn, cerr = net.Dial("unix", addr.String())
   226  		require.NoError(t, cerr)
   227  
   228  		_, cerr = conn.Write([]byte("bar\n"))
   229  		require.NoError(t, cerr)
   230  
   231  		_, cerr = conn.Write([]byte("baz\n"))
   232  		require.NoError(t, cerr)
   233  
   234  		wg.Done()
   235  	}()
   236  
   237  	readNextMsg := func() (types.Message, error) {
   238  		var tran types.Transaction
   239  		select {
   240  		case tran = <-rdr.TransactionChan():
   241  			select {
   242  			case tran.ResponseChan <- response.NewAck():
   243  			case <-time.After(time.Second):
   244  				return nil, errors.New("timed out")
   245  			}
   246  		case <-time.After(time.Second):
   247  			return nil, errors.New("timed out")
   248  		}
   249  		return tran.Payload, nil
   250  	}
   251  
   252  	expMsgs := map[string]struct{}{
   253  		"foo": {},
   254  		"bar": {},
   255  		"baz": {},
   256  	}
   257  
   258  	for i := 0; i < 3; i++ {
   259  		msg, err := readNextMsg()
   260  		require.NoError(t, err)
   261  
   262  		act := string(msg.Get(0).Get())
   263  		assert.Contains(t, expMsgs, act)
   264  
   265  		delete(expMsgs, act)
   266  	}
   267  
   268  	wg.Wait()
   269  	conn.Close()
   270  }
   271  
   272  func TestSocketServerMpart(t *testing.T) {
   273  	tmpDir := t.TempDir()
   274  
   275  	conf := NewConfig()
   276  	conf.SocketServer.Network = "unix"
   277  	conf.SocketServer.Address = filepath.Join(tmpDir, "benthos.sock")
   278  	conf.SocketServer.Multipart = true
   279  
   280  	rdr, err := NewSocketServer(conf, nil, log.Noop(), metrics.Noop())
   281  	require.NoError(t, err)
   282  
   283  	defer func() {
   284  		rdr.CloseAsync()
   285  		assert.NoError(t, rdr.WaitForClose(time.Second))
   286  	}()
   287  
   288  	conn, err := net.Dial("unix", conf.SocketServer.Address)
   289  	require.NoError(t, err)
   290  
   291  	wg := sync.WaitGroup{}
   292  	wg.Add(1)
   293  	go func() {
   294  		conn.SetWriteDeadline(time.Now().Add(time.Second * 5))
   295  		_, cerr := conn.Write([]byte("foo\n"))
   296  		require.NoError(t, cerr)
   297  
   298  		_, cerr = conn.Write([]byte("bar\n"))
   299  		require.NoError(t, cerr)
   300  
   301  		_, cerr = conn.Write([]byte("\n"))
   302  		require.NoError(t, cerr)
   303  
   304  		_, cerr = conn.Write([]byte("baz\n\n"))
   305  		require.NoError(t, cerr)
   306  
   307  		wg.Done()
   308  	}()
   309  
   310  	readNextMsg := func() (types.Message, error) {
   311  		var tran types.Transaction
   312  		select {
   313  		case tran = <-rdr.TransactionChan():
   314  			select {
   315  			case tran.ResponseChan <- response.NewAck():
   316  			case <-time.After(time.Second):
   317  				return nil, errors.New("timed out")
   318  			}
   319  		case <-time.After(time.Second):
   320  			return nil, errors.New("timed out")
   321  		}
   322  		return tran.Payload, nil
   323  	}
   324  
   325  	exp := [][]byte{[]byte("foo"), []byte("bar")}
   326  	msg, err := readNextMsg()
   327  	require.NoError(t, err)
   328  	assert.Equal(t, exp, message.GetAllBytes(msg))
   329  
   330  	exp = [][]byte{[]byte("baz")}
   331  	msg, err = readNextMsg()
   332  	require.NoError(t, err)
   333  	assert.Equal(t, exp, message.GetAllBytes(msg))
   334  
   335  	wg.Wait()
   336  	conn.Close()
   337  }
   338  
   339  func TestSocketServerMpartCDelim(t *testing.T) {
   340  	tmpDir := t.TempDir()
   341  
   342  	conf := NewConfig()
   343  	conf.SocketServer.Network = "unix"
   344  	conf.SocketServer.Address = filepath.Join(tmpDir, "b.sock")
   345  	conf.SocketServer.Multipart = true
   346  	conf.SocketServer.Delim = "@"
   347  
   348  	rdr, err := NewSocketServer(conf, nil, log.Noop(), metrics.Noop())
   349  	require.NoError(t, err)
   350  
   351  	defer func() {
   352  		rdr.CloseAsync()
   353  		assert.NoError(t, rdr.WaitForClose(time.Second))
   354  	}()
   355  
   356  	conn, err := net.Dial("unix", conf.SocketServer.Address)
   357  	require.NoError(t, err)
   358  
   359  	wg := sync.WaitGroup{}
   360  	wg.Add(1)
   361  	go func() {
   362  		conn.SetWriteDeadline(time.Now().Add(time.Second * 5))
   363  		_, cerr := conn.Write([]byte("foo@"))
   364  		require.NoError(t, cerr)
   365  
   366  		_, cerr = conn.Write([]byte("bar@"))
   367  		require.NoError(t, cerr)
   368  
   369  		_, cerr = conn.Write([]byte("@"))
   370  		require.NoError(t, cerr)
   371  
   372  		_, cerr = conn.Write([]byte("baz\n@@"))
   373  		require.NoError(t, cerr)
   374  
   375  		wg.Done()
   376  	}()
   377  
   378  	readNextMsg := func() (types.Message, error) {
   379  		var tran types.Transaction
   380  		select {
   381  		case tran = <-rdr.TransactionChan():
   382  			select {
   383  			case tran.ResponseChan <- response.NewAck():
   384  			case <-time.After(time.Second):
   385  				return nil, errors.New("timed out")
   386  			}
   387  		case <-time.After(time.Second):
   388  			return nil, errors.New("timed out")
   389  		}
   390  		return tran.Payload, nil
   391  	}
   392  
   393  	exp := [][]byte{[]byte("foo"), []byte("bar")}
   394  	msg, err := readNextMsg()
   395  	require.NoError(t, err)
   396  	assert.Equal(t, exp, message.GetAllBytes(msg))
   397  
   398  	exp = [][]byte{[]byte("baz\n")}
   399  	msg, err = readNextMsg()
   400  	require.NoError(t, err)
   401  	assert.Equal(t, exp, message.GetAllBytes(msg))
   402  
   403  	wg.Wait()
   404  	conn.Close()
   405  }
   406  
   407  func TestSocketServerMpartSdown(t *testing.T) {
   408  	tmpDir := t.TempDir()
   409  
   410  	conf := NewConfig()
   411  	conf.SocketServer.Network = "unix"
   412  	conf.SocketServer.Address = filepath.Join(tmpDir, "b.sock")
   413  	conf.SocketServer.Multipart = true
   414  
   415  	rdr, err := NewSocketServer(conf, nil, log.Noop(), metrics.Noop())
   416  	require.NoError(t, err)
   417  
   418  	defer func() {
   419  		rdr.CloseAsync()
   420  		assert.NoError(t, rdr.WaitForClose(time.Second))
   421  	}()
   422  
   423  	conn, err := net.Dial("unix", conf.SocketServer.Address)
   424  	require.NoError(t, err)
   425  
   426  	wg := sync.WaitGroup{}
   427  	wg.Add(1)
   428  	go func() {
   429  		conn.SetWriteDeadline(time.Now().Add(time.Second * 5))
   430  
   431  		_, cerr := conn.Write([]byte("foo\n"))
   432  		require.NoError(t, cerr)
   433  
   434  		_, cerr = conn.Write([]byte("bar\n"))
   435  		require.NoError(t, cerr)
   436  
   437  		_, cerr = conn.Write([]byte("\n"))
   438  		require.NoError(t, cerr)
   439  
   440  		_, cerr = conn.Write([]byte("baz\n"))
   441  		require.NoError(t, cerr)
   442  
   443  		conn.Close()
   444  		wg.Done()
   445  	}()
   446  
   447  	readNextMsg := func() (types.Message, error) {
   448  		var tran types.Transaction
   449  		select {
   450  		case tran = <-rdr.TransactionChan():
   451  			select {
   452  			case tran.ResponseChan <- response.NewAck():
   453  			case <-time.After(time.Second):
   454  				return nil, errors.New("timed out")
   455  			}
   456  		case <-time.After(time.Second):
   457  			return nil, errors.New("timed out")
   458  		}
   459  		return tran.Payload, nil
   460  	}
   461  
   462  	exp := [][]byte{[]byte("foo"), []byte("bar")}
   463  	msg, err := readNextMsg()
   464  	require.NoError(t, err)
   465  	assert.Equal(t, exp, message.GetAllBytes(msg))
   466  
   467  	exp = [][]byte{[]byte("baz")}
   468  	msg, err = readNextMsg()
   469  	require.NoError(t, err)
   470  	assert.Equal(t, exp, message.GetAllBytes(msg))
   471  
   472  	wg.Wait()
   473  }
   474  
   475  func TestSocketUDPServerBasic(t *testing.T) {
   476  	conf := NewConfig()
   477  	conf.SocketServer.Network = "udp"
   478  	conf.SocketServer.Address = "127.0.0.1:0"
   479  
   480  	rdr, err := NewSocketServer(conf, nil, log.Noop(), metrics.Noop())
   481  	require.NoError(t, err)
   482  
   483  	addr := rdr.(*SocketServer).Addr()
   484  
   485  	defer func() {
   486  		rdr.CloseAsync()
   487  		assert.NoError(t, rdr.WaitForClose(time.Second))
   488  	}()
   489  
   490  	conn, err := net.Dial("udp", addr.String())
   491  	require.NoError(t, err)
   492  
   493  	wg := sync.WaitGroup{}
   494  	wg.Add(1)
   495  	go func() {
   496  		conn.SetWriteDeadline(time.Now().Add(time.Second * 5))
   497  
   498  		_, cerr := conn.Write([]byte("foo\n"))
   499  		require.NoError(t, cerr)
   500  
   501  		_, cerr = conn.Write([]byte("bar\n"))
   502  		require.NoError(t, cerr)
   503  
   504  		_, cerr = conn.Write([]byte("baz\n"))
   505  		require.NoError(t, cerr)
   506  
   507  		wg.Done()
   508  	}()
   509  
   510  	readNextMsg := func() (types.Message, error) {
   511  		var tran types.Transaction
   512  		select {
   513  		case tran = <-rdr.TransactionChan():
   514  			select {
   515  			case tran.ResponseChan <- response.NewAck():
   516  			case <-time.After(time.Second):
   517  				return nil, errors.New("timed out")
   518  			}
   519  		case <-time.After(time.Second):
   520  			return nil, errors.New("timed out")
   521  		}
   522  		return tran.Payload, nil
   523  	}
   524  
   525  	exp := [][]byte{[]byte("foo")}
   526  	msg, err := readNextMsg()
   527  	require.NoError(t, err)
   528  	assert.Equal(t, exp, message.GetAllBytes(msg))
   529  
   530  	exp = [][]byte{[]byte("bar")}
   531  	msg, err = readNextMsg()
   532  	require.NoError(t, err)
   533  	assert.Equal(t, exp, message.GetAllBytes(msg))
   534  
   535  	exp = [][]byte{[]byte("baz")}
   536  	msg, err = readNextMsg()
   537  	require.NoError(t, err)
   538  	assert.Equal(t, exp, message.GetAllBytes(msg))
   539  
   540  	wg.Wait()
   541  	conn.Close()
   542  }
   543  
   544  func TestSocketUDPServerRetries(t *testing.T) {
   545  	conf := NewConfig()
   546  	conf.SocketServer.Network = "udp"
   547  	conf.SocketServer.Address = "127.0.0.1:0"
   548  
   549  	rdr, err := NewSocketServer(conf, nil, log.Noop(), metrics.Noop())
   550  	require.NoError(t, err)
   551  
   552  	addr := rdr.(*SocketServer).Addr()
   553  
   554  	defer func() {
   555  		rdr.CloseAsync()
   556  		assert.NoError(t, rdr.WaitForClose(time.Second))
   557  	}()
   558  
   559  	conn, err := net.Dial("udp", addr.String())
   560  	require.NoError(t, err)
   561  
   562  	wg := sync.WaitGroup{}
   563  	wg.Add(1)
   564  	go func() {
   565  		conn.SetWriteDeadline(time.Now().Add(time.Second * 5))
   566  
   567  		_, cerr := conn.Write([]byte("foo\n"))
   568  		require.NoError(t, cerr)
   569  
   570  		_, cerr = conn.Write([]byte("bar\n"))
   571  		require.NoError(t, cerr)
   572  
   573  		_, cerr = conn.Write([]byte("baz\n"))
   574  		require.NoError(t, cerr)
   575  
   576  		wg.Done()
   577  	}()
   578  
   579  	readNextMsg := func(reject bool) (types.Message, error) {
   580  		var tran types.Transaction
   581  		select {
   582  		case tran = <-rdr.TransactionChan():
   583  			var res types.Response = response.NewAck()
   584  			if reject {
   585  				res = response.NewError(errors.New("test err"))
   586  			}
   587  			select {
   588  			case tran.ResponseChan <- res:
   589  			case <-time.After(time.Second * 5):
   590  				return nil, errors.New("timed out")
   591  			}
   592  		case <-time.After(time.Second * 5):
   593  			return nil, errors.New("timed out")
   594  		}
   595  		return tran.Payload, nil
   596  	}
   597  
   598  	exp := [][]byte{[]byte("foo")}
   599  	msg, err := readNextMsg(false)
   600  	require.NoError(t, err)
   601  	assert.Equal(t, exp, message.GetAllBytes(msg))
   602  
   603  	exp = [][]byte{[]byte("bar")}
   604  	msg, err = readNextMsg(true)
   605  	require.NoError(t, err)
   606  	assert.Equal(t, exp, message.GetAllBytes(msg))
   607  
   608  	expRemaining := []string{"bar", "baz"}
   609  	actRemaining := []string{}
   610  
   611  	msg, err = readNextMsg(false)
   612  	require.NoError(t, err)
   613  	require.Equal(t, 1, msg.Len())
   614  	actRemaining = append(actRemaining, string(msg.Get(0).Get()))
   615  
   616  	msg, err = readNextMsg(false)
   617  	require.NoError(t, err)
   618  	require.Equal(t, 1, msg.Len())
   619  	actRemaining = append(actRemaining, string(msg.Get(0).Get()))
   620  
   621  	sort.Strings(actRemaining)
   622  	assert.Equal(t, expRemaining, actRemaining)
   623  
   624  	wg.Wait()
   625  	conn.Close()
   626  }
   627  
   628  func TestUDPServerWriteToClosed(t *testing.T) {
   629  	conf := NewConfig()
   630  	conf.SocketServer.Network = "udp"
   631  	conf.SocketServer.Address = "127.0.0.1:0"
   632  
   633  	rdr, err := NewSocketServer(conf, nil, log.Noop(), metrics.Noop())
   634  	require.NoError(t, err)
   635  
   636  	addr := rdr.(*SocketServer).Addr()
   637  
   638  	conn, err := net.Dial("udp", addr.String())
   639  	require.NoError(t, err)
   640  
   641  	conn.SetWriteDeadline(time.Now().Add(time.Second * 5))
   642  
   643  	rdr.CloseAsync()
   644  	assert.NoError(t, rdr.WaitForClose(time.Second*3))
   645  
   646  	// Just make sure data written doesn't panic
   647  	_, _ = conn.Write([]byte("bar\n"))
   648  
   649  	_, open := <-rdr.TransactionChan()
   650  	assert.False(t, open)
   651  
   652  	conn.Close()
   653  }
   654  
   655  func TestSocketUDPServerReconnect(t *testing.T) {
   656  	conf := NewConfig()
   657  	conf.SocketServer.Network = "udp"
   658  	conf.SocketServer.Address = "127.0.0.1:0"
   659  
   660  	rdr, err := NewSocketServer(conf, nil, log.Noop(), metrics.Noop())
   661  	require.NoError(t, err)
   662  
   663  	addr := rdr.(*SocketServer).Addr()
   664  
   665  	defer func() {
   666  		rdr.CloseAsync()
   667  		assert.NoError(t, rdr.WaitForClose(time.Second))
   668  	}()
   669  
   670  	conn, err := net.Dial("udp", addr.String())
   671  	require.NoError(t, err)
   672  
   673  	wg := sync.WaitGroup{}
   674  	wg.Add(1)
   675  	go func() {
   676  		conn.SetWriteDeadline(time.Now().Add(time.Second * 5))
   677  		_, cerr := conn.Write([]byte("foo\n"))
   678  		require.NoError(t, cerr)
   679  
   680  		conn.Close()
   681  
   682  		conn, cerr = net.Dial("udp", addr.String())
   683  		require.NoError(t, cerr)
   684  
   685  		_, cerr = conn.Write([]byte("bar\n"))
   686  		require.NoError(t, cerr)
   687  
   688  		_, cerr = conn.Write([]byte("baz\n"))
   689  		require.NoError(t, cerr)
   690  
   691  		wg.Done()
   692  	}()
   693  
   694  	readNextMsg := func() (types.Message, error) {
   695  		var tran types.Transaction
   696  		select {
   697  		case tran = <-rdr.TransactionChan():
   698  			select {
   699  			case tran.ResponseChan <- response.NewAck():
   700  			case <-time.After(time.Second):
   701  				return nil, errors.New("timed out")
   702  			}
   703  		case <-time.After(time.Second):
   704  			return nil, errors.New("timed out")
   705  		}
   706  		return tran.Payload, nil
   707  	}
   708  
   709  	exp := [][]byte{[]byte("foo")}
   710  	msg, err := readNextMsg()
   711  	require.NoError(t, err)
   712  	assert.Equal(t, exp, message.GetAllBytes(msg))
   713  
   714  	exp = [][]byte{[]byte("bar")}
   715  	msg, err = readNextMsg()
   716  	require.NoError(t, err)
   717  	assert.Equal(t, exp, message.GetAllBytes(msg))
   718  
   719  	exp = [][]byte{[]byte("baz")}
   720  	msg, err = readNextMsg()
   721  	require.NoError(t, err)
   722  	assert.Equal(t, exp, message.GetAllBytes(msg))
   723  
   724  	wg.Wait()
   725  	conn.Close()
   726  }
   727  
   728  func TestSocketUDPServerCustomDelim(t *testing.T) {
   729  	conf := NewConfig()
   730  	conf.SocketServer.Network = "udp"
   731  	conf.SocketServer.Address = "127.0.0.1:0"
   732  	conf.SocketServer.Delim = "@"
   733  
   734  	rdr, err := NewSocketServer(conf, nil, log.Noop(), metrics.Noop())
   735  	require.NoError(t, err)
   736  
   737  	addr := rdr.(*SocketServer).Addr()
   738  
   739  	defer func() {
   740  		rdr.CloseAsync()
   741  		assert.NoError(t, rdr.WaitForClose(time.Second))
   742  	}()
   743  
   744  	conn, err := net.Dial("udp", addr.String())
   745  	require.NoError(t, err)
   746  
   747  	wg := sync.WaitGroup{}
   748  	wg.Add(1)
   749  	go func() {
   750  		conn.SetWriteDeadline(time.Now().Add(time.Second * 5))
   751  
   752  		_, cerr := conn.Write([]byte("foo@"))
   753  		require.NoError(t, cerr)
   754  
   755  		_, cerr = conn.Write([]byte("bar@"))
   756  		require.NoError(t, cerr)
   757  
   758  		_, cerr = conn.Write([]byte("@"))
   759  		require.NoError(t, cerr)
   760  
   761  		_, cerr = conn.Write([]byte("baz\n@@"))
   762  		require.NoError(t, cerr)
   763  
   764  		wg.Done()
   765  	}()
   766  
   767  	readNextMsg := func() (types.Message, error) {
   768  		var tran types.Transaction
   769  		select {
   770  		case tran = <-rdr.TransactionChan():
   771  			select {
   772  			case tran.ResponseChan <- response.NewAck():
   773  			case <-time.After(time.Second):
   774  				return nil, errors.New("timed out")
   775  			}
   776  		case <-time.After(time.Second):
   777  			return nil, errors.New("timed out")
   778  		}
   779  		return tran.Payload, nil
   780  	}
   781  
   782  	exp := [][]byte{[]byte("foo")}
   783  	msg, err := readNextMsg()
   784  	require.NoError(t, err)
   785  	assert.Equal(t, exp, message.GetAllBytes(msg))
   786  
   787  	exp = [][]byte{[]byte("bar")}
   788  	msg, err = readNextMsg()
   789  	require.NoError(t, err)
   790  	assert.Equal(t, exp, message.GetAllBytes(msg))
   791  
   792  	exp = [][]byte{[]byte("")}
   793  	msg, err = readNextMsg()
   794  	require.NoError(t, err)
   795  	assert.Equal(t, exp, message.GetAllBytes(msg))
   796  
   797  	exp = [][]byte{[]byte("baz\n")}
   798  	msg, err = readNextMsg()
   799  	require.NoError(t, err)
   800  	assert.Equal(t, exp, message.GetAllBytes(msg))
   801  
   802  	wg.Wait()
   803  	conn.Close()
   804  }
   805  
   806  func TestSocketUDPServerShutdown(t *testing.T) {
   807  	conf := NewConfig()
   808  	conf.SocketServer.Network = "udp"
   809  	conf.SocketServer.Address = "127.0.0.1:0"
   810  
   811  	rdr, err := NewSocketServer(conf, nil, log.Noop(), metrics.Noop())
   812  	require.NoError(t, err)
   813  
   814  	addr := rdr.(*SocketServer).Addr()
   815  
   816  	defer func() {
   817  		rdr.CloseAsync()
   818  		assert.NoError(t, rdr.WaitForClose(time.Second))
   819  	}()
   820  
   821  	conn, err := net.Dial("udp", addr.String())
   822  	require.NoError(t, err)
   823  
   824  	wg := sync.WaitGroup{}
   825  	wg.Add(1)
   826  	go func() {
   827  		conn.SetWriteDeadline(time.Now().Add(time.Second * 5))
   828  
   829  		_, cerr := conn.Write([]byte("foo\n"))
   830  		require.NoError(t, cerr)
   831  
   832  		_, cerr = conn.Write([]byte("bar\n"))
   833  		require.NoError(t, cerr)
   834  
   835  		_, cerr = conn.Write([]byte("\n"))
   836  		require.NoError(t, cerr)
   837  
   838  		_, cerr = conn.Write([]byte("baz\n"))
   839  		require.NoError(t, cerr)
   840  
   841  		conn.Close()
   842  		wg.Done()
   843  	}()
   844  
   845  	readNextMsg := func() (types.Message, error) {
   846  		var tran types.Transaction
   847  		select {
   848  		case tran = <-rdr.TransactionChan():
   849  			select {
   850  			case tran.ResponseChan <- response.NewAck():
   851  			case <-time.After(time.Second):
   852  				return nil, errors.New("timed out")
   853  			}
   854  		case <-time.After(time.Second):
   855  			return nil, errors.New("timed out")
   856  		}
   857  		return tran.Payload, nil
   858  	}
   859  
   860  	exp := [][]byte{[]byte("foo")}
   861  	msg, err := readNextMsg()
   862  	require.NoError(t, err)
   863  	assert.Equal(t, exp, message.GetAllBytes(msg))
   864  
   865  	exp = [][]byte{[]byte("bar")}
   866  	msg, err = readNextMsg()
   867  	require.NoError(t, err)
   868  	assert.Equal(t, exp, message.GetAllBytes(msg))
   869  
   870  	exp = [][]byte{[]byte("")}
   871  	msg, err = readNextMsg()
   872  	require.NoError(t, err)
   873  	assert.Equal(t, exp, message.GetAllBytes(msg))
   874  
   875  	exp = [][]byte{[]byte("baz")}
   876  	msg, err = readNextMsg()
   877  	require.NoError(t, err)
   878  	assert.Equal(t, exp, message.GetAllBytes(msg))
   879  
   880  	wg.Wait()
   881  }
   882  
   883  func TestTCPSocketServerBasic(t *testing.T) {
   884  	conf := NewConfig()
   885  	conf.SocketServer.Network = "tcp"
   886  	conf.SocketServer.Address = "127.0.0.1:0"
   887  
   888  	rdr, err := NewSocketServer(conf, nil, log.Noop(), metrics.Noop())
   889  	require.NoError(t, err)
   890  
   891  	addr := rdr.(*SocketServer).Addr()
   892  
   893  	defer func() {
   894  		rdr.CloseAsync()
   895  		assert.NoError(t, rdr.WaitForClose(time.Second))
   896  	}()
   897  
   898  	conn, err := net.Dial("tcp", addr.String())
   899  	require.NoError(t, err)
   900  
   901  	wg := sync.WaitGroup{}
   902  	wg.Add(1)
   903  	go func() {
   904  		conn.SetWriteDeadline(time.Now().Add(time.Second * 5))
   905  
   906  		_, cerr := conn.Write([]byte("foo\n"))
   907  		require.NoError(t, cerr)
   908  
   909  		_, cerr = conn.Write([]byte("bar\n"))
   910  		require.NoError(t, cerr)
   911  
   912  		_, cerr = conn.Write([]byte("baz\n"))
   913  		require.NoError(t, cerr)
   914  
   915  		wg.Done()
   916  	}()
   917  
   918  	readNextMsg := func() (types.Message, error) {
   919  		var tran types.Transaction
   920  		select {
   921  		case tran = <-rdr.TransactionChan():
   922  			select {
   923  			case tran.ResponseChan <- response.NewAck():
   924  			case <-time.After(time.Second):
   925  				return nil, errors.New("timed out")
   926  			}
   927  		case <-time.After(time.Second):
   928  			return nil, errors.New("timed out")
   929  		}
   930  		return tran.Payload, nil
   931  	}
   932  
   933  	exp := [][]byte{[]byte("foo")}
   934  	msg, err := readNextMsg()
   935  	require.NoError(t, err)
   936  	assert.Equal(t, exp, message.GetAllBytes(msg))
   937  
   938  	exp = [][]byte{[]byte("bar")}
   939  	msg, err = readNextMsg()
   940  	require.NoError(t, err)
   941  	assert.Equal(t, exp, message.GetAllBytes(msg))
   942  
   943  	exp = [][]byte{[]byte("baz")}
   944  	msg, err = readNextMsg()
   945  	require.NoError(t, err)
   946  	assert.Equal(t, exp, message.GetAllBytes(msg))
   947  
   948  	wg.Wait()
   949  	conn.Close()
   950  }
   951  
   952  func TestTCPSocketServerReconnect(t *testing.T) {
   953  	conf := NewConfig()
   954  	conf.SocketServer.Network = "tcp"
   955  	conf.SocketServer.Address = "127.0.0.1:0"
   956  
   957  	rdr, err := NewSocketServer(conf, nil, log.Noop(), metrics.Noop())
   958  	require.NoError(t, err)
   959  
   960  	addr := rdr.(*SocketServer).Addr()
   961  
   962  	defer func() {
   963  		rdr.CloseAsync()
   964  		assert.NoError(t, rdr.WaitForClose(time.Second))
   965  	}()
   966  
   967  	conn, err := net.Dial("tcp", addr.String())
   968  	require.NoError(t, err)
   969  
   970  	wg := sync.WaitGroup{}
   971  	wg.Add(1)
   972  	go func() {
   973  		conn.SetWriteDeadline(time.Now().Add(time.Second * 5))
   974  
   975  		_, cerr := conn.Write([]byte("foo\n"))
   976  		require.NoError(t, cerr)
   977  
   978  		conn.Close()
   979  
   980  		conn, cerr = net.Dial("tcp", addr.String())
   981  		require.NoError(t, cerr)
   982  
   983  		_, cerr = conn.Write([]byte("bar\n"))
   984  		require.NoError(t, cerr)
   985  
   986  		_, cerr = conn.Write([]byte("baz\n"))
   987  		require.NoError(t, cerr)
   988  
   989  		wg.Done()
   990  	}()
   991  
   992  	readNextMsg := func() (types.Message, error) {
   993  		var tran types.Transaction
   994  		select {
   995  		case tran = <-rdr.TransactionChan():
   996  			select {
   997  			case tran.ResponseChan <- response.NewAck():
   998  			case <-time.After(time.Second):
   999  				return nil, errors.New("timed out")
  1000  			}
  1001  		case <-time.After(time.Second):
  1002  			return nil, errors.New("timed out")
  1003  		}
  1004  		return tran.Payload, nil
  1005  	}
  1006  
  1007  	expMsgs := map[string]struct{}{
  1008  		"foo": {},
  1009  		"bar": {},
  1010  		"baz": {},
  1011  	}
  1012  
  1013  	for i := 0; i < 3; i++ {
  1014  		msg, err := readNextMsg()
  1015  		require.NoError(t, err)
  1016  
  1017  		act := string(msg.Get(0).Get())
  1018  		assert.Contains(t, expMsgs, act)
  1019  		delete(expMsgs, act)
  1020  	}
  1021  
  1022  	wg.Wait()
  1023  	conn.Close()
  1024  }
  1025  
  1026  func TestTCPSocketServerMultipart(t *testing.T) {
  1027  	conf := NewConfig()
  1028  	conf.SocketServer.Network = "tcp"
  1029  	conf.SocketServer.Address = "127.0.0.1:0"
  1030  	conf.SocketServer.Multipart = true
  1031  
  1032  	rdr, err := NewSocketServer(conf, nil, log.Noop(), metrics.Noop())
  1033  	require.NoError(t, err)
  1034  
  1035  	addr := rdr.(*SocketServer).Addr()
  1036  
  1037  	defer func() {
  1038  		rdr.CloseAsync()
  1039  		assert.NoError(t, rdr.WaitForClose(time.Second))
  1040  	}()
  1041  
  1042  	conn, err := net.Dial("tcp", addr.String())
  1043  	require.NoError(t, err)
  1044  
  1045  	wg := sync.WaitGroup{}
  1046  	wg.Add(1)
  1047  	go func() {
  1048  		conn.SetWriteDeadline(time.Now().Add(time.Second * 5))
  1049  
  1050  		_, cerr := conn.Write([]byte("foo\n"))
  1051  		require.NoError(t, cerr)
  1052  
  1053  		_, cerr = conn.Write([]byte("bar\n"))
  1054  		require.NoError(t, cerr)
  1055  
  1056  		_, cerr = conn.Write([]byte("\n"))
  1057  		require.NoError(t, cerr)
  1058  
  1059  		_, cerr = conn.Write([]byte("baz\n\n"))
  1060  		require.NoError(t, cerr)
  1061  
  1062  		wg.Done()
  1063  	}()
  1064  
  1065  	readNextMsg := func() (types.Message, error) {
  1066  		var tran types.Transaction
  1067  		select {
  1068  		case tran = <-rdr.TransactionChan():
  1069  			select {
  1070  			case tran.ResponseChan <- response.NewAck():
  1071  			case <-time.After(time.Second):
  1072  				return nil, errors.New("timed out")
  1073  			}
  1074  		case <-time.After(time.Second):
  1075  			return nil, errors.New("timed out")
  1076  		}
  1077  		return tran.Payload, nil
  1078  	}
  1079  
  1080  	exp := [][]byte{[]byte("foo"), []byte("bar")}
  1081  	msg, err := readNextMsg()
  1082  	require.NoError(t, err)
  1083  	assert.Equal(t, exp, message.GetAllBytes(msg))
  1084  
  1085  	exp = [][]byte{[]byte("baz")}
  1086  	msg, err = readNextMsg()
  1087  	require.NoError(t, err)
  1088  	assert.Equal(t, exp, message.GetAllBytes(msg))
  1089  
  1090  	wg.Wait()
  1091  	conn.Close()
  1092  }
  1093  
  1094  func TestTCPSocketServerMultipartCustomDelim(t *testing.T) {
  1095  	conf := NewConfig()
  1096  	conf.SocketServer.Network = "tcp"
  1097  	conf.SocketServer.Address = "127.0.0.1:0"
  1098  	conf.SocketServer.Multipart = true
  1099  	conf.SocketServer.Delim = "@"
  1100  
  1101  	rdr, err := NewSocketServer(conf, nil, log.Noop(), metrics.Noop())
  1102  	require.NoError(t, err)
  1103  
  1104  	addr := rdr.(*SocketServer).Addr()
  1105  
  1106  	defer func() {
  1107  		rdr.CloseAsync()
  1108  		assert.NoError(t, rdr.WaitForClose(time.Second))
  1109  	}()
  1110  
  1111  	conn, err := net.Dial("tcp", addr.String())
  1112  	require.NoError(t, err)
  1113  
  1114  	wg := sync.WaitGroup{}
  1115  	wg.Add(1)
  1116  	go func() {
  1117  		conn.SetWriteDeadline(time.Now().Add(time.Second * 5))
  1118  
  1119  		_, cerr := conn.Write([]byte("foo@"))
  1120  		require.NoError(t, cerr)
  1121  
  1122  		_, cerr = conn.Write([]byte("bar@"))
  1123  		require.NoError(t, cerr)
  1124  
  1125  		_, cerr = conn.Write([]byte("@"))
  1126  		require.NoError(t, cerr)
  1127  
  1128  		_, cerr = conn.Write([]byte("baz\n@@"))
  1129  		require.NoError(t, cerr)
  1130  
  1131  		wg.Done()
  1132  	}()
  1133  
  1134  	readNextMsg := func() (types.Message, error) {
  1135  		var tran types.Transaction
  1136  		select {
  1137  		case tran = <-rdr.TransactionChan():
  1138  			select {
  1139  			case tran.ResponseChan <- response.NewAck():
  1140  			case <-time.After(time.Second):
  1141  				return nil, errors.New("timed out")
  1142  			}
  1143  		case <-time.After(time.Second):
  1144  			return nil, errors.New("timed out")
  1145  		}
  1146  		return tran.Payload, nil
  1147  	}
  1148  
  1149  	exp := [][]byte{[]byte("foo"), []byte("bar")}
  1150  	msg, err := readNextMsg()
  1151  	require.NoError(t, err)
  1152  	assert.Equal(t, exp, message.GetAllBytes(msg))
  1153  
  1154  	exp = [][]byte{[]byte("baz\n")}
  1155  	msg, err = readNextMsg()
  1156  	require.NoError(t, err)
  1157  	assert.Equal(t, exp, message.GetAllBytes(msg))
  1158  
  1159  	wg.Wait()
  1160  	conn.Close()
  1161  }
  1162  
  1163  func TestTCPSocketServerMultipartShutdown(t *testing.T) {
  1164  	conf := NewConfig()
  1165  	conf.SocketServer.Network = "tcp"
  1166  	conf.SocketServer.Address = "127.0.0.1:0"
  1167  	conf.SocketServer.Multipart = true
  1168  
  1169  	rdr, err := NewSocketServer(conf, nil, log.Noop(), metrics.Noop())
  1170  	require.NoError(t, err)
  1171  
  1172  	addr := rdr.(*SocketServer).Addr()
  1173  
  1174  	defer func() {
  1175  		rdr.CloseAsync()
  1176  		assert.NoError(t, rdr.WaitForClose(time.Second))
  1177  	}()
  1178  
  1179  	conn, err := net.Dial("tcp", addr.String())
  1180  	require.NoError(t, err)
  1181  
  1182  	wg := sync.WaitGroup{}
  1183  	wg.Add(1)
  1184  	go func() {
  1185  		conn.SetWriteDeadline(time.Now().Add(time.Second * 5))
  1186  
  1187  		_, cerr := conn.Write([]byte("foo\n"))
  1188  		require.NoError(t, cerr)
  1189  
  1190  		_, cerr = conn.Write([]byte("bar\n"))
  1191  		require.NoError(t, cerr)
  1192  
  1193  		_, cerr = conn.Write([]byte("\n"))
  1194  		require.NoError(t, cerr)
  1195  
  1196  		_, cerr = conn.Write([]byte("baz\n"))
  1197  		require.NoError(t, cerr)
  1198  
  1199  		conn.Close()
  1200  		wg.Done()
  1201  	}()
  1202  
  1203  	readNextMsg := func() (types.Message, error) {
  1204  		var tran types.Transaction
  1205  		select {
  1206  		case tran = <-rdr.TransactionChan():
  1207  			select {
  1208  			case tran.ResponseChan <- response.NewAck():
  1209  			case <-time.After(time.Second):
  1210  				return nil, errors.New("timed out")
  1211  			}
  1212  		case <-time.After(time.Second):
  1213  			return nil, errors.New("timed out")
  1214  		}
  1215  		return tran.Payload, nil
  1216  	}
  1217  
  1218  	exp := [][]byte{[]byte("foo"), []byte("bar")}
  1219  	msg, err := readNextMsg()
  1220  	require.NoError(t, err)
  1221  	assert.Equal(t, exp, message.GetAllBytes(msg))
  1222  
  1223  	exp = [][]byte{[]byte("baz")}
  1224  	msg, err = readNextMsg()
  1225  	require.NoError(t, err)
  1226  	assert.Equal(t, exp, message.GetAllBytes(msg))
  1227  
  1228  	wg.Wait()
  1229  }