github.com/m-lab/tcp-info@v1.9.0/collector/collector_linux_test.go (about)

     1  package collector_test
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"log"
     7  	"net"
     8  	"sync"
     9  	"testing"
    10  	"time"
    11  
    12  	"github.com/m-lab/go/rtx"
    13  	"github.com/m-lab/tcp-info/collector"
    14  	"github.com/m-lab/tcp-info/netlink"
    15  )
    16  
    17  func init() {
    18  	// Always prepend the filename and line number.
    19  	log.SetFlags(log.LstdFlags | log.Lshortfile)
    20  }
    21  
    22  func testFatal(t *testing.T, err error) {
    23  	if err != nil {
    24  		t.Fatal(err)
    25  	}
    26  }
    27  
    28  type testCacheLogger struct{}
    29  
    30  func (t *testCacheLogger) LogCacheStats(_, _ int) {}
    31  
    32  // This opens a local connection and streams data through it.
    33  func runTest(t *testing.T, ctx context.Context, port int) {
    34  	// Open a server socket, connect to it, send data to it until the context is canceled.
    35  	address := fmt.Sprintf("localhost:%d", port)
    36  	t.Log("Listening on", address)
    37  	localAddr, err := net.ResolveTCPAddr("tcp", address)
    38  	rtx.Must(err, "No localhost")
    39  	listener, err := net.ListenTCP("tcp", localAddr)
    40  	rtx.Must(err, "Could not make TCP listener")
    41  	local, err := net.Dial("tcp", address)
    42  	defer local.Close()
    43  	rtx.Must(err, "Could not connect to myself")
    44  	conn, err := listener.AcceptTCP()
    45  	rtx.Must(err, "Could not accept conn")
    46  	go func() {
    47  		for ctx.Err() == nil {
    48  			conn.Write([]byte("hello"))
    49  		}
    50  	}()
    51  	buff := make([]byte, 1024)
    52  	for ctx.Err() == nil {
    53  		local.SetReadDeadline(time.Now().Add(10 * time.Millisecond))
    54  		local.Read(buff)
    55  	}
    56  }
    57  
    58  func findPort() int {
    59  	portFinder, err := net.Listen("tcp", ":0")
    60  	rtx.Must(err, "Could not open server to discover open ports")
    61  	port := portFinder.Addr().(*net.TCPAddr).Port
    62  	portFinder.Close()
    63  	return port
    64  }
    65  
    66  func TestRun(t *testing.T) {
    67  	ctx, cancel := context.WithCancel(context.Background())
    68  	defer cancel()
    69  
    70  	port := findPort()
    71  
    72  	// A nice big buffer on the channel
    73  	msgChan := make(chan netlink.MessageBlock, 10000)
    74  	var wg sync.WaitGroup
    75  	wg.Add(3)
    76  
    77  	go func() {
    78  		defer wg.Done()
    79  		collector.Run(ctx, 0, msgChan, &testCacheLogger{}, false)
    80  		t.Log("Run done.")
    81  	}()
    82  
    83  	go func() {
    84  		defer wg.Done()
    85  		runTest(t, ctx, port)
    86  		t.Log("runTest done.")
    87  	}()
    88  
    89  	go func() {
    90  		defer wg.Done()
    91  		select {
    92  		case <-ctx.Done():
    93  			t.Log("ctx.Done")
    94  			return
    95  		case <-time.NewTimer(10 * time.Second).C:
    96  			t.Log("Time out")
    97  			cancel()
    98  			close(msgChan)
    99  			t.Error("It should not take 10 seconds to get enough messages. Something is wrong.")
   100  			return
   101  		}
   102  	}()
   103  
   104  	// Make sure we receive multiple different messages regarding the open port
   105  	count := 0
   106  	var prev *netlink.ArchivalRecord
   107  	for msgs := range msgChan {
   108  		changed := false
   109  		for _, v4 := range msgs.V4Messages {
   110  			if v4 == nil {
   111  				continue
   112  			}
   113  			m, err := netlink.MakeArchivalRecord(v4, nil)
   114  			testFatal(t, err)
   115  			idm, err := m.RawIDM.Parse()
   116  			testFatal(t, err)
   117  			if idm != nil && idm.ID.SPort() == uint16(port) {
   118  				change, err := m.Compare(prev)
   119  				if err != nil {
   120  					t.Log(err)
   121  				} else if change > netlink.NoMajorChange {
   122  					prev = m
   123  					changed = true
   124  				}
   125  			}
   126  		}
   127  		for _, v6 := range msgs.V6Messages {
   128  			if v6 == nil {
   129  				continue
   130  			}
   131  			m, err := netlink.MakeArchivalRecord(v6, nil)
   132  			testFatal(t, err)
   133  			idm, err := m.RawIDM.Parse()
   134  			testFatal(t, err)
   135  			if idm != nil && idm.ID.SPort() == uint16(port) {
   136  				change, err := m.Compare(prev)
   137  				if err != nil {
   138  					t.Log(err)
   139  				} else if change > netlink.NoMajorChange {
   140  					prev = m
   141  					changed = true
   142  				}
   143  			}
   144  		}
   145  		if changed {
   146  			count++
   147  		}
   148  		if count > 10 {
   149  			cancel()
   150  			break
   151  		}
   152  	}
   153  
   154  	t.Log("Waiting for goroutines to exit")
   155  	wg.Wait()
   156  }