github.com/inspektor-gadget/inspektor-gadget@v0.28.1/pkg/socketenricher/bpf/sockets-iter.bpf.c (about)

     1  // SPDX-License-Identifier: (GPL-2.0 WITH Linux-syscall-note) OR Apache-2.0
     2  /* Copyright (c) 2023 The Inspektor Gadget authors */
     3  
     4  #include <vmlinux.h>
     5  #include <bpf/bpf_helpers.h>
     6  #include <bpf/bpf_core_read.h>
     7  #include <bpf/bpf_tracing.h>
     8  #include <bpf/bpf_endian.h>
     9  
    10  #include <bpf/bpf_helpers.h>
    11  
    12  #include <gadget/sockets-map.h>
    13  #include "socket-enricher-helpers.h"
    14  
    15  const volatile __u64 socket_file_ops_addr = 0;
    16  
    17  static __always_inline void insert_socket_from_iter(struct sock *sock,
    18  						    struct task_struct *task)
    19  {
    20  	struct sockets_key socket_key = {
    21  		0,
    22  	};
    23  	prepare_socket_key(&socket_key, sock);
    24  
    25  	struct sockets_value socket_value = {
    26  		0,
    27  	};
    28  	// use given task
    29  	socket_value.pid_tgid = ((u64)task->tgid) << 32 | task->pid;
    30  	// The VFS code might temporary substitute task->cred by other creds during overlayfs
    31  	// copyup. In this case, we want the real creds of the process, not the creds temporarily
    32  	// substituted by VFS overlayfs copyup.
    33  	// https://kernel.org/doc/html/v6.2-rc8/security/credentials.html#overriding-the-vfs-s-use-of-credentials
    34  	socket_value.uid_gid = ((u64)task->real_cred->gid.val) << 32 |
    35  			       task->real_cred->uid.val;
    36  	__builtin_memcpy(&socket_value.task, task->comm,
    37  			 sizeof(socket_value.task));
    38  	socket_value.mntns = (u64)task->nsproxy->mnt_ns->ns.inum;
    39  	socket_value.sock = (__u64)sock;
    40  	socket_value.ipv6only =
    41  		BPF_CORE_READ_BITFIELD_PROBED(sock, __sk_common.skc_ipv6only);
    42  
    43  	// If the endpoint was not present, add it and we're done.
    44  	struct sockets_value *old_socket_value =
    45  		(struct sockets_value *)bpf_map_lookup_elem(&gadget_sockets,
    46  							    &socket_key);
    47  	if (!old_socket_value) {
    48  		// Use BPF_NOEXIST: if an entry was inserted just after the check, this
    49  		// is because the bpf iterator for initial sockets runs in
    50  		// parallel to other kprobes and we prefer the information from the
    51  		// other kprobes because their data is more accurate (e.g. correct
    52  		// thread).
    53  		bpf_map_update_elem(&gadget_sockets, &socket_key, &socket_value,
    54  				    BPF_NOEXIST);
    55  		return;
    56  	}
    57  
    58  	// At this point, the endpoint was already present, we need to determine
    59  	// the best entry between the existing one and the new one.
    60  
    61  	// When iterating on initial sockets, we get both passive and active
    62  	// sockets (server side). We want the passive socket because we don't
    63  	// want the endpoint to be removed from the map when just one
    64  	// connection is terminated. We cannot determine if an active socket
    65  	// is server side or client side, so we add active socket anyway on the
    66  	// chance that it is client side. It will be fine for server side too,
    67  	// because the passive socket will be added later, overwriting the
    68  	// active socket.
    69  	if (BPF_CORE_READ(sock, __sk_common.skc_state) == TCP_LISTEN)
    70  		bpf_map_update_elem(&gadget_sockets, &socket_key, &socket_value,
    71  				    BPF_ANY);
    72  }
    73  
    74  // This iterates on all the sockets (from all tasks) and updates the sockets
    75  // map. This is useful to get the initial sockets that were already opened
    76  // before the socket enricher was attached.
    77  SEC("iter/task_file")
    78  int ig_sockets_it(struct bpf_iter__task_file *ctx)
    79  {
    80  	struct file *file = ctx->file;
    81  	struct task_struct *task = ctx->task;
    82  
    83  	if (!file || !task)
    84  		return 0;
    85  
    86  	// Check that the file descriptor is a socket.
    87  	// TODO: cilium/ebpf doesn't support .ksyms, so we get the address of
    88  	// socket_file_ops from userspace.
    89  	// See: https://github.com/cilium/ebpf/issues/761
    90  	if (socket_file_ops_addr == 0 ||
    91  	    (__u64)(file->f_op) != socket_file_ops_addr)
    92  		return 0;
    93  
    94  	// file->private_data is a struct socket because we checked f_op.
    95  	struct socket *socket = (struct socket *)file->private_data;
    96  	struct sock *sock = BPF_CORE_READ(socket, sk);
    97  	__u16 family = BPF_CORE_READ(sock, __sk_common.skc_family);
    98  	if (family != AF_INET && family != AF_INET6)
    99  		return 0;
   100  
   101  	// Since the iterator is not executed from the context of the process that
   102  	// opened the socket, we need to pass the task_struct to the map.
   103  	insert_socket_from_iter(sock, task);
   104  	return 0;
   105  }
   106  
   107  // This iterator is called from a Go Ticker to remove expired sockets
   108  SEC("iter/bpf_map_elem")
   109  int ig_sk_cleanup(struct bpf_iter__bpf_map_elem *ctx)
   110  {
   111  	struct seq_file *seq = ctx->meta->seq;
   112  	__u32 seq_num = ctx->meta->seq_num;
   113  	struct bpf_map *map = ctx->map;
   114  	struct sockets_key *socket_key = ctx->key;
   115  	struct sockets_key tmp_key;
   116  	struct sockets_value *socket_value = ctx->value;
   117  
   118  	if (!socket_key || !socket_value)
   119  		return 0;
   120  
   121  	__u64 now = bpf_ktime_get_ns();
   122  	__u64 deletion_timestamp = socket_value->deletion_timestamp;
   123  	__u64 socket_expiration_ns =
   124  		1000ULL * 1000ULL * 1000ULL * 5ULL; // 5 seconds
   125  
   126  	if (deletion_timestamp != 0 &&
   127  	    deletion_timestamp + socket_expiration_ns < now) {
   128  		// The socket is expired, remove it from the map.
   129  		__builtin_memcpy(&tmp_key, socket_key,
   130  				 sizeof(struct sockets_key));
   131  		bpf_map_delete_elem(&gadget_sockets, &tmp_key);
   132  		return 0;
   133  	}
   134  
   135  	return 0;
   136  }
   137  
   138  char _license[] SEC("license") = "GPL";