github.com/aristanetworks/goarista@v0.0.0-20240514173732-cca2755bbd44/dscp/dial_test.go (about)

     1  // Copyright (c) 2017 Arista Networks, Inc.
     2  // Use of this source code is governed by the Apache License 2.0
     3  // that can be found in the COPYING file.
     4  
     5  package dscp_test
     6  
     7  import (
     8  	"fmt"
     9  	"net"
    10  	"strings"
    11  	"testing"
    12  	"time"
    13  
    14  	"github.com/aristanetworks/goarista/dscp"
    15  )
    16  
    17  func TestDialTCPWithTOS(t *testing.T) {
    18  	addr := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}
    19  	listen, err := net.ListenTCP("tcp", addr)
    20  	if err != nil {
    21  		t.Fatal(err)
    22  	}
    23  	defer listen.Close()
    24  
    25  	done := make(chan struct{})
    26  	go func() {
    27  		conn, err := listen.Accept()
    28  		if err != nil {
    29  			t.Error(err)
    30  		}
    31  		defer conn.Close()
    32  		buf := []byte{'!'}
    33  		conn.Write(buf)
    34  		n, err := conn.Read(buf)
    35  		if n != 1 || err != nil {
    36  			t.Errorf("Read returned %d / %s", n, err)
    37  		} else if buf[0] != '!' {
    38  			t.Errorf("Expected to read '!' but got %q", buf)
    39  		}
    40  		close(done)
    41  	}()
    42  	conn, err := dscp.DialTCPWithTOS(nil, listen.Addr().(*net.TCPAddr), 40)
    43  	if err != nil {
    44  		t.Fatal("Connection failed:", err)
    45  	}
    46  	defer conn.Close()
    47  	buf := make([]byte, 1)
    48  	n, err := conn.Read(buf)
    49  	if n != 1 || err != nil {
    50  		t.Fatalf("Read returned %d / %s", n, err)
    51  	} else if buf[0] != '!' {
    52  		t.Fatalf("Expected to read '!' but got %q", buf)
    53  	}
    54  	conn.Write(buf)
    55  	<-done
    56  }
    57  
    58  func TestDialTCPTimeoutWithTOS(t *testing.T) {
    59  	raddr := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}
    60  	for name, td := range map[string]*net.TCPAddr{
    61  		"ipNoPort": &net.TCPAddr{
    62  			IP: net.ParseIP("127.0.0.42"), Port: 0,
    63  		},
    64  		"ipWithPort": &net.TCPAddr{
    65  			IP: net.ParseIP("127.0.0.42"), Port: 10001,
    66  		},
    67  	} {
    68  		t.Run(name, func(t *testing.T) {
    69  			l, err := net.ListenTCP("tcp", raddr)
    70  			if err != nil {
    71  				t.Fatal(err)
    72  			}
    73  			defer l.Close()
    74  
    75  			var srcAddr net.Addr
    76  			done := make(chan struct{})
    77  			go func() {
    78  				conn, err := l.Accept()
    79  				if err != nil {
    80  					t.Error(err)
    81  				}
    82  				defer conn.Close()
    83  				srcAddr = conn.RemoteAddr()
    84  				close(done)
    85  			}()
    86  
    87  			conn, err := dscp.DialTCPTimeoutWithTOS(td, l.Addr().(*net.TCPAddr), 40, 5*time.Second)
    88  			if err != nil {
    89  				t.Fatal("Connection failed:", err)
    90  			}
    91  			defer conn.Close()
    92  
    93  			pfx := td.IP.String() + ":"
    94  			if td.Port > 0 {
    95  				pfx = fmt.Sprintf("%s%d", pfx, td.Port)
    96  			}
    97  			<-done
    98  			if !strings.HasPrefix(srcAddr.String(), pfx) {
    99  				t.Fatalf("DialTCPTimeoutWithTOS wrong address: %q instead of %q", srcAddr, pfx)
   100  			}
   101  		})
   102  	}
   103  }