github.com/haraldrudell/parl@v0.4.176/pnet/tcp-listener_test.go (about)

     1  /*
     2  © 2021–present Harald Rudell <harald.rudell@gmail.com> (https://haraldrudell.github.io/haraldrudell/)
     3  ISC License
     4  */
     5  
     6  package pnet
     7  
     8  import (
     9  	"context"
    10  	"errors"
    11  	"io"
    12  	"net"
    13  	"net/netip"
    14  	"strconv"
    15  	"strings"
    16  	"sync"
    17  	"sync/atomic"
    18  	"testing"
    19  
    20  	"github.com/haraldrudell/parl"
    21  	"github.com/haraldrudell/parl/perrors"
    22  )
    23  
    24  func TestSocket(t *testing.T) {
    25  	var expAddr = netip.MustParseAddr("127.0.0.1") // localhost: ensure IPv4
    26  	var socketString = expAddr.String() + ":0"     // 0 means ephemeral port
    27  
    28  	var colonPos int
    29  	var socket *TCPListener
    30  	var err error
    31  	var addr string
    32  	var port int
    33  	var addrPort netip.AddrPort
    34  
    35  	// check socketString
    36  	if colonPos = strings.Index(socketString, ":"); colonPos == -1 {
    37  		t.Fatalf("Bad socketString fixture: %q", socketString)
    38  	}
    39  
    40  	// NewTCPListener()
    41  	socket = NewTCPListener()
    42  	if socket == nil {
    43  		t.Fatalf("NewTCPListener nil")
    44  	} else if socket.TCPListener == nil {
    45  		t.Fatalf("socket.TCPListener nil")
    46  	} else if socket.SocketListener.netListener != socket.TCPListener {
    47  		t.Fatalf("socket.TCPListener and socket.SocketListener.netListener different")
    48  	}
    49  
    50  	// Listen()
    51  	if err = socket.Listen(socketString); err != nil {
    52  		t.Fatalf("ListenTCP4 error: %+v", err)
    53  	}
    54  
    55  	// Addr()
    56  	if addr = socket.Addr().String(); !strings.HasPrefix(addr, socketString[:colonPos+1]) {
    57  		t.Fatalf("Bad socket adress: %q", addr)
    58  	}
    59  	if port, err = strconv.Atoi(strings.TrimPrefix(addr, socketString[:colonPos+1])); err != nil {
    60  		t.Errorf("Bad port number: %q", addr)
    61  	} else if port < 1 || port > 65535 {
    62  		t.Errorf("Bad port numeric: %v", port)
    63  	}
    64  
    65  	// AddrPort()
    66  	if addrPort, err = socket.AddrPort(); err != nil {
    67  		t.Errorf("AddrPort err: %s", perrors.Short(err))
    68  	}
    69  	if addrPort.Addr() != expAddr {
    70  		t.Errorf("bad AddrPort addr: %q exp %q", addrPort.Addr(), expAddr)
    71  	}
    72  
    73  	// Close()
    74  	if err = socket.Close(); err != nil {
    75  		t.Errorf("socket.Close: '%v'", err)
    76  	}
    77  }
    78  
    79  type connectionHandlerFixture struct {
    80  	count int64
    81  }
    82  
    83  func (c *connectionHandlerFixture) connFunc(conn *net.TCPConn) {
    84  	if err := conn.Close(); err != nil {
    85  		panic(perrors.Errorf("conn.Close: '%w'", err))
    86  	}
    87  	atomic.AddInt64(&c.count, 1)
    88  }
    89  
    90  func (c *connectionHandlerFixture) errorListenerThread(
    91  	errs parl.Errs,
    92  	socketCloseCh <-chan struct{},
    93  	wg *sync.WaitGroup) {
    94  	defer wg.Done()
    95  
    96  	var err error
    97  	for {
    98  		select {
    99  		case <-socketCloseCh:
   100  			return
   101  		case <-errs.WaitCh():
   102  			err, _ = errs.Error()
   103  			if err == nil {
   104  				panic(perrors.New("socket error nil"))
   105  			}
   106  			panic(err)
   107  		}
   108  	}
   109  }
   110  
   111  func TestAcceptThread(t *testing.T) {
   112  	var socketString = "127.0.0.1:0" // 0 means ephemeral port
   113  	var fixture connectionHandlerFixture
   114  
   115  	var socket *TCPListener
   116  	var err error
   117  	var ctx context.Context = context.Background()
   118  	var addr net.Addr
   119  	var tcpClient net.Dialer
   120  	var netConn net.Conn
   121  	var threadWait sync.WaitGroup
   122  
   123  	// set-up socket
   124  	socket = NewTCPListener()
   125  	if err = socket.Listen(socketString); err != nil {
   126  		t.Fatalf("ListenTCP4 error: %+v", err)
   127  	}
   128  
   129  	// error listener thread
   130  	threadWait.Add(1)
   131  	go fixture.errorListenerThread(socket.Errs(), socket.WaitCh(), &threadWait)
   132  
   133  	// invoke AcceptConnections
   134  	t.Log("socket.AcceptThread…")
   135  	go socket.AcceptConnections(fixture.connFunc)
   136  
   137  	// connect to socket
   138  	t.Log("tcpClient.DialContext…")
   139  	addr = socket.Addr()
   140  	if netConn, err = tcpClient.DialContext(ctx, addr.Network(), addr.String()); err != nil {
   141  		t.Fatalf("tcpClient.DialContext: '%v'", err)
   142  	}
   143  
   144  	// read from socket
   145  	t.Log("netConn.Read…")
   146  	bytes := make([]byte, 1)
   147  	for {
   148  		var n int
   149  		n, err = netConn.Read(bytes)
   150  		if err != nil {
   151  			if errors.Is(err, io.EOF) {
   152  				err = nil
   153  				break
   154  			} else {
   155  				t.Fatalf("conn.Read: '%v'", err)
   156  			}
   157  		}
   158  		if n != 0 {
   159  			t.Fatalf("conn.Read unexpected bytes: %d", n)
   160  		}
   161  	}
   162  
   163  	// close client
   164  	t.Log("netConn.Close…")
   165  	if err := netConn.Close(); err != nil {
   166  		t.Errorf("client Close: '%v'", err)
   167  	}
   168  
   169  	// close listener
   170  	t.Log("socket.Close…")
   171  	if err := socket.Close(); err != nil {
   172  		t.Errorf("client Close: '%v'", err)
   173  	}
   174  
   175  	t.Logf("socket.Wait… %d", atomic.LoadInt64(&fixture.count))
   176  	<-socket.WaitCh()
   177  
   178  	t.Log("error listener Wait…")
   179  	threadWait.Wait()
   180  
   181  	t.Log("Completed")
   182  }