github.com/Ilhicas/nomad@v1.0.4-0.20210304152020-e86851182bc3/helper/freeport/freeport_test.go (about)

     1  package freeport
     2  
     3  import (
     4  	"fmt"
     5  	"io"
     6  	"net"
     7  	"sync"
     8  	"testing"
     9  
    10  	"github.com/hashicorp/consul/sdk/testutil/retry"
    11  )
    12  
    13  // reset will reverse the setup from initialize() and then redo it (for tests)
    14  func reset() {
    15  	mu.Lock()
    16  	defer mu.Unlock()
    17  
    18  	logf("INFO", "resetting the freeport package state")
    19  
    20  	effectiveMaxBlocks = 0
    21  	firstPort = 0
    22  	if lockLn != nil {
    23  		lockLn.Close()
    24  		lockLn = nil
    25  	}
    26  
    27  	once = sync.Once{}
    28  
    29  	freePorts = nil
    30  	pendingPorts = nil
    31  	total = 0
    32  }
    33  
    34  // peekFree returns the next port that will be returned by Take to aid in testing.
    35  func peekFree() int {
    36  	mu.Lock()
    37  	defer mu.Unlock()
    38  	return freePorts.Front().Value.(int)
    39  }
    40  
    41  // peekAllFree returns all free ports that could be returned by Take to aid in testing.
    42  func peekAllFree() []int {
    43  	mu.Lock()
    44  	defer mu.Unlock()
    45  
    46  	var out []int
    47  	for elem := freePorts.Front(); elem != nil; elem = elem.Next() {
    48  		port := elem.Value.(int)
    49  		out = append(out, port)
    50  	}
    51  
    52  	return out
    53  }
    54  
    55  // stats returns diagnostic data to aid in testing
    56  func stats() (numTotal, numPending, numFree int) {
    57  	mu.Lock()
    58  	defer mu.Unlock()
    59  	return total, pendingPorts.Len(), freePorts.Len()
    60  }
    61  
    62  func TestTakeReturn(t *testing.T) {
    63  	// NOTE: for global var reasons this cannot execute in parallel
    64  	// t.Parallel()
    65  
    66  	// Since this test is destructive (i.e. it leaks all ports) it means that
    67  	// any other test cases in this package will not function after it runs. To
    68  	// help out we reset the global state after we run this test.
    69  	defer reset()
    70  
    71  	// OK: do a simple take/return cycle to trigger the package initialization
    72  	func() {
    73  		ports, err := Take(1)
    74  		if err != nil {
    75  			t.Fatalf("err: %v", err)
    76  		}
    77  		defer Return(ports)
    78  
    79  		if len(ports) != 1 {
    80  			t.Fatalf("expected %d but got %d ports", 1, len(ports))
    81  		}
    82  	}()
    83  
    84  	waitForStatsReset := func() (numTotal int) {
    85  		t.Helper()
    86  		numTotal, numPending, numFree := stats()
    87  		if numTotal != numFree+numPending {
    88  			t.Fatalf("expected total (%d) and free+pending (%d) ports to match", numTotal, numFree+numPending)
    89  		}
    90  		retry.Run(t, func(r *retry.R) {
    91  			numTotal, numPending, numFree = stats()
    92  			if numPending != 0 {
    93  				r.Fatalf("pending is still non zero: %d", numPending)
    94  			}
    95  			if numTotal != numFree {
    96  				r.Fatalf("total (%d) does not equal free (%d)", numTotal, numFree)
    97  			}
    98  		})
    99  		return numTotal
   100  	}
   101  
   102  	// Reset
   103  	numTotal := waitForStatsReset()
   104  
   105  	// --------------------
   106  	// OK: take the max
   107  	func() {
   108  		ports, err := Take(numTotal)
   109  		if err != nil {
   110  			t.Fatalf("err: %v", err)
   111  		}
   112  		defer Return(ports)
   113  
   114  		if len(ports) != numTotal {
   115  			t.Fatalf("expected %d but got %d ports", numTotal, len(ports))
   116  		}
   117  	}()
   118  
   119  	// Reset
   120  	numTotal = waitForStatsReset()
   121  
   122  	expectError := func(expected string, got error) {
   123  		t.Helper()
   124  		if got == nil {
   125  			t.Fatalf("expected error but was nil")
   126  		}
   127  		if got.Error() != expected {
   128  			t.Fatalf("expected error %q but got %q", expected, got.Error())
   129  		}
   130  	}
   131  
   132  	// --------------------
   133  	// ERROR: take too many ports
   134  	func() {
   135  		ports, err := Take(numTotal + 1)
   136  		defer Return(ports)
   137  		expectError("freeport: block size too small", err)
   138  	}()
   139  
   140  	// --------------------
   141  	// ERROR: invalid ports request (negative)
   142  	func() {
   143  		_, err := Take(-1)
   144  		expectError("freeport: cannot take -1 ports", err)
   145  	}()
   146  
   147  	// --------------------
   148  	// ERROR: invalid ports request (zero)
   149  	func() {
   150  		_, err := Take(0)
   151  		expectError("freeport: cannot take 0 ports", err)
   152  	}()
   153  
   154  	// --------------------
   155  	// OK: Steal a port under the covers and let freeport detect the theft and compensate
   156  	leakedPort := peekFree()
   157  	func() {
   158  		leakyListener, err := net.ListenTCP("tcp", tcpAddr("127.0.0.1", leakedPort))
   159  		if err != nil {
   160  			t.Fatalf("err: %v", err)
   161  		}
   162  		defer leakyListener.Close()
   163  
   164  		func() {
   165  			ports, err := Take(3)
   166  			if err != nil {
   167  				t.Fatalf("err: %v", err)
   168  			}
   169  			defer Return(ports)
   170  
   171  			if len(ports) != 3 {
   172  				t.Fatalf("expected %d but got %d ports", 3, len(ports))
   173  			}
   174  
   175  			for _, port := range ports {
   176  				if port == leakedPort {
   177  					t.Fatalf("did not expect for Take to return the leaked port")
   178  				}
   179  			}
   180  		}()
   181  
   182  		newNumTotal := waitForStatsReset()
   183  		if newNumTotal != numTotal-1 {
   184  			t.Fatalf("expected total to drop to %d but got %d", numTotal-1, newNumTotal)
   185  		}
   186  		numTotal = newNumTotal // update outer variable for later tests
   187  	}()
   188  
   189  	// --------------------
   190  	// OK: sequence it so that one Take must wait on another Take to Return.
   191  	func() {
   192  		mostPorts, err := Take(numTotal - 5)
   193  		if err != nil {
   194  			t.Fatalf("err: %v", err)
   195  		}
   196  
   197  		type reply struct {
   198  			ports []int
   199  			err   error
   200  		}
   201  		ch := make(chan reply, 1)
   202  		go func() {
   203  			ports, err := Take(10)
   204  			ch <- reply{ports: ports, err: err}
   205  		}()
   206  
   207  		Return(mostPorts)
   208  
   209  		r := <-ch
   210  		if r.err != nil {
   211  			t.Fatalf("err: %v", r.err)
   212  		}
   213  		defer Return(r.ports)
   214  
   215  		if len(r.ports) != 10 {
   216  			t.Fatalf("expected %d ports but got %d", 10, len(r.ports))
   217  		}
   218  	}()
   219  
   220  	// Reset
   221  	numTotal = waitForStatsReset()
   222  
   223  	// --------------------
   224  	// ERROR: Now we end on the crazy "Ocean's 11" level port theft where we
   225  	// orchestrate a situation where all ports are stolen and we don't find out
   226  	// until Take.
   227  	func() {
   228  		// 1. Grab all of the ports.
   229  		allPorts := peekAllFree()
   230  
   231  		// 2. Leak all of the ports
   232  		leaked := make([]io.Closer, 0, len(allPorts))
   233  		defer func() {
   234  			for _, c := range leaked {
   235  				c.Close()
   236  			}
   237  		}()
   238  		for _, port := range allPorts {
   239  			ln, err := net.ListenTCP("tcp", tcpAddr("127.0.0.1", port))
   240  			if err != nil {
   241  				t.Fatalf("err: %v", err)
   242  			}
   243  			leaked = append(leaked, ln)
   244  		}
   245  
   246  		// 3. Request 1 port which will detect the leaked ports and fail.
   247  		_, err := Take(1)
   248  		expectError("freeport: impossible to satisfy request; there are no actual free ports in the block anymore", err)
   249  
   250  		// 4. Wait for the block to zero out.
   251  		newNumTotal := waitForStatsReset()
   252  		if newNumTotal != 0 {
   253  			t.Fatalf("expected total to drop to %d but got %d", 0, newNumTotal)
   254  		}
   255  	}()
   256  }
   257  
   258  func TestIntervalOverlap(t *testing.T) {
   259  	cases := []struct {
   260  		min1, max1, min2, max2 int
   261  		overlap                bool
   262  	}{
   263  		{0, 0, 0, 0, true},
   264  		{1, 1, 1, 1, true},
   265  		{1, 3, 1, 3, true},  // same
   266  		{1, 3, 4, 6, false}, // serial
   267  		{1, 4, 3, 6, true},  // inner overlap
   268  		{1, 6, 3, 4, true},  // nest
   269  	}
   270  
   271  	for _, tc := range cases {
   272  		t.Run(fmt.Sprintf("%d:%d vs %d:%d", tc.min1, tc.max1, tc.min2, tc.max2), func(t *testing.T) {
   273  			if tc.overlap != intervalOverlap(tc.min1, tc.max1, tc.min2, tc.max2) { // 1 vs 2
   274  				t.Fatalf("expected %v but got %v", tc.overlap, !tc.overlap)
   275  			}
   276  			if tc.overlap != intervalOverlap(tc.min2, tc.max2, tc.min1, tc.max1) { // 2 vs 1
   277  				t.Fatalf("expected %v but got %v", tc.overlap, !tc.overlap)
   278  			}
   279  		})
   280  	}
   281  }