github.com/cilium/cilium@v1.16.2/pkg/datapath/sockets/sockets.go (about)

     1  // SPDX-License-Identifier: Apache-2.0
     2  // Copyright Authors of Cilium
     3  
     4  package sockets
     5  
     6  import (
     7  	"encoding/binary"
     8  	"errors"
     9  	"fmt"
    10  	"net"
    11  	"syscall"
    12  
    13  	"github.com/sirupsen/logrus"
    14  	"github.com/vishvananda/netlink"
    15  	"github.com/vishvananda/netlink/nl"
    16  	"golang.org/x/sys/unix"
    17  
    18  	"github.com/cilium/cilium/pkg/logging"
    19  	"github.com/cilium/cilium/pkg/logging/logfields"
    20  )
    21  
    22  const (
    23  	sizeofSocketID      = 0x30
    24  	sizeofSocketRequest = sizeofSocketID + 0x8
    25  	sizeofSocket        = sizeofSocketID + 0x18
    26  	SOCK_DESTROY        = 21
    27  )
    28  
    29  var (
    30  	log          = logging.DefaultLogger.WithField(logfields.LogSubsys, "datapath-sockets")
    31  	native       = nl.NativeEndian()
    32  	networkOrder = binary.BigEndian
    33  )
    34  
    35  type SocketDestroyer interface {
    36  	Destroy(filter SocketFilter) error
    37  }
    38  
    39  type SocketFilter struct {
    40  	DestIp   net.IP
    41  	DestPort uint16
    42  	Family   uint8
    43  	Protocol uint8
    44  	// Optional callback function to determine whether a filtered socket needs to be destroyed
    45  	DestroyCB DestroySocketCB
    46  }
    47  
    48  type DestroySocketCB func(id netlink.SocketID) bool
    49  
    50  // Destroy destroys sockets matching the passed filter parameters using the
    51  // sock_diag netlink framework.
    52  //
    53  // Supported families in the filter: syscall.AF_INET, syscall.AF_INET6
    54  // Supported protocols in the filter: unix.IPPROTO_UDP
    55  func Destroy(filter SocketFilter) error {
    56  	family := filter.Family
    57  	protocol := filter.Protocol
    58  
    59  	if family != syscall.AF_INET && family != syscall.AF_INET6 {
    60  		return fmt.Errorf("unsupported family for socket destroy: %d", family)
    61  	}
    62  	var errs error
    63  	success, failed := 0, 0
    64  
    65  	// Query sockets matching the passed filter, and then destroy the filtered
    66  	// sockets.
    67  	switch protocol {
    68  	case unix.IPPROTO_UDP:
    69  		err := filterAndDestroyUDPSockets(family, func(sock netlink.SocketID, err error) {
    70  			if err != nil {
    71  				errs = errors.Join(errs, fmt.Errorf("UDP socket with filter [%v]: %w", filter, err))
    72  				failed++
    73  				return
    74  			}
    75  			if filter.MatchSocket(sock) {
    76  				log.Infof("socket %v", sock)
    77  				if err := destroySocket(sock, family, unix.IPPROTO_UDP); err != nil {
    78  					errs = errors.Join(errs, fmt.Errorf("destroying UDP socket with filter [%v]: %w", filter, err))
    79  					failed++
    80  					return
    81  				}
    82  				log.Debugf("Destroyed socket: %v", sock)
    83  				success++
    84  			}
    85  		})
    86  		if err != nil {
    87  			return fmt.Errorf("failed to get sockets with filter %v: %w", filter, err)
    88  		}
    89  
    90  	default:
    91  		return fmt.Errorf("unsupported protocol for socket destroy: %d", protocol)
    92  	}
    93  	if success > 0 || failed > 0 || errs != nil {
    94  		log.WithFields(logrus.Fields{
    95  			"filter":  filter,
    96  			"success": success,
    97  			"failed":  failed,
    98  			"errors":  errs,
    99  		}).Info("Forcefully terminated sockets")
   100  	}
   101  
   102  	return nil
   103  }
   104  
   105  func (f *SocketFilter) MatchSocket(socket netlink.SocketID) bool {
   106  	if socket.Destination.Equal(f.DestIp) && socket.DestinationPort == f.DestPort {
   107  		if f.DestroyCB == nil || f.DestroyCB(socket) {
   108  			return true
   109  		}
   110  	}
   111  
   112  	return false
   113  }
   114  
   115  func filterAndDestroyUDPSockets(family uint8, socketCB func(socket netlink.SocketID, err error)) error {
   116  	err := socketDiagUDPExecutor(family, func(m syscall.NetlinkMessage) error {
   117  		sockInfo := &socket{}
   118  		err := sockInfo.deserialize(m.Data)
   119  		socketCB(sockInfo.ID, err)
   120  		return nil
   121  	})
   122  	if err != nil {
   123  		return err
   124  	}
   125  	return nil
   126  }
   127  
   128  // Below handlers are adapted from netlink/socket_linux.go to avoid memory allocations.
   129  
   130  type socketRequest struct {
   131  	Family   uint8
   132  	Protocol uint8
   133  	Ext      uint8
   134  	pad      uint8
   135  	States   uint32
   136  	ID       netlink.SocketID
   137  }
   138  
   139  type writeBuffer struct {
   140  	Bytes []byte
   141  	pos   int
   142  }
   143  
   144  func (b *writeBuffer) write(c byte) {
   145  	b.Bytes[b.pos] = c
   146  	b.pos++
   147  }
   148  
   149  func (b *writeBuffer) next(n int) []byte {
   150  	s := b.Bytes[b.pos : b.pos+n]
   151  	b.pos += n
   152  	return s
   153  }
   154  
   155  func (r *socketRequest) Serialize() []byte {
   156  	b := writeBuffer{Bytes: make([]byte, sizeofSocketRequest)}
   157  	b.write(r.Family)
   158  	b.write(r.Protocol)
   159  	b.write(r.Ext)
   160  	b.write(r.pad)
   161  	native.PutUint32(b.next(4), r.States)
   162  	networkOrder.PutUint16(b.next(2), r.ID.SourcePort)
   163  	networkOrder.PutUint16(b.next(2), r.ID.DestinationPort)
   164  	if r.Family == unix.AF_INET6 {
   165  		copy(b.next(16), r.ID.Source)
   166  		copy(b.next(16), r.ID.Destination)
   167  	} else {
   168  		copy(b.next(4), r.ID.Source.To4())
   169  		b.next(12)
   170  		copy(b.next(4), r.ID.Destination.To4())
   171  		b.next(12)
   172  	}
   173  	native.PutUint32(b.next(4), r.ID.Interface)
   174  	native.PutUint32(b.next(4), r.ID.Cookie[0])
   175  	native.PutUint32(b.next(4), r.ID.Cookie[1])
   176  	return b.Bytes
   177  }
   178  
   179  func (r *socketRequest) Len() int { return sizeofSocketRequest }
   180  
   181  type readBuffer struct {
   182  	Bytes []byte
   183  	pos   int
   184  }
   185  
   186  func (b *readBuffer) Read() byte {
   187  	c := b.Bytes[b.pos]
   188  	b.pos++
   189  	return c
   190  }
   191  
   192  func (b *readBuffer) Next(n int) []byte {
   193  	s := b.Bytes[b.pos : b.pos+n]
   194  	b.pos += n
   195  	return s
   196  }
   197  
   198  type socket netlink.Socket
   199  
   200  func (s *socket) deserialize(b []byte) error {
   201  	if len(b) < sizeofSocket {
   202  		return fmt.Errorf("socket data short read (%d); want %d", len(b), sizeofSocket)
   203  	}
   204  	rb := readBuffer{Bytes: b}
   205  	s.Family = rb.Read()
   206  	s.State = rb.Read()
   207  	s.Timer = rb.Read()
   208  	s.Retrans = rb.Read()
   209  	s.ID.SourcePort = networkOrder.Uint16(rb.Next(2))
   210  	s.ID.DestinationPort = networkOrder.Uint16(rb.Next(2))
   211  	if s.Family == unix.AF_INET6 {
   212  		s.ID.Source = net.IP(rb.Next(16))
   213  		s.ID.Destination = net.IP(rb.Next(16))
   214  	} else {
   215  		s.ID.Source = net.IPv4(rb.Read(), rb.Read(), rb.Read(), rb.Read())
   216  		rb.Next(12)
   217  		s.ID.Destination = net.IPv4(rb.Read(), rb.Read(), rb.Read(), rb.Read())
   218  		rb.Next(12)
   219  	}
   220  	s.ID.Interface = native.Uint32(rb.Next(4))
   221  	s.ID.Cookie[0] = native.Uint32(rb.Next(4))
   222  	s.ID.Cookie[1] = native.Uint32(rb.Next(4))
   223  	s.Expires = native.Uint32(rb.Next(4))
   224  	s.RQueue = native.Uint32(rb.Next(4))
   225  	s.WQueue = native.Uint32(rb.Next(4))
   226  	s.UID = native.Uint32(rb.Next(4))
   227  	s.INode = native.Uint32(rb.Next(4))
   228  	return nil
   229  }
   230  
   231  func destroySocket(sockId netlink.SocketID, family uint8, protocol uint8) error {
   232  	s, err := nl.Subscribe(unix.NETLINK_INET_DIAG)
   233  	if err != nil {
   234  		return err
   235  	}
   236  	defer s.Close()
   237  
   238  	req := nl.NewNetlinkRequest(SOCK_DESTROY, unix.NLM_F_REQUEST)
   239  	req.AddData(&socketRequest{
   240  		Family:   family,
   241  		Protocol: protocol,
   242  		States:   uint32(0xfff),
   243  		ID:       sockId,
   244  	})
   245  	err = s.Send(req)
   246  	if err != nil {
   247  		fmt.Printf("error in destroying socket: %v", sockId)
   248  	}
   249  	return err
   250  }
   251  
   252  func socketDiagUDPExecutor(family uint8, receiver func(message syscall.NetlinkMessage) error) error {
   253  	s, err := nl.Subscribe(unix.NETLINK_INET_DIAG)
   254  	if err != nil {
   255  		return err
   256  	}
   257  	defer s.Close()
   258  
   259  	req := nl.NewNetlinkRequest(nl.SOCK_DIAG_BY_FAMILY, unix.NLM_F_DUMP)
   260  	req.AddData(&socketRequest{
   261  		Family:   family,
   262  		Protocol: unix.IPPROTO_UDP,
   263  		States:   uint32(0xfff),
   264  	})
   265  	s.Send(req)
   266  
   267  loop:
   268  	for {
   269  		msgs, from, err := s.Receive()
   270  		if err != nil {
   271  			return err
   272  		}
   273  		if from.Pid != nl.PidKernel {
   274  			return fmt.Errorf("Wrong sender portid %d, expected %d", from.Pid, nl.PidKernel)
   275  		}
   276  		if len(msgs) == 0 {
   277  			return errors.New("no message nor error from netlink")
   278  		}
   279  
   280  		for _, m := range msgs {
   281  			switch m.Header.Type {
   282  			case unix.NLMSG_DONE:
   283  				break loop
   284  			case unix.NLMSG_ERROR:
   285  				error := int32(native.Uint32(m.Data[0:4]))
   286  				return syscall.Errno(-error)
   287  			}
   288  			if err := receiver(m); err != nil {
   289  				return err
   290  			}
   291  		}
   292  	}
   293  	return nil
   294  }