github.com/vishvananda/netlink@v1.3.0/socket_test.go (about)

     1  //go:build linux
     2  // +build linux
     3  
     4  package netlink
     5  
     6  import (
     7  	"fmt"
     8  	"log"
     9  	"net"
    10  	"os/user"
    11  	"strconv"
    12  	"syscall"
    13  	"testing"
    14  )
    15  
    16  func TestSocketGet(t *testing.T) {
    17  	defer setUpNetlinkTestWithLoopback(t)()
    18  
    19  	type Addr struct {
    20  		IP   net.IP
    21  		Port int
    22  	}
    23  
    24  	getAddr := func(a net.Addr) Addr {
    25  		var addr Addr
    26  		switch v := a.(type) {
    27  		case *net.UDPAddr:
    28  			addr.IP = v.IP
    29  			addr.Port = v.Port
    30  		case *net.TCPAddr:
    31  			addr.IP = v.IP
    32  			addr.Port = v.Port
    33  		}
    34  		return addr
    35  	}
    36  
    37  	checkSocket := func(t *testing.T, local, remote net.Addr) {
    38  		socket, err := SocketGet(local, remote)
    39  		if err != nil {
    40  			t.Fatal(err)
    41  		}
    42  
    43  		localAddr, remoteAddr := getAddr(local), getAddr(remote)
    44  
    45  		if got, want := socket.ID.Source, localAddr.IP; !got.Equal(want) {
    46  			t.Fatalf("local ip = %v, want %v", got, want)
    47  		}
    48  		if got, want := socket.ID.Destination, remoteAddr.IP; !got.Equal(want) {
    49  			t.Fatalf("remote ip = %v, want %v", got, want)
    50  		}
    51  		if got, want := int(socket.ID.SourcePort), localAddr.Port; got != want {
    52  			t.Fatalf("local port = %d, want %d", got, want)
    53  		}
    54  		if got, want := int(socket.ID.DestinationPort), remoteAddr.Port; got != want {
    55  			t.Fatalf("remote port = %d, want %d", got, want)
    56  		}
    57  		u, err := user.Current()
    58  		if err != nil {
    59  			t.Fatal(err)
    60  		}
    61  		if got, want := strconv.Itoa(int(socket.UID)), u.Uid; got != want {
    62  			t.Fatalf("UID = %s, want %s", got, want)
    63  		}
    64  	}
    65  
    66  	for _, v := range [...]string{"tcp4", "tcp6"} {
    67  		addr, err := net.ResolveTCPAddr(v, "localhost:0")
    68  		if err != nil {
    69  			log.Fatal(err)
    70  		}
    71  		l, err := net.ListenTCP(v, addr)
    72  		if err != nil {
    73  			log.Fatal(err)
    74  		}
    75  		defer l.Close()
    76  
    77  		conn, err := net.Dial(l.Addr().Network(), l.Addr().String())
    78  		if err != nil {
    79  			t.Fatal(err)
    80  		}
    81  		defer conn.Close()
    82  
    83  		checkSocket(t, conn.LocalAddr(), conn.RemoteAddr())
    84  	}
    85  
    86  	for _, v := range [...]string{"udp4", "udp6"} {
    87  		addr, err := net.ResolveUDPAddr(v, "localhost:0")
    88  		if err != nil {
    89  			log.Fatal(err)
    90  		}
    91  		l, err := net.ListenUDP(v, addr)
    92  		if err != nil {
    93  			log.Fatal(err)
    94  		}
    95  		defer l.Close()
    96  		conn, err := net.Dial(l.LocalAddr().Network(), l.LocalAddr().String())
    97  		if err != nil {
    98  			t.Fatal(err)
    99  		}
   100  		defer conn.Close()
   101  
   102  		checkSocket(t, conn.LocalAddr(), conn.RemoteAddr())
   103  	}
   104  }
   105  
   106  func TestSocketDestroy(t *testing.T) {
   107  	defer setUpNetlinkTestWithLoopback(t)()
   108  
   109  	addr, err := net.ResolveTCPAddr("tcp", "localhost:0")
   110  	if err != nil {
   111  		log.Fatal(err)
   112  	}
   113  	l, err := net.ListenTCP("tcp", addr)
   114  	if err != nil {
   115  		log.Fatal(err)
   116  	}
   117  	defer l.Close()
   118  
   119  	conn, err := net.Dial(l.Addr().Network(), l.Addr().String())
   120  	if err != nil {
   121  		t.Fatal(err)
   122  	}
   123  	defer conn.Close()
   124  
   125  	localAddr := conn.LocalAddr().(*net.TCPAddr)
   126  	remoteAddr := conn.RemoteAddr().(*net.TCPAddr)
   127  	err = SocketDestroy(localAddr, remoteAddr)
   128  	if err != nil {
   129  		t.Fatal(err)
   130  	}
   131  }
   132  
   133  func TestSocketDiagTCPInfo(t *testing.T) {
   134  	Family4 := uint8(syscall.AF_INET)
   135  	Family6 := uint8(syscall.AF_INET6)
   136  	families := []uint8{Family4, Family6}
   137  	for _, wantFamily := range families {
   138  		res, err := SocketDiagTCPInfo(wantFamily)
   139  		if err != nil {
   140  			t.Fatal(err)
   141  		}
   142  		for _, i := range res {
   143  			gotFamily := i.InetDiagMsg.Family
   144  			if gotFamily != wantFamily {
   145  				t.Fatalf("Socket family = %d, want %d", gotFamily, wantFamily)
   146  			}
   147  		}
   148  	}
   149  }
   150  
   151  func TestSocketDiagUDPnfo(t *testing.T) {
   152  	for _, want := range []uint8{syscall.AF_INET, syscall.AF_INET6} {
   153  		result, err := SocketDiagUDPInfo(want)
   154  		if err != nil {
   155  			t.Fatal(err)
   156  		}
   157  
   158  		for _, r := range result {
   159  			if got := r.InetDiagMsg.Family; got != want {
   160  				t.Fatalf("protocol family = %v, want %v", got, want)
   161  			}
   162  		}
   163  	}
   164  }
   165  
   166  func TestUnixSocketDiagInfo(t *testing.T) {
   167  	want := syscall.AF_UNIX
   168  	result, err := UnixSocketDiagInfo()
   169  	if err != nil {
   170  		t.Fatal(err)
   171  	}
   172  
   173  	for i, r := range result {
   174  		fmt.Println(r.DiagMsg)
   175  		if got := r.DiagMsg.Family; got != uint8(want) {
   176  			t.Fatalf("%d: protocol family = %v, want %v", i, got, want)
   177  		}
   178  	}
   179  }