github.com/sagernet/netlink@v0.0.0-20240612041022-b9a21c07ac6a/socket_test.go (about)

     1  //go:build linux
     2  // +build linux
     3  
     4  package netlink
     5  
     6  import (
     7  	"log"
     8  	"net"
     9  	"os/user"
    10  	"strconv"
    11  	"syscall"
    12  	"testing"
    13  )
    14  
    15  func TestSocketGet(t *testing.T) {
    16  	defer setUpNetlinkTestWithLoopback(t)()
    17  
    18  	addr, err := net.ResolveTCPAddr("tcp", "localhost:0")
    19  	if err != nil {
    20  		log.Fatal(err)
    21  	}
    22  	l, err := net.ListenTCP("tcp", addr)
    23  	if err != nil {
    24  		log.Fatal(err)
    25  	}
    26  	defer l.Close()
    27  
    28  	conn, err := net.Dial(l.Addr().Network(), l.Addr().String())
    29  	if err != nil {
    30  		t.Fatal(err)
    31  	}
    32  	defer conn.Close()
    33  
    34  	localAddr := conn.LocalAddr().(*net.TCPAddr)
    35  	remoteAddr := conn.RemoteAddr().(*net.TCPAddr)
    36  	socket, err := SocketGet(localAddr, remoteAddr)
    37  	if err != nil {
    38  		t.Fatal(err)
    39  	}
    40  
    41  	if got, want := socket.ID.Source, localAddr.IP; !got.Equal(want) {
    42  		t.Fatalf("local ip = %v, want %v", got, want)
    43  	}
    44  	if got, want := socket.ID.Destination, remoteAddr.IP; !got.Equal(want) {
    45  		t.Fatalf("remote ip = %v, want %v", got, want)
    46  	}
    47  	if got, want := int(socket.ID.SourcePort), localAddr.Port; got != want {
    48  		t.Fatalf("local port = %d, want %d", got, want)
    49  	}
    50  	if got, want := int(socket.ID.DestinationPort), remoteAddr.Port; got != want {
    51  		t.Fatalf("remote port = %d, want %d", got, want)
    52  	}
    53  	u, err := user.Current()
    54  	if err != nil {
    55  		t.Fatal(err)
    56  	}
    57  	if got, want := strconv.Itoa(int(socket.UID)), u.Uid; got != want {
    58  		t.Fatalf("UID = %s, want %s", got, want)
    59  	}
    60  }
    61  
    62  func TestSocketDiagTCPInfo(t *testing.T) {
    63  	Family4 := uint8(syscall.AF_INET)
    64  	Family6 := uint8(syscall.AF_INET6)
    65  	families := []uint8{Family4, Family6}
    66  	for _, wantFamily := range families {
    67  		res, err := SocketDiagTCPInfo(wantFamily)
    68  		if err != nil {
    69  			t.Fatal(err)
    70  		}
    71  		for _, i := range res {
    72  			gotFamily := i.InetDiagMsg.Family
    73  			if gotFamily != wantFamily {
    74  				t.Fatalf("Socket family = %d, want %d", gotFamily, wantFamily)
    75  			}
    76  		}
    77  	}
    78  }