github.com/sagernet/netlink@v0.0.0-20240612041022-b9a21c07ac6a/socket_linux.go (about)

     1  package netlink
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"net"
     7  	"syscall"
     8  
     9  	"github.com/sagernet/netlink/nl"
    10  	"golang.org/x/sys/unix"
    11  )
    12  
    13  const (
    14  	sizeofSocketID      = 0x30
    15  	sizeofSocketRequest = sizeofSocketID + 0x8
    16  	sizeofSocket        = sizeofSocketID + 0x18
    17  )
    18  
    19  type socketRequest struct {
    20  	Family   uint8
    21  	Protocol uint8
    22  	Ext      uint8
    23  	pad      uint8
    24  	States   uint32
    25  	ID       SocketID
    26  }
    27  
    28  type writeBuffer struct {
    29  	Bytes []byte
    30  	pos   int
    31  }
    32  
    33  func (b *writeBuffer) Write(c byte) {
    34  	b.Bytes[b.pos] = c
    35  	b.pos++
    36  }
    37  
    38  func (b *writeBuffer) Next(n int) []byte {
    39  	s := b.Bytes[b.pos : b.pos+n]
    40  	b.pos += n
    41  	return s
    42  }
    43  
    44  func (r *socketRequest) Serialize() []byte {
    45  	b := writeBuffer{Bytes: make([]byte, sizeofSocketRequest)}
    46  	b.Write(r.Family)
    47  	b.Write(r.Protocol)
    48  	b.Write(r.Ext)
    49  	b.Write(r.pad)
    50  	native.PutUint32(b.Next(4), r.States)
    51  	networkOrder.PutUint16(b.Next(2), r.ID.SourcePort)
    52  	networkOrder.PutUint16(b.Next(2), r.ID.DestinationPort)
    53  	if r.Family == unix.AF_INET6 {
    54  		copy(b.Next(16), r.ID.Source)
    55  		copy(b.Next(16), r.ID.Destination)
    56  	} else {
    57  		copy(b.Next(4), r.ID.Source.To4())
    58  		b.Next(12)
    59  		copy(b.Next(4), r.ID.Destination.To4())
    60  		b.Next(12)
    61  	}
    62  	native.PutUint32(b.Next(4), r.ID.Interface)
    63  	native.PutUint32(b.Next(4), r.ID.Cookie[0])
    64  	native.PutUint32(b.Next(4), r.ID.Cookie[1])
    65  	return b.Bytes
    66  }
    67  
    68  func (r *socketRequest) Len() int { return sizeofSocketRequest }
    69  
    70  type readBuffer struct {
    71  	Bytes []byte
    72  	pos   int
    73  }
    74  
    75  func (b *readBuffer) Read() byte {
    76  	c := b.Bytes[b.pos]
    77  	b.pos++
    78  	return c
    79  }
    80  
    81  func (b *readBuffer) Next(n int) []byte {
    82  	s := b.Bytes[b.pos : b.pos+n]
    83  	b.pos += n
    84  	return s
    85  }
    86  
    87  func (s *Socket) deserialize(b []byte) error {
    88  	if len(b) < sizeofSocket {
    89  		return fmt.Errorf("socket data short read (%d); want %d", len(b), sizeofSocket)
    90  	}
    91  	rb := readBuffer{Bytes: b}
    92  	s.Family = rb.Read()
    93  	s.State = rb.Read()
    94  	s.Timer = rb.Read()
    95  	s.Retrans = rb.Read()
    96  	s.ID.SourcePort = networkOrder.Uint16(rb.Next(2))
    97  	s.ID.DestinationPort = networkOrder.Uint16(rb.Next(2))
    98  	if s.Family == unix.AF_INET6 {
    99  		s.ID.Source = net.IP(rb.Next(16))
   100  		s.ID.Destination = net.IP(rb.Next(16))
   101  	} else {
   102  		s.ID.Source = net.IPv4(rb.Read(), rb.Read(), rb.Read(), rb.Read())
   103  		rb.Next(12)
   104  		s.ID.Destination = net.IPv4(rb.Read(), rb.Read(), rb.Read(), rb.Read())
   105  		rb.Next(12)
   106  	}
   107  	s.ID.Interface = native.Uint32(rb.Next(4))
   108  	s.ID.Cookie[0] = native.Uint32(rb.Next(4))
   109  	s.ID.Cookie[1] = native.Uint32(rb.Next(4))
   110  	s.Expires = native.Uint32(rb.Next(4))
   111  	s.RQueue = native.Uint32(rb.Next(4))
   112  	s.WQueue = native.Uint32(rb.Next(4))
   113  	s.UID = native.Uint32(rb.Next(4))
   114  	s.INode = native.Uint32(rb.Next(4))
   115  	return nil
   116  }
   117  
   118  // SocketGet returns the Socket identified by its local and remote addresses.
   119  func SocketGet(local, remote net.Addr) (*Socket, error) {
   120  	localTCP, ok := local.(*net.TCPAddr)
   121  	if !ok {
   122  		return nil, ErrNotImplemented
   123  	}
   124  	remoteTCP, ok := remote.(*net.TCPAddr)
   125  	if !ok {
   126  		return nil, ErrNotImplemented
   127  	}
   128  	localIP := localTCP.IP.To4()
   129  	if localIP == nil {
   130  		return nil, ErrNotImplemented
   131  	}
   132  	remoteIP := remoteTCP.IP.To4()
   133  	if remoteIP == nil {
   134  		return nil, ErrNotImplemented
   135  	}
   136  
   137  	s, err := nl.Subscribe(unix.NETLINK_INET_DIAG)
   138  	if err != nil {
   139  		return nil, err
   140  	}
   141  	defer s.Close()
   142  	req := nl.NewNetlinkRequest(nl.SOCK_DIAG_BY_FAMILY, 0)
   143  	req.AddData(&socketRequest{
   144  		Family:   unix.AF_INET,
   145  		Protocol: unix.IPPROTO_TCP,
   146  		ID: SocketID{
   147  			SourcePort:      uint16(localTCP.Port),
   148  			DestinationPort: uint16(remoteTCP.Port),
   149  			Source:          localIP,
   150  			Destination:     remoteIP,
   151  			Cookie:          [2]uint32{nl.TCPDIAG_NOCOOKIE, nl.TCPDIAG_NOCOOKIE},
   152  		},
   153  	})
   154  	s.Send(req)
   155  	msgs, from, err := s.Receive()
   156  	if err != nil {
   157  		return nil, err
   158  	}
   159  	if from.Pid != nl.PidKernel {
   160  		return nil, fmt.Errorf("Wrong sender portid %d, expected %d", from.Pid, nl.PidKernel)
   161  	}
   162  	if len(msgs) == 0 {
   163  		return nil, errors.New("no message nor error from netlink")
   164  	}
   165  	if len(msgs) > 2 {
   166  		return nil, fmt.Errorf("multiple (%d) matching sockets", len(msgs))
   167  	}
   168  	sock := &Socket{}
   169  	if err := sock.deserialize(msgs[0].Data); err != nil {
   170  		return nil, err
   171  	}
   172  	return sock, nil
   173  }
   174  
   175  // SocketDiagTCPInfo requests INET_DIAG_INFO for TCP protocol for specified family type and return with extension TCP info.
   176  func SocketDiagTCPInfo(family uint8) ([]*InetDiagTCPInfoResp, error) {
   177  	var result []*InetDiagTCPInfoResp
   178  	err := socketDiagTCPExecutor(family, func(m syscall.NetlinkMessage) error {
   179  		sockInfo := &Socket{}
   180  		if err := sockInfo.deserialize(m.Data); err != nil {
   181  			return err
   182  		}
   183  		attrs, err := nl.ParseRouteAttr(m.Data[sizeofSocket:])
   184  		if err != nil {
   185  			return err
   186  		}
   187  
   188  		res, err := attrsToInetDiagTCPInfoResp(attrs, sockInfo)
   189  		if err != nil {
   190  			return err
   191  		}
   192  
   193  		result = append(result, res)
   194  		return nil
   195  	})
   196  	if err != nil {
   197  		return nil, err
   198  	}
   199  	return result, nil
   200  }
   201  
   202  // SocketDiagTCP requests INET_DIAG_INFO for TCP protocol for specified family type and return related socket.
   203  func SocketDiagTCP(family uint8) ([]*Socket, error) {
   204  	var result []*Socket
   205  	err := socketDiagTCPExecutor(family, func(m syscall.NetlinkMessage) error {
   206  		sockInfo := &Socket{}
   207  		if err := sockInfo.deserialize(m.Data); err != nil {
   208  			return err
   209  		}
   210  		result = append(result, sockInfo)
   211  		return nil
   212  	})
   213  	if err != nil {
   214  		return nil, err
   215  	}
   216  	return result, nil
   217  }
   218  
   219  // socketDiagTCPExecutor requests INET_DIAG_INFO for TCP protocol for specified family type.
   220  func socketDiagTCPExecutor(family uint8, receiver func(syscall.NetlinkMessage) error) error {
   221  	s, err := nl.Subscribe(unix.NETLINK_INET_DIAG)
   222  	if err != nil {
   223  		return err
   224  	}
   225  	defer s.Close()
   226  
   227  	req := nl.NewNetlinkRequest(nl.SOCK_DIAG_BY_FAMILY, unix.NLM_F_DUMP)
   228  	req.AddData(&socketRequest{
   229  		Family:   family,
   230  		Protocol: unix.IPPROTO_TCP,
   231  		Ext:      (1 << (INET_DIAG_VEGASINFO - 1)) | (1 << (INET_DIAG_INFO - 1)),
   232  		States:   uint32(0xfff), // All TCP states
   233  	})
   234  	s.Send(req)
   235  
   236  loop:
   237  	for {
   238  		msgs, from, err := s.Receive()
   239  		if err != nil {
   240  			return err
   241  		}
   242  		if from.Pid != nl.PidKernel {
   243  			return fmt.Errorf("Wrong sender portid %d, expected %d", from.Pid, nl.PidKernel)
   244  		}
   245  		if len(msgs) == 0 {
   246  			return errors.New("no message nor error from netlink")
   247  		}
   248  
   249  		for _, m := range msgs {
   250  			switch m.Header.Type {
   251  			case unix.NLMSG_DONE:
   252  				break loop
   253  			case unix.NLMSG_ERROR:
   254  				error := int32(native.Uint32(m.Data[0:4]))
   255  				return syscall.Errno(-error)
   256  			}
   257  			if err := receiver(m); err != nil {
   258  				return err
   259  			}
   260  		}
   261  	}
   262  	return nil
   263  }
   264  
   265  func attrsToInetDiagTCPInfoResp(attrs []syscall.NetlinkRouteAttr, sockInfo *Socket) (*InetDiagTCPInfoResp, error) {
   266  	var tcpInfo *TCPInfo
   267  	var tcpBBRInfo *TCPBBRInfo
   268  	for _, a := range attrs {
   269  		if a.Attr.Type == INET_DIAG_INFO {
   270  			tcpInfo = &TCPInfo{}
   271  			if err := tcpInfo.deserialize(a.Value); err != nil {
   272  				return nil, err
   273  			}
   274  			continue
   275  		}
   276  
   277  		if a.Attr.Type == INET_DIAG_BBRINFO {
   278  			tcpBBRInfo = &TCPBBRInfo{}
   279  			if err := tcpBBRInfo.deserialize(a.Value); err != nil {
   280  				return nil, err
   281  			}
   282  			continue
   283  		}
   284  	}
   285  
   286  	return &InetDiagTCPInfoResp{
   287  		InetDiagMsg: sockInfo,
   288  		TCPInfo:     tcpInfo,
   289  		TCPBBRInfo:  tcpBBRInfo,
   290  	}, nil
   291  }