github.com/gramework/gramework@v1.8.1-0.20231027140105-82555c9057f5/x/testutils/choosePort.go (about)

     1  package testutils
     2  
     3  import (
     4  	"fmt"
     5  	"math/rand"
     6  	"net"
     7  	"sync"
     8  )
     9  
    10  var portsRegister = map[uint16]struct{}{}
    11  var portsRegisterMu = &sync.Mutex{}
    12  
    13  // PortChooser does not store any port information.
    14  // It was created only for API purposes.
    15  type PortChooser struct {
    16  	nonRoot *bool
    17  	unused  *bool
    18  }
    19  
    20  // Port creates a chaining API structure
    21  func Port() *PortChooser {
    22  	return &PortChooser{}
    23  }
    24  
    25  // NonRoot enables the non-root port requirement: min port will be 1025 to ensure anything is ok.
    26  func (pc *PortChooser) NonRoot() *PortChooser {
    27  	pc.nonRoot = new(bool)
    28  	*pc.nonRoot = true
    29  	return pc
    30  }
    31  
    32  // Unused enables a check that port is free.
    33  func (pc *PortChooser) Unused() *PortChooser {
    34  	pc.unused = new(bool)
    35  	*pc.unused = true
    36  	return pc
    37  }
    38  
    39  // Root enables root-only port requirement: max port will be 1024.
    40  func (pc *PortChooser) Root() *PortChooser {
    41  	pc.nonRoot = new(bool)
    42  	return pc
    43  }
    44  
    45  // Used enables a check that port is not free.
    46  func (pc *PortChooser) Used() *PortChooser {
    47  	pc.unused = new(bool)
    48  	*pc.unused = false
    49  	return pc
    50  }
    51  
    52  // Acquire applies all filters defined before and returns a port number.
    53  func (pc *PortChooser) Acquire() int {
    54  	_, port := pc.determinePort()
    55  
    56  	portsRegisterMu.Lock()
    57  	portsRegister[uint16(port)] = struct{}{}
    58  	portsRegisterMu.Unlock()
    59  	return port
    60  }
    61  
    62  func (pc *PortChooser) AcquireListener() (net.Listener, int) {
    63  	ln, port := pc.determinePort()
    64  
    65  	portsRegisterMu.Lock()
    66  	portsRegister[uint16(port)] = struct{}{}
    67  	portsRegisterMu.Unlock()
    68  
    69  	return ln, port
    70  }
    71  
    72  func (pc *PortChooser) determinePort() (net.Listener, int) {
    73  	minPort := 1
    74  	maxPort := 65535
    75  
    76  	if pc.nonRoot != nil {
    77  		if *pc.nonRoot {
    78  			minPort = 1025
    79  		} else {
    80  			maxPort = 1024
    81  		}
    82  	}
    83  
    84  	chosenPort := 0
    85  	if pc.unused != nil && !*pc.unused {
    86  		portsRegisterMu.Lock()
    87  		for port := range portsRegister {
    88  			if int(port) > minPort && int(port) < maxPort {
    89  				chosenPort = int(port)
    90  			}
    91  		}
    92  		portsRegisterMu.Unlock()
    93  		if chosenPort != 0 {
    94  			return nil, chosenPort
    95  		}
    96  
    97  		chosenPort = rand.Intn(maxPort-minPort) + minPort
    98  		_, err := net.Listen("tcp4", fmt.Sprintf(":%d", chosenPort))
    99  		_ = err // fixes linter warning
   100  		return nil, chosenPort
   101  	}
   102  
   103  	for {
   104  		chosenPort = rand.Intn(maxPort-minPort) + minPort
   105  		ln, err := net.Listen("tcp4", fmt.Sprintf(":%d", chosenPort))
   106  		if err == nil {
   107  			return ln, chosenPort
   108  		}
   109  	}
   110  }