github.com/aristanetworks/goarista@v0.0.0-20240514173732-cca2755bbd44/dscp/dial_test.go (about) 1 // Copyright (c) 2017 Arista Networks, Inc. 2 // Use of this source code is governed by the Apache License 2.0 3 // that can be found in the COPYING file. 4 5 package dscp_test 6 7 import ( 8 "fmt" 9 "net" 10 "strings" 11 "testing" 12 "time" 13 14 "github.com/aristanetworks/goarista/dscp" 15 ) 16 17 func TestDialTCPWithTOS(t *testing.T) { 18 addr := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0} 19 listen, err := net.ListenTCP("tcp", addr) 20 if err != nil { 21 t.Fatal(err) 22 } 23 defer listen.Close() 24 25 done := make(chan struct{}) 26 go func() { 27 conn, err := listen.Accept() 28 if err != nil { 29 t.Error(err) 30 } 31 defer conn.Close() 32 buf := []byte{'!'} 33 conn.Write(buf) 34 n, err := conn.Read(buf) 35 if n != 1 || err != nil { 36 t.Errorf("Read returned %d / %s", n, err) 37 } else if buf[0] != '!' { 38 t.Errorf("Expected to read '!' but got %q", buf) 39 } 40 close(done) 41 }() 42 conn, err := dscp.DialTCPWithTOS(nil, listen.Addr().(*net.TCPAddr), 40) 43 if err != nil { 44 t.Fatal("Connection failed:", err) 45 } 46 defer conn.Close() 47 buf := make([]byte, 1) 48 n, err := conn.Read(buf) 49 if n != 1 || err != nil { 50 t.Fatalf("Read returned %d / %s", n, err) 51 } else if buf[0] != '!' { 52 t.Fatalf("Expected to read '!' but got %q", buf) 53 } 54 conn.Write(buf) 55 <-done 56 } 57 58 func TestDialTCPTimeoutWithTOS(t *testing.T) { 59 raddr := &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0} 60 for name, td := range map[string]*net.TCPAddr{ 61 "ipNoPort": &net.TCPAddr{ 62 IP: net.ParseIP("127.0.0.42"), Port: 0, 63 }, 64 "ipWithPort": &net.TCPAddr{ 65 IP: net.ParseIP("127.0.0.42"), Port: 10001, 66 }, 67 } { 68 t.Run(name, func(t *testing.T) { 69 l, err := net.ListenTCP("tcp", raddr) 70 if err != nil { 71 t.Fatal(err) 72 } 73 defer l.Close() 74 75 var srcAddr net.Addr 76 done := make(chan struct{}) 77 go func() { 78 conn, err := l.Accept() 79 if err != nil { 80 t.Error(err) 81 } 82 defer conn.Close() 83 srcAddr = conn.RemoteAddr() 84 close(done) 85 }() 86 87 conn, err := dscp.DialTCPTimeoutWithTOS(td, l.Addr().(*net.TCPAddr), 40, 5*time.Second) 88 if err != nil { 89 t.Fatal("Connection failed:", err) 90 } 91 defer conn.Close() 92 93 pfx := td.IP.String() + ":" 94 if td.Port > 0 { 95 pfx = fmt.Sprintf("%s%d", pfx, td.Port) 96 } 97 <-done 98 if !strings.HasPrefix(srcAddr.String(), pfx) { 99 t.Fatalf("DialTCPTimeoutWithTOS wrong address: %q instead of %q", srcAddr, pfx) 100 } 101 }) 102 } 103 }