github.com/Jeffail/benthos/v3@v3.65.0/lib/input/tcp_server_test.go (about)

     1  package input
     2  
     3  import (
     4  	"errors"
     5  	"net"
     6  	"reflect"
     7  	"sync"
     8  	"testing"
     9  	"time"
    10  
    11  	"github.com/Jeffail/benthos/v3/lib/log"
    12  	"github.com/Jeffail/benthos/v3/lib/message"
    13  	"github.com/Jeffail/benthos/v3/lib/metrics"
    14  	"github.com/Jeffail/benthos/v3/lib/response"
    15  	"github.com/Jeffail/benthos/v3/lib/types"
    16  )
    17  
    18  func TestTCPServerBasic(t *testing.T) {
    19  	conf := NewConfig()
    20  	conf.TCPServer.Address = "127.0.0.1:0"
    21  
    22  	rdr, err := NewTCPServer(conf, nil, log.Noop(), metrics.Noop())
    23  	if err != nil {
    24  		t.Fatal(err)
    25  	}
    26  	addr := rdr.(*TCPServer).Addr()
    27  
    28  	defer func() {
    29  		rdr.CloseAsync()
    30  		if err := rdr.WaitForClose(time.Second); err != nil {
    31  			t.Error(err)
    32  		}
    33  	}()
    34  
    35  	conn, err := net.Dial("tcp", addr.String())
    36  	if err != nil {
    37  		t.Fatal(err)
    38  	}
    39  
    40  	wg := sync.WaitGroup{}
    41  	wg.Add(1)
    42  	go func() {
    43  		conn.SetWriteDeadline(time.Now().Add(time.Second * 5))
    44  		if _, cerr := conn.Write([]byte("foo\n")); cerr != nil {
    45  			t.Error(cerr)
    46  		}
    47  		if _, cerr := conn.Write([]byte("bar\n")); cerr != nil {
    48  			t.Error(cerr)
    49  		}
    50  		if _, cerr := conn.Write([]byte("baz\n")); cerr != nil {
    51  			t.Error(cerr)
    52  		}
    53  		wg.Done()
    54  	}()
    55  
    56  	readNextMsg := func() (types.Message, error) {
    57  		var tran types.Transaction
    58  		select {
    59  		case tran = <-rdr.TransactionChan():
    60  			select {
    61  			case tran.ResponseChan <- response.NewAck():
    62  			case <-time.After(time.Second):
    63  				return nil, errors.New("timed out")
    64  			}
    65  		case <-time.After(time.Second):
    66  			return nil, errors.New("timed out")
    67  		}
    68  		return tran.Payload, nil
    69  	}
    70  
    71  	exp := [][]byte{[]byte("foo")}
    72  	msg, err := readNextMsg()
    73  	if err != nil {
    74  		t.Fatal(err)
    75  	}
    76  	if act := message.GetAllBytes(msg); !reflect.DeepEqual(exp, act) {
    77  		t.Errorf("Wrong message contents: %s != %s", act, exp)
    78  	}
    79  
    80  	exp = [][]byte{[]byte("bar")}
    81  	if msg, err = readNextMsg(); err != nil {
    82  		t.Fatal(err)
    83  	}
    84  	if act := message.GetAllBytes(msg); !reflect.DeepEqual(exp, act) {
    85  		t.Errorf("Wrong message contents: %s != %s", act, exp)
    86  	}
    87  
    88  	exp = [][]byte{[]byte("baz")}
    89  	if msg, err = readNextMsg(); err != nil {
    90  		t.Fatal(err)
    91  	}
    92  	if act := message.GetAllBytes(msg); !reflect.DeepEqual(exp, act) {
    93  		t.Errorf("Wrong message contents: %s != %s", act, exp)
    94  	}
    95  
    96  	wg.Wait()
    97  	conn.Close()
    98  }
    99  
   100  func TestTCPServerReconnect(t *testing.T) {
   101  	conf := NewConfig()
   102  	conf.TCPServer.Address = "127.0.0.1:0"
   103  
   104  	rdr, err := NewTCPServer(conf, nil, log.Noop(), metrics.Noop())
   105  	if err != nil {
   106  		t.Fatal(err)
   107  	}
   108  	addr := rdr.(*TCPServer).Addr()
   109  
   110  	defer func() {
   111  		rdr.CloseAsync()
   112  		if err := rdr.WaitForClose(time.Second); err != nil {
   113  			t.Error(err)
   114  		}
   115  	}()
   116  
   117  	conn, err := net.Dial("tcp", addr.String())
   118  	if err != nil {
   119  		t.Fatal(err)
   120  	}
   121  
   122  	wg := sync.WaitGroup{}
   123  	wg.Add(1)
   124  	go func() {
   125  		conn.SetWriteDeadline(time.Now().Add(time.Second * 5))
   126  		_, cerr := conn.Write([]byte("foo\n"))
   127  		if cerr != nil {
   128  			t.Error(cerr)
   129  		}
   130  		conn.Close()
   131  		conn, cerr = net.Dial("tcp", addr.String())
   132  		if cerr != nil {
   133  			t.Error(cerr)
   134  		}
   135  		if _, cerr := conn.Write([]byte("bar\n")); cerr != nil {
   136  			t.Error(cerr)
   137  		}
   138  		if _, cerr := conn.Write([]byte("baz\n")); cerr != nil {
   139  			t.Error(cerr)
   140  		}
   141  
   142  		wg.Done()
   143  	}()
   144  
   145  	readNextMsg := func() (types.Message, error) {
   146  		var tran types.Transaction
   147  		select {
   148  		case tran = <-rdr.TransactionChan():
   149  			select {
   150  			case tran.ResponseChan <- response.NewAck():
   151  			case <-time.After(time.Second):
   152  				return nil, errors.New("timed out")
   153  			}
   154  		case <-time.After(time.Second):
   155  			return nil, errors.New("timed out")
   156  		}
   157  		return tran.Payload, nil
   158  	}
   159  
   160  	expMsgs := map[string]struct{}{
   161  		"foo": {},
   162  		"bar": {},
   163  		"baz": {},
   164  	}
   165  
   166  	for i := 0; i < 3; i++ {
   167  		msg, err := readNextMsg()
   168  		if err != nil {
   169  			t.Fatal(err)
   170  		}
   171  		act := string(msg.Get(0).Get())
   172  		if _, exists := expMsgs[act]; !exists {
   173  			t.Errorf("Unexpected message: %v", act)
   174  		}
   175  		delete(expMsgs, act)
   176  	}
   177  
   178  	wg.Wait()
   179  	conn.Close()
   180  }
   181  
   182  func TestTCPServerMultipart(t *testing.T) {
   183  	conf := NewConfig()
   184  	conf.TCPServer.Address = "127.0.0.1:0"
   185  	conf.TCPServer.Multipart = true
   186  
   187  	rdr, err := NewTCPServer(conf, nil, log.Noop(), metrics.Noop())
   188  	if err != nil {
   189  		t.Fatal(err)
   190  	}
   191  	addr := rdr.(*TCPServer).Addr()
   192  
   193  	defer func() {
   194  		rdr.CloseAsync()
   195  		if err := rdr.WaitForClose(time.Second); err != nil {
   196  			t.Error(err)
   197  		}
   198  	}()
   199  
   200  	conn, err := net.Dial("tcp", addr.String())
   201  	if err != nil {
   202  		t.Fatal(err)
   203  	}
   204  
   205  	wg := sync.WaitGroup{}
   206  	wg.Add(1)
   207  	go func() {
   208  		conn.SetWriteDeadline(time.Now().Add(time.Second * 5))
   209  		if _, cerr := conn.Write([]byte("foo\n")); cerr != nil {
   210  			t.Error(cerr)
   211  		}
   212  		if _, cerr := conn.Write([]byte("bar\n")); cerr != nil {
   213  			t.Error(cerr)
   214  		}
   215  		if _, cerr := conn.Write([]byte("\n")); cerr != nil {
   216  			t.Error(cerr)
   217  		}
   218  		if _, cerr := conn.Write([]byte("baz\n\n")); cerr != nil {
   219  			t.Error(cerr)
   220  		}
   221  		wg.Done()
   222  	}()
   223  
   224  	readNextMsg := func() (types.Message, error) {
   225  		var tran types.Transaction
   226  		select {
   227  		case tran = <-rdr.TransactionChan():
   228  			select {
   229  			case tran.ResponseChan <- response.NewAck():
   230  			case <-time.After(time.Second):
   231  				return nil, errors.New("timed out")
   232  			}
   233  		case <-time.After(time.Second):
   234  			return nil, errors.New("timed out")
   235  		}
   236  		return tran.Payload, nil
   237  	}
   238  
   239  	exp := [][]byte{[]byte("foo"), []byte("bar")}
   240  	msg, err := readNextMsg()
   241  	if err != nil {
   242  		t.Fatal(err)
   243  	}
   244  	if act := message.GetAllBytes(msg); !reflect.DeepEqual(exp, act) {
   245  		t.Errorf("Wrong message contents: %s != %s", act, exp)
   246  	}
   247  
   248  	exp = [][]byte{[]byte("baz")}
   249  	if msg, err = readNextMsg(); err != nil {
   250  		t.Fatal(err)
   251  	}
   252  	if act := message.GetAllBytes(msg); !reflect.DeepEqual(exp, act) {
   253  		t.Errorf("Wrong message contents: %s != %s", act, exp)
   254  	}
   255  
   256  	wg.Wait()
   257  	conn.Close()
   258  }
   259  
   260  func TestTCPServerMultipartCustomDelim(t *testing.T) {
   261  	conf := NewConfig()
   262  	conf.TCPServer.Address = "127.0.0.1:0"
   263  	conf.TCPServer.Multipart = true
   264  	conf.TCPServer.Delim = "@"
   265  
   266  	rdr, err := NewTCPServer(conf, nil, log.Noop(), metrics.Noop())
   267  	if err != nil {
   268  		t.Fatal(err)
   269  	}
   270  	addr := rdr.(*TCPServer).Addr()
   271  
   272  	defer func() {
   273  		rdr.CloseAsync()
   274  		if err := rdr.WaitForClose(time.Second); err != nil {
   275  			t.Error(err)
   276  		}
   277  	}()
   278  
   279  	conn, err := net.Dial("tcp", addr.String())
   280  	if err != nil {
   281  		t.Fatal(err)
   282  	}
   283  
   284  	wg := sync.WaitGroup{}
   285  	wg.Add(1)
   286  	go func() {
   287  		conn.SetWriteDeadline(time.Now().Add(time.Second * 5))
   288  		if _, cerr := conn.Write([]byte("foo@")); cerr != nil {
   289  			t.Error(cerr)
   290  		}
   291  		if _, cerr := conn.Write([]byte("bar@")); cerr != nil {
   292  			t.Error(cerr)
   293  		}
   294  		if _, cerr := conn.Write([]byte("@")); cerr != nil {
   295  			t.Error(cerr)
   296  		}
   297  		if _, cerr := conn.Write([]byte("baz\n@@")); cerr != nil {
   298  			t.Error(cerr)
   299  		}
   300  		wg.Done()
   301  	}()
   302  
   303  	readNextMsg := func() (types.Message, error) {
   304  		var tran types.Transaction
   305  		select {
   306  		case tran = <-rdr.TransactionChan():
   307  			select {
   308  			case tran.ResponseChan <- response.NewAck():
   309  			case <-time.After(time.Second):
   310  				return nil, errors.New("timed out")
   311  			}
   312  		case <-time.After(time.Second):
   313  			return nil, errors.New("timed out")
   314  		}
   315  		return tran.Payload, nil
   316  	}
   317  
   318  	exp := [][]byte{[]byte("foo"), []byte("bar")}
   319  	msg, err := readNextMsg()
   320  	if err != nil {
   321  		t.Fatal(err)
   322  	}
   323  	if act := message.GetAllBytes(msg); !reflect.DeepEqual(exp, act) {
   324  		t.Errorf("Wrong message contents: %s != %s", act, exp)
   325  	}
   326  
   327  	exp = [][]byte{[]byte("baz\n")}
   328  	if msg, err = readNextMsg(); err != nil {
   329  		t.Fatal(err)
   330  	}
   331  	if act := message.GetAllBytes(msg); !reflect.DeepEqual(exp, act) {
   332  		t.Errorf("Wrong message contents: %s != %s", act, exp)
   333  	}
   334  
   335  	wg.Wait()
   336  	conn.Close()
   337  }
   338  
   339  func TestTCPServerMultipartShutdown(t *testing.T) {
   340  	conf := NewConfig()
   341  	conf.TCPServer.Address = "127.0.0.1:0"
   342  	conf.TCPServer.Multipart = true
   343  
   344  	rdr, err := NewTCPServer(conf, nil, log.Noop(), metrics.Noop())
   345  	if err != nil {
   346  		t.Fatal(err)
   347  	}
   348  	addr := rdr.(*TCPServer).Addr()
   349  
   350  	defer func() {
   351  		rdr.CloseAsync()
   352  		if err := rdr.WaitForClose(time.Second); err != nil {
   353  			t.Error(err)
   354  		}
   355  	}()
   356  
   357  	conn, err := net.Dial("tcp", addr.String())
   358  	if err != nil {
   359  		t.Fatal(err)
   360  	}
   361  
   362  	wg := sync.WaitGroup{}
   363  	wg.Add(1)
   364  	go func() {
   365  		conn.SetWriteDeadline(time.Now().Add(time.Second * 5))
   366  		if _, cerr := conn.Write([]byte("foo\n")); cerr != nil {
   367  			t.Error(cerr)
   368  		}
   369  		if _, cerr := conn.Write([]byte("bar\n")); cerr != nil {
   370  			t.Error(cerr)
   371  		}
   372  		if _, cerr := conn.Write([]byte("\n")); cerr != nil {
   373  			t.Error(cerr)
   374  		}
   375  		if _, cerr := conn.Write([]byte("baz\n")); cerr != nil {
   376  			t.Error(cerr)
   377  		}
   378  		conn.Close()
   379  		wg.Done()
   380  	}()
   381  
   382  	readNextMsg := func() (types.Message, error) {
   383  		var tran types.Transaction
   384  		select {
   385  		case tran = <-rdr.TransactionChan():
   386  			select {
   387  			case tran.ResponseChan <- response.NewAck():
   388  			case <-time.After(time.Second):
   389  				return nil, errors.New("timed out")
   390  			}
   391  		case <-time.After(time.Second):
   392  			return nil, errors.New("timed out")
   393  		}
   394  		return tran.Payload, nil
   395  	}
   396  
   397  	exp := [][]byte{[]byte("foo"), []byte("bar")}
   398  	msg, err := readNextMsg()
   399  	if err != nil {
   400  		t.Fatal(err)
   401  	}
   402  	if act := message.GetAllBytes(msg); !reflect.DeepEqual(exp, act) {
   403  		t.Errorf("Wrong message contents: %s != %s", act, exp)
   404  	}
   405  
   406  	exp = [][]byte{[]byte("baz")}
   407  	if msg, err = readNextMsg(); err != nil {
   408  		t.Fatal(err)
   409  	}
   410  	if act := message.GetAllBytes(msg); !reflect.DeepEqual(exp, act) {
   411  		t.Errorf("Wrong message contents: %s != %s", act, exp)
   412  	}
   413  
   414  	wg.Wait()
   415  }