github.com/vishvananda/netlink@v1.3.0/socket_test.go (about) 1 //go:build linux 2 // +build linux 3 4 package netlink 5 6 import ( 7 "fmt" 8 "log" 9 "net" 10 "os/user" 11 "strconv" 12 "syscall" 13 "testing" 14 ) 15 16 func TestSocketGet(t *testing.T) { 17 defer setUpNetlinkTestWithLoopback(t)() 18 19 type Addr struct { 20 IP net.IP 21 Port int 22 } 23 24 getAddr := func(a net.Addr) Addr { 25 var addr Addr 26 switch v := a.(type) { 27 case *net.UDPAddr: 28 addr.IP = v.IP 29 addr.Port = v.Port 30 case *net.TCPAddr: 31 addr.IP = v.IP 32 addr.Port = v.Port 33 } 34 return addr 35 } 36 37 checkSocket := func(t *testing.T, local, remote net.Addr) { 38 socket, err := SocketGet(local, remote) 39 if err != nil { 40 t.Fatal(err) 41 } 42 43 localAddr, remoteAddr := getAddr(local), getAddr(remote) 44 45 if got, want := socket.ID.Source, localAddr.IP; !got.Equal(want) { 46 t.Fatalf("local ip = %v, want %v", got, want) 47 } 48 if got, want := socket.ID.Destination, remoteAddr.IP; !got.Equal(want) { 49 t.Fatalf("remote ip = %v, want %v", got, want) 50 } 51 if got, want := int(socket.ID.SourcePort), localAddr.Port; got != want { 52 t.Fatalf("local port = %d, want %d", got, want) 53 } 54 if got, want := int(socket.ID.DestinationPort), remoteAddr.Port; got != want { 55 t.Fatalf("remote port = %d, want %d", got, want) 56 } 57 u, err := user.Current() 58 if err != nil { 59 t.Fatal(err) 60 } 61 if got, want := strconv.Itoa(int(socket.UID)), u.Uid; got != want { 62 t.Fatalf("UID = %s, want %s", got, want) 63 } 64 } 65 66 for _, v := range [...]string{"tcp4", "tcp6"} { 67 addr, err := net.ResolveTCPAddr(v, "localhost:0") 68 if err != nil { 69 log.Fatal(err) 70 } 71 l, err := net.ListenTCP(v, addr) 72 if err != nil { 73 log.Fatal(err) 74 } 75 defer l.Close() 76 77 conn, err := net.Dial(l.Addr().Network(), l.Addr().String()) 78 if err != nil { 79 t.Fatal(err) 80 } 81 defer conn.Close() 82 83 checkSocket(t, conn.LocalAddr(), conn.RemoteAddr()) 84 } 85 86 for _, v := range [...]string{"udp4", "udp6"} { 87 addr, err := net.ResolveUDPAddr(v, "localhost:0") 88 if err != nil { 89 log.Fatal(err) 90 } 91 l, err := net.ListenUDP(v, addr) 92 if err != nil { 93 log.Fatal(err) 94 } 95 defer l.Close() 96 conn, err := net.Dial(l.LocalAddr().Network(), l.LocalAddr().String()) 97 if err != nil { 98 t.Fatal(err) 99 } 100 defer conn.Close() 101 102 checkSocket(t, conn.LocalAddr(), conn.RemoteAddr()) 103 } 104 } 105 106 func TestSocketDestroy(t *testing.T) { 107 defer setUpNetlinkTestWithLoopback(t)() 108 109 addr, err := net.ResolveTCPAddr("tcp", "localhost:0") 110 if err != nil { 111 log.Fatal(err) 112 } 113 l, err := net.ListenTCP("tcp", addr) 114 if err != nil { 115 log.Fatal(err) 116 } 117 defer l.Close() 118 119 conn, err := net.Dial(l.Addr().Network(), l.Addr().String()) 120 if err != nil { 121 t.Fatal(err) 122 } 123 defer conn.Close() 124 125 localAddr := conn.LocalAddr().(*net.TCPAddr) 126 remoteAddr := conn.RemoteAddr().(*net.TCPAddr) 127 err = SocketDestroy(localAddr, remoteAddr) 128 if err != nil { 129 t.Fatal(err) 130 } 131 } 132 133 func TestSocketDiagTCPInfo(t *testing.T) { 134 Family4 := uint8(syscall.AF_INET) 135 Family6 := uint8(syscall.AF_INET6) 136 families := []uint8{Family4, Family6} 137 for _, wantFamily := range families { 138 res, err := SocketDiagTCPInfo(wantFamily) 139 if err != nil { 140 t.Fatal(err) 141 } 142 for _, i := range res { 143 gotFamily := i.InetDiagMsg.Family 144 if gotFamily != wantFamily { 145 t.Fatalf("Socket family = %d, want %d", gotFamily, wantFamily) 146 } 147 } 148 } 149 } 150 151 func TestSocketDiagUDPnfo(t *testing.T) { 152 for _, want := range []uint8{syscall.AF_INET, syscall.AF_INET6} { 153 result, err := SocketDiagUDPInfo(want) 154 if err != nil { 155 t.Fatal(err) 156 } 157 158 for _, r := range result { 159 if got := r.InetDiagMsg.Family; got != want { 160 t.Fatalf("protocol family = %v, want %v", got, want) 161 } 162 } 163 } 164 } 165 166 func TestUnixSocketDiagInfo(t *testing.T) { 167 want := syscall.AF_UNIX 168 result, err := UnixSocketDiagInfo() 169 if err != nil { 170 t.Fatal(err) 171 } 172 173 for i, r := range result { 174 fmt.Println(r.DiagMsg) 175 if got := r.DiagMsg.Family; got != uint8(want) { 176 t.Fatalf("%d: protocol family = %v, want %v", i, got, want) 177 } 178 } 179 }