github.com/Jeffail/benthos/v3@v3.65.0/lib/input/reader/mqtt_test.go (about)

     1  package reader
     2  
     3  import (
     4  	"fmt"
     5  	"sync"
     6  	"testing"
     7  	"time"
     8  
     9  	"github.com/Jeffail/benthos/v3/lib/log"
    10  	"github.com/Jeffail/benthos/v3/lib/metrics"
    11  	"github.com/Jeffail/benthos/v3/lib/types"
    12  	mqtt "github.com/eclipse/paho.mqtt.golang"
    13  	"github.com/ory/dockertest/v3"
    14  )
    15  
    16  func getMQTTConn(urls []string) (mqtt.Client, error) {
    17  	inConf := mqtt.NewClientOptions().
    18  		SetClientID("UNIT_TEST")
    19  	for _, u := range urls {
    20  		inConf = inConf.AddBroker(u)
    21  	}
    22  
    23  	mIn := mqtt.NewClient(inConf)
    24  	tok := mIn.Connect()
    25  	tok.Wait()
    26  	if cErr := tok.Error(); cErr != nil {
    27  		return nil, cErr
    28  	}
    29  
    30  	return mIn, nil
    31  }
    32  
    33  func sendMQTTMsg(c mqtt.Client, topic, msg string) error {
    34  	mtok := c.Publish(topic, 2, false, msg)
    35  	mtok.Wait()
    36  	return mtok.Error()
    37  }
    38  
    39  func TestMQTTIntegration(t *testing.T) {
    40  	if testing.Short() {
    41  		t.Skip("Skipping integration test in short mode")
    42  	}
    43  	t.Skip("Skipping MQTT tests because the library crashes on shutdown")
    44  
    45  	pool, err := dockertest.NewPool("")
    46  	if err != nil {
    47  		t.Skipf("Could not connect to docker: %s", err)
    48  	}
    49  	pool.MaxWait = time.Second * 30
    50  
    51  	resource, err := pool.Run("ncarlier/mqtt", "latest", nil)
    52  	if err != nil {
    53  		t.Fatalf("Could not start resource: %s", err)
    54  	}
    55  
    56  	urls := []string{fmt.Sprintf("tcp://localhost:%v", resource.GetPort("1883/tcp"))}
    57  
    58  	if err = pool.Retry(func() error {
    59  		client, err := getMQTTConn(urls)
    60  		if err == nil {
    61  			client.Disconnect(0)
    62  		}
    63  		return err
    64  	}); err != nil {
    65  		t.Fatalf("Could not connect to docker resource: %s", err)
    66  	}
    67  
    68  	defer func() {
    69  		if err = pool.Purge(resource); err != nil {
    70  			t.Logf("Failed to clean up docker resource: %v", err)
    71  		}
    72  	}()
    73  
    74  	t.Run("TestMQTTConnect", func(te *testing.T) {
    75  		testMQTTConnect(urls, te)
    76  	})
    77  	t.Run("TestMQTTDisconnect", func(te *testing.T) {
    78  		testMQTTDisconnect(urls, te)
    79  	})
    80  }
    81  
    82  func testMQTTConnect(urls []string, t *testing.T) {
    83  	conf := NewMQTTConfig()
    84  	conf.ClientID = "foo"
    85  	conf.Topics = []string{"test_input_1"}
    86  	conf.URLs = urls
    87  
    88  	m, err := NewMQTT(conf, log.Noop(), metrics.Noop())
    89  	if err != nil {
    90  		t.Fatal(err)
    91  	}
    92  
    93  	if err = m.Connect(); err != nil {
    94  		t.Fatal(err)
    95  	}
    96  
    97  	defer func() {
    98  		m.CloseAsync()
    99  		if cErr := m.WaitForClose(time.Second); cErr != nil {
   100  			t.Error(cErr)
   101  		}
   102  	}()
   103  
   104  	var mIn mqtt.Client
   105  	if mIn, err = getMQTTConn(urls); err != nil {
   106  		t.Fatal(err)
   107  	}
   108  
   109  	defer mIn.Disconnect(0)
   110  
   111  	N := 10
   112  
   113  	wg := sync.WaitGroup{}
   114  	wg.Add(N)
   115  
   116  	testMsgs := map[string]struct{}{}
   117  	for i := 0; i < N; i++ {
   118  		str := fmt.Sprintf("hello world: %v", i)
   119  		testMsgs[str] = struct{}{}
   120  		go func(testStr string) {
   121  			if sErr := sendMQTTMsg(mIn, "test_input_1", testStr); sErr != nil {
   122  				t.Error(err)
   123  			}
   124  			wg.Done()
   125  		}(str)
   126  	}
   127  
   128  	lMsgs := len(testMsgs)
   129  	for lMsgs > 0 {
   130  		var actM types.Message
   131  		actM, err = m.Read()
   132  		if err != nil {
   133  			t.Error(err)
   134  		} else {
   135  			act := string(actM.Get(0).Get())
   136  			if _, exists := testMsgs[act]; !exists {
   137  				t.Errorf("Unexpected message: %v", act)
   138  			}
   139  			delete(testMsgs, act)
   140  		}
   141  		lMsgs = len(testMsgs)
   142  	}
   143  
   144  	wg.Wait()
   145  }
   146  
   147  func testMQTTDisconnect(urls []string, t *testing.T) {
   148  	conf := NewMQTTConfig()
   149  	conf.ClientID = "foo"
   150  	conf.Topics = []string{"test_input_1"}
   151  	conf.URLs = urls
   152  
   153  	m, err := NewMQTT(conf, log.Noop(), metrics.Noop())
   154  	if err != nil {
   155  		t.Fatal(err)
   156  	}
   157  
   158  	if err = m.Connect(); err != nil {
   159  		t.Fatal(err)
   160  	}
   161  
   162  	wg := sync.WaitGroup{}
   163  	wg.Add(1)
   164  	go func() {
   165  		m.CloseAsync()
   166  		if cErr := m.WaitForClose(time.Second); cErr != nil {
   167  			t.Error(cErr)
   168  		}
   169  		wg.Done()
   170  	}()
   171  
   172  	if _, err = m.Read(); err != types.ErrTypeClosed {
   173  		t.Errorf("Wrong error: %v != %v", err, types.ErrTypeClosed)
   174  	}
   175  
   176  	wg.Wait()
   177  }