github.com/cilium/cilium@v1.16.2/pkg/datapath/sockets/sockets.go (about) 1 // SPDX-License-Identifier: Apache-2.0 2 // Copyright Authors of Cilium 3 4 package sockets 5 6 import ( 7 "encoding/binary" 8 "errors" 9 "fmt" 10 "net" 11 "syscall" 12 13 "github.com/sirupsen/logrus" 14 "github.com/vishvananda/netlink" 15 "github.com/vishvananda/netlink/nl" 16 "golang.org/x/sys/unix" 17 18 "github.com/cilium/cilium/pkg/logging" 19 "github.com/cilium/cilium/pkg/logging/logfields" 20 ) 21 22 const ( 23 sizeofSocketID = 0x30 24 sizeofSocketRequest = sizeofSocketID + 0x8 25 sizeofSocket = sizeofSocketID + 0x18 26 SOCK_DESTROY = 21 27 ) 28 29 var ( 30 log = logging.DefaultLogger.WithField(logfields.LogSubsys, "datapath-sockets") 31 native = nl.NativeEndian() 32 networkOrder = binary.BigEndian 33 ) 34 35 type SocketDestroyer interface { 36 Destroy(filter SocketFilter) error 37 } 38 39 type SocketFilter struct { 40 DestIp net.IP 41 DestPort uint16 42 Family uint8 43 Protocol uint8 44 // Optional callback function to determine whether a filtered socket needs to be destroyed 45 DestroyCB DestroySocketCB 46 } 47 48 type DestroySocketCB func(id netlink.SocketID) bool 49 50 // Destroy destroys sockets matching the passed filter parameters using the 51 // sock_diag netlink framework. 52 // 53 // Supported families in the filter: syscall.AF_INET, syscall.AF_INET6 54 // Supported protocols in the filter: unix.IPPROTO_UDP 55 func Destroy(filter SocketFilter) error { 56 family := filter.Family 57 protocol := filter.Protocol 58 59 if family != syscall.AF_INET && family != syscall.AF_INET6 { 60 return fmt.Errorf("unsupported family for socket destroy: %d", family) 61 } 62 var errs error 63 success, failed := 0, 0 64 65 // Query sockets matching the passed filter, and then destroy the filtered 66 // sockets. 67 switch protocol { 68 case unix.IPPROTO_UDP: 69 err := filterAndDestroyUDPSockets(family, func(sock netlink.SocketID, err error) { 70 if err != nil { 71 errs = errors.Join(errs, fmt.Errorf("UDP socket with filter [%v]: %w", filter, err)) 72 failed++ 73 return 74 } 75 if filter.MatchSocket(sock) { 76 log.Infof("socket %v", sock) 77 if err := destroySocket(sock, family, unix.IPPROTO_UDP); err != nil { 78 errs = errors.Join(errs, fmt.Errorf("destroying UDP socket with filter [%v]: %w", filter, err)) 79 failed++ 80 return 81 } 82 log.Debugf("Destroyed socket: %v", sock) 83 success++ 84 } 85 }) 86 if err != nil { 87 return fmt.Errorf("failed to get sockets with filter %v: %w", filter, err) 88 } 89 90 default: 91 return fmt.Errorf("unsupported protocol for socket destroy: %d", protocol) 92 } 93 if success > 0 || failed > 0 || errs != nil { 94 log.WithFields(logrus.Fields{ 95 "filter": filter, 96 "success": success, 97 "failed": failed, 98 "errors": errs, 99 }).Info("Forcefully terminated sockets") 100 } 101 102 return nil 103 } 104 105 func (f *SocketFilter) MatchSocket(socket netlink.SocketID) bool { 106 if socket.Destination.Equal(f.DestIp) && socket.DestinationPort == f.DestPort { 107 if f.DestroyCB == nil || f.DestroyCB(socket) { 108 return true 109 } 110 } 111 112 return false 113 } 114 115 func filterAndDestroyUDPSockets(family uint8, socketCB func(socket netlink.SocketID, err error)) error { 116 err := socketDiagUDPExecutor(family, func(m syscall.NetlinkMessage) error { 117 sockInfo := &socket{} 118 err := sockInfo.deserialize(m.Data) 119 socketCB(sockInfo.ID, err) 120 return nil 121 }) 122 if err != nil { 123 return err 124 } 125 return nil 126 } 127 128 // Below handlers are adapted from netlink/socket_linux.go to avoid memory allocations. 129 130 type socketRequest struct { 131 Family uint8 132 Protocol uint8 133 Ext uint8 134 pad uint8 135 States uint32 136 ID netlink.SocketID 137 } 138 139 type writeBuffer struct { 140 Bytes []byte 141 pos int 142 } 143 144 func (b *writeBuffer) write(c byte) { 145 b.Bytes[b.pos] = c 146 b.pos++ 147 } 148 149 func (b *writeBuffer) next(n int) []byte { 150 s := b.Bytes[b.pos : b.pos+n] 151 b.pos += n 152 return s 153 } 154 155 func (r *socketRequest) Serialize() []byte { 156 b := writeBuffer{Bytes: make([]byte, sizeofSocketRequest)} 157 b.write(r.Family) 158 b.write(r.Protocol) 159 b.write(r.Ext) 160 b.write(r.pad) 161 native.PutUint32(b.next(4), r.States) 162 networkOrder.PutUint16(b.next(2), r.ID.SourcePort) 163 networkOrder.PutUint16(b.next(2), r.ID.DestinationPort) 164 if r.Family == unix.AF_INET6 { 165 copy(b.next(16), r.ID.Source) 166 copy(b.next(16), r.ID.Destination) 167 } else { 168 copy(b.next(4), r.ID.Source.To4()) 169 b.next(12) 170 copy(b.next(4), r.ID.Destination.To4()) 171 b.next(12) 172 } 173 native.PutUint32(b.next(4), r.ID.Interface) 174 native.PutUint32(b.next(4), r.ID.Cookie[0]) 175 native.PutUint32(b.next(4), r.ID.Cookie[1]) 176 return b.Bytes 177 } 178 179 func (r *socketRequest) Len() int { return sizeofSocketRequest } 180 181 type readBuffer struct { 182 Bytes []byte 183 pos int 184 } 185 186 func (b *readBuffer) Read() byte { 187 c := b.Bytes[b.pos] 188 b.pos++ 189 return c 190 } 191 192 func (b *readBuffer) Next(n int) []byte { 193 s := b.Bytes[b.pos : b.pos+n] 194 b.pos += n 195 return s 196 } 197 198 type socket netlink.Socket 199 200 func (s *socket) deserialize(b []byte) error { 201 if len(b) < sizeofSocket { 202 return fmt.Errorf("socket data short read (%d); want %d", len(b), sizeofSocket) 203 } 204 rb := readBuffer{Bytes: b} 205 s.Family = rb.Read() 206 s.State = rb.Read() 207 s.Timer = rb.Read() 208 s.Retrans = rb.Read() 209 s.ID.SourcePort = networkOrder.Uint16(rb.Next(2)) 210 s.ID.DestinationPort = networkOrder.Uint16(rb.Next(2)) 211 if s.Family == unix.AF_INET6 { 212 s.ID.Source = net.IP(rb.Next(16)) 213 s.ID.Destination = net.IP(rb.Next(16)) 214 } else { 215 s.ID.Source = net.IPv4(rb.Read(), rb.Read(), rb.Read(), rb.Read()) 216 rb.Next(12) 217 s.ID.Destination = net.IPv4(rb.Read(), rb.Read(), rb.Read(), rb.Read()) 218 rb.Next(12) 219 } 220 s.ID.Interface = native.Uint32(rb.Next(4)) 221 s.ID.Cookie[0] = native.Uint32(rb.Next(4)) 222 s.ID.Cookie[1] = native.Uint32(rb.Next(4)) 223 s.Expires = native.Uint32(rb.Next(4)) 224 s.RQueue = native.Uint32(rb.Next(4)) 225 s.WQueue = native.Uint32(rb.Next(4)) 226 s.UID = native.Uint32(rb.Next(4)) 227 s.INode = native.Uint32(rb.Next(4)) 228 return nil 229 } 230 231 func destroySocket(sockId netlink.SocketID, family uint8, protocol uint8) error { 232 s, err := nl.Subscribe(unix.NETLINK_INET_DIAG) 233 if err != nil { 234 return err 235 } 236 defer s.Close() 237 238 req := nl.NewNetlinkRequest(SOCK_DESTROY, unix.NLM_F_REQUEST) 239 req.AddData(&socketRequest{ 240 Family: family, 241 Protocol: protocol, 242 States: uint32(0xfff), 243 ID: sockId, 244 }) 245 err = s.Send(req) 246 if err != nil { 247 fmt.Printf("error in destroying socket: %v", sockId) 248 } 249 return err 250 } 251 252 func socketDiagUDPExecutor(family uint8, receiver func(message syscall.NetlinkMessage) error) error { 253 s, err := nl.Subscribe(unix.NETLINK_INET_DIAG) 254 if err != nil { 255 return err 256 } 257 defer s.Close() 258 259 req := nl.NewNetlinkRequest(nl.SOCK_DIAG_BY_FAMILY, unix.NLM_F_DUMP) 260 req.AddData(&socketRequest{ 261 Family: family, 262 Protocol: unix.IPPROTO_UDP, 263 States: uint32(0xfff), 264 }) 265 s.Send(req) 266 267 loop: 268 for { 269 msgs, from, err := s.Receive() 270 if err != nil { 271 return err 272 } 273 if from.Pid != nl.PidKernel { 274 return fmt.Errorf("Wrong sender portid %d, expected %d", from.Pid, nl.PidKernel) 275 } 276 if len(msgs) == 0 { 277 return errors.New("no message nor error from netlink") 278 } 279 280 for _, m := range msgs { 281 switch m.Header.Type { 282 case unix.NLMSG_DONE: 283 break loop 284 case unix.NLMSG_ERROR: 285 error := int32(native.Uint32(m.Data[0:4])) 286 return syscall.Errno(-error) 287 } 288 if err := receiver(m); err != nil { 289 return err 290 } 291 } 292 } 293 return nil 294 }