github.com/sagernet/netlink@v0.0.0-20240612041022-b9a21c07ac6a/socket_linux.go (about) 1 package netlink 2 3 import ( 4 "errors" 5 "fmt" 6 "net" 7 "syscall" 8 9 "github.com/sagernet/netlink/nl" 10 "golang.org/x/sys/unix" 11 ) 12 13 const ( 14 sizeofSocketID = 0x30 15 sizeofSocketRequest = sizeofSocketID + 0x8 16 sizeofSocket = sizeofSocketID + 0x18 17 ) 18 19 type socketRequest struct { 20 Family uint8 21 Protocol uint8 22 Ext uint8 23 pad uint8 24 States uint32 25 ID SocketID 26 } 27 28 type writeBuffer struct { 29 Bytes []byte 30 pos int 31 } 32 33 func (b *writeBuffer) Write(c byte) { 34 b.Bytes[b.pos] = c 35 b.pos++ 36 } 37 38 func (b *writeBuffer) Next(n int) []byte { 39 s := b.Bytes[b.pos : b.pos+n] 40 b.pos += n 41 return s 42 } 43 44 func (r *socketRequest) Serialize() []byte { 45 b := writeBuffer{Bytes: make([]byte, sizeofSocketRequest)} 46 b.Write(r.Family) 47 b.Write(r.Protocol) 48 b.Write(r.Ext) 49 b.Write(r.pad) 50 native.PutUint32(b.Next(4), r.States) 51 networkOrder.PutUint16(b.Next(2), r.ID.SourcePort) 52 networkOrder.PutUint16(b.Next(2), r.ID.DestinationPort) 53 if r.Family == unix.AF_INET6 { 54 copy(b.Next(16), r.ID.Source) 55 copy(b.Next(16), r.ID.Destination) 56 } else { 57 copy(b.Next(4), r.ID.Source.To4()) 58 b.Next(12) 59 copy(b.Next(4), r.ID.Destination.To4()) 60 b.Next(12) 61 } 62 native.PutUint32(b.Next(4), r.ID.Interface) 63 native.PutUint32(b.Next(4), r.ID.Cookie[0]) 64 native.PutUint32(b.Next(4), r.ID.Cookie[1]) 65 return b.Bytes 66 } 67 68 func (r *socketRequest) Len() int { return sizeofSocketRequest } 69 70 type readBuffer struct { 71 Bytes []byte 72 pos int 73 } 74 75 func (b *readBuffer) Read() byte { 76 c := b.Bytes[b.pos] 77 b.pos++ 78 return c 79 } 80 81 func (b *readBuffer) Next(n int) []byte { 82 s := b.Bytes[b.pos : b.pos+n] 83 b.pos += n 84 return s 85 } 86 87 func (s *Socket) deserialize(b []byte) error { 88 if len(b) < sizeofSocket { 89 return fmt.Errorf("socket data short read (%d); want %d", len(b), sizeofSocket) 90 } 91 rb := readBuffer{Bytes: b} 92 s.Family = rb.Read() 93 s.State = rb.Read() 94 s.Timer = rb.Read() 95 s.Retrans = rb.Read() 96 s.ID.SourcePort = networkOrder.Uint16(rb.Next(2)) 97 s.ID.DestinationPort = networkOrder.Uint16(rb.Next(2)) 98 if s.Family == unix.AF_INET6 { 99 s.ID.Source = net.IP(rb.Next(16)) 100 s.ID.Destination = net.IP(rb.Next(16)) 101 } else { 102 s.ID.Source = net.IPv4(rb.Read(), rb.Read(), rb.Read(), rb.Read()) 103 rb.Next(12) 104 s.ID.Destination = net.IPv4(rb.Read(), rb.Read(), rb.Read(), rb.Read()) 105 rb.Next(12) 106 } 107 s.ID.Interface = native.Uint32(rb.Next(4)) 108 s.ID.Cookie[0] = native.Uint32(rb.Next(4)) 109 s.ID.Cookie[1] = native.Uint32(rb.Next(4)) 110 s.Expires = native.Uint32(rb.Next(4)) 111 s.RQueue = native.Uint32(rb.Next(4)) 112 s.WQueue = native.Uint32(rb.Next(4)) 113 s.UID = native.Uint32(rb.Next(4)) 114 s.INode = native.Uint32(rb.Next(4)) 115 return nil 116 } 117 118 // SocketGet returns the Socket identified by its local and remote addresses. 119 func SocketGet(local, remote net.Addr) (*Socket, error) { 120 localTCP, ok := local.(*net.TCPAddr) 121 if !ok { 122 return nil, ErrNotImplemented 123 } 124 remoteTCP, ok := remote.(*net.TCPAddr) 125 if !ok { 126 return nil, ErrNotImplemented 127 } 128 localIP := localTCP.IP.To4() 129 if localIP == nil { 130 return nil, ErrNotImplemented 131 } 132 remoteIP := remoteTCP.IP.To4() 133 if remoteIP == nil { 134 return nil, ErrNotImplemented 135 } 136 137 s, err := nl.Subscribe(unix.NETLINK_INET_DIAG) 138 if err != nil { 139 return nil, err 140 } 141 defer s.Close() 142 req := nl.NewNetlinkRequest(nl.SOCK_DIAG_BY_FAMILY, 0) 143 req.AddData(&socketRequest{ 144 Family: unix.AF_INET, 145 Protocol: unix.IPPROTO_TCP, 146 ID: SocketID{ 147 SourcePort: uint16(localTCP.Port), 148 DestinationPort: uint16(remoteTCP.Port), 149 Source: localIP, 150 Destination: remoteIP, 151 Cookie: [2]uint32{nl.TCPDIAG_NOCOOKIE, nl.TCPDIAG_NOCOOKIE}, 152 }, 153 }) 154 s.Send(req) 155 msgs, from, err := s.Receive() 156 if err != nil { 157 return nil, err 158 } 159 if from.Pid != nl.PidKernel { 160 return nil, fmt.Errorf("Wrong sender portid %d, expected %d", from.Pid, nl.PidKernel) 161 } 162 if len(msgs) == 0 { 163 return nil, errors.New("no message nor error from netlink") 164 } 165 if len(msgs) > 2 { 166 return nil, fmt.Errorf("multiple (%d) matching sockets", len(msgs)) 167 } 168 sock := &Socket{} 169 if err := sock.deserialize(msgs[0].Data); err != nil { 170 return nil, err 171 } 172 return sock, nil 173 } 174 175 // SocketDiagTCPInfo requests INET_DIAG_INFO for TCP protocol for specified family type and return with extension TCP info. 176 func SocketDiagTCPInfo(family uint8) ([]*InetDiagTCPInfoResp, error) { 177 var result []*InetDiagTCPInfoResp 178 err := socketDiagTCPExecutor(family, func(m syscall.NetlinkMessage) error { 179 sockInfo := &Socket{} 180 if err := sockInfo.deserialize(m.Data); err != nil { 181 return err 182 } 183 attrs, err := nl.ParseRouteAttr(m.Data[sizeofSocket:]) 184 if err != nil { 185 return err 186 } 187 188 res, err := attrsToInetDiagTCPInfoResp(attrs, sockInfo) 189 if err != nil { 190 return err 191 } 192 193 result = append(result, res) 194 return nil 195 }) 196 if err != nil { 197 return nil, err 198 } 199 return result, nil 200 } 201 202 // SocketDiagTCP requests INET_DIAG_INFO for TCP protocol for specified family type and return related socket. 203 func SocketDiagTCP(family uint8) ([]*Socket, error) { 204 var result []*Socket 205 err := socketDiagTCPExecutor(family, func(m syscall.NetlinkMessage) error { 206 sockInfo := &Socket{} 207 if err := sockInfo.deserialize(m.Data); err != nil { 208 return err 209 } 210 result = append(result, sockInfo) 211 return nil 212 }) 213 if err != nil { 214 return nil, err 215 } 216 return result, nil 217 } 218 219 // socketDiagTCPExecutor requests INET_DIAG_INFO for TCP protocol for specified family type. 220 func socketDiagTCPExecutor(family uint8, receiver func(syscall.NetlinkMessage) error) error { 221 s, err := nl.Subscribe(unix.NETLINK_INET_DIAG) 222 if err != nil { 223 return err 224 } 225 defer s.Close() 226 227 req := nl.NewNetlinkRequest(nl.SOCK_DIAG_BY_FAMILY, unix.NLM_F_DUMP) 228 req.AddData(&socketRequest{ 229 Family: family, 230 Protocol: unix.IPPROTO_TCP, 231 Ext: (1 << (INET_DIAG_VEGASINFO - 1)) | (1 << (INET_DIAG_INFO - 1)), 232 States: uint32(0xfff), // All TCP states 233 }) 234 s.Send(req) 235 236 loop: 237 for { 238 msgs, from, err := s.Receive() 239 if err != nil { 240 return err 241 } 242 if from.Pid != nl.PidKernel { 243 return fmt.Errorf("Wrong sender portid %d, expected %d", from.Pid, nl.PidKernel) 244 } 245 if len(msgs) == 0 { 246 return errors.New("no message nor error from netlink") 247 } 248 249 for _, m := range msgs { 250 switch m.Header.Type { 251 case unix.NLMSG_DONE: 252 break loop 253 case unix.NLMSG_ERROR: 254 error := int32(native.Uint32(m.Data[0:4])) 255 return syscall.Errno(-error) 256 } 257 if err := receiver(m); err != nil { 258 return err 259 } 260 } 261 } 262 return nil 263 } 264 265 func attrsToInetDiagTCPInfoResp(attrs []syscall.NetlinkRouteAttr, sockInfo *Socket) (*InetDiagTCPInfoResp, error) { 266 var tcpInfo *TCPInfo 267 var tcpBBRInfo *TCPBBRInfo 268 for _, a := range attrs { 269 if a.Attr.Type == INET_DIAG_INFO { 270 tcpInfo = &TCPInfo{} 271 if err := tcpInfo.deserialize(a.Value); err != nil { 272 return nil, err 273 } 274 continue 275 } 276 277 if a.Attr.Type == INET_DIAG_BBRINFO { 278 tcpBBRInfo = &TCPBBRInfo{} 279 if err := tcpBBRInfo.deserialize(a.Value); err != nil { 280 return nil, err 281 } 282 continue 283 } 284 } 285 286 return &InetDiagTCPInfoResp{ 287 InetDiagMsg: sockInfo, 288 TCPInfo: tcpInfo, 289 TCPBBRInfo: tcpBBRInfo, 290 }, nil 291 }