github.com/inspektor-gadget/inspektor-gadget@v0.28.1/pkg/socketenricher/tracer_test.go (about)

     1  //go:build linux
     2  // +build linux
     3  
     4  // Copyright 2023 The Inspektor Gadget authors
     5  //
     6  // Licensed under the Apache License, Version 2.0 (the "License");
     7  // you may not use this file except in compliance with the License.
     8  // You may obtain a copy of the License at
     9  //
    10  //     http://www.apache.org/licenses/LICENSE-2.0
    11  //
    12  // Unless required by applicable law or agreed to in writing, software
    13  // distributed under the License is distributed on an "AS IS" BASIS,
    14  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    15  // See the License for the specific language governing permissions and
    16  // limitations under the License.
    17  
    18  package socketenricher
    19  
    20  import (
    21  	"fmt"
    22  	"net"
    23  	"reflect"
    24  	"testing"
    25  	"unsafe"
    26  
    27  	"golang.org/x/sys/unix"
    28  
    29  	utilstest "github.com/inspektor-gadget/inspektor-gadget/internal/test"
    30  )
    31  
    32  func TestSocketEnricherCreate(t *testing.T) {
    33  	t.Parallel()
    34  
    35  	utilstest.RequireRoot(t)
    36  	utilstest.HostInit(t)
    37  
    38  	tracer, err := NewSocketEnricher()
    39  	if err != nil {
    40  		t.Fatal(err)
    41  	}
    42  	if tracer == nil {
    43  		t.Fatal("Returned tracer was nil")
    44  	}
    45  }
    46  
    47  func TestSocketEnricherStopIdempotent(t *testing.T) {
    48  	t.Parallel()
    49  
    50  	utilstest.RequireRoot(t)
    51  	utilstest.HostInit(t)
    52  
    53  	tracer, _ := NewSocketEnricher()
    54  
    55  	// Check that a double stop doesn't cause issues
    56  	tracer.Close()
    57  	tracer.Close()
    58  }
    59  
    60  type sockOpt struct {
    61  	level int
    62  	opt   int
    63  	value int
    64  }
    65  
    66  type socketEnricherMapEntry struct {
    67  	Key   socketenricherSocketsKey
    68  	Value socketenricherSocketsValue
    69  }
    70  
    71  func TestSocketEnricherBind(t *testing.T) {
    72  	t.Parallel()
    73  
    74  	utilstest.RequireRoot(t)
    75  	utilstest.HostInit(t)
    76  
    77  	type testDefinition struct {
    78  		runnerConfig  *utilstest.RunnerConfig
    79  		generateEvent func() (uint16, int, error)
    80  		expectedEvent func(info *utilstest.RunnerInfo, port uint16) *socketEnricherMapEntry
    81  	}
    82  
    83  	stringToSlice := func(s string) (ret [16]int8) {
    84  		for i := 0; i < 16; i++ {
    85  			if i >= len(s) {
    86  				break
    87  			}
    88  			ret[i] = int8(s[i])
    89  		}
    90  		return
    91  	}
    92  
    93  	for name, test := range map[string]testDefinition{
    94  		"udp": {
    95  			generateEvent: bindSocketFn("127.0.0.1", unix.AF_INET, unix.SOCK_DGRAM, 0),
    96  			expectedEvent: func(info *utilstest.RunnerInfo, port uint16) *socketEnricherMapEntry {
    97  				return &socketEnricherMapEntry{
    98  					Key: socketenricherSocketsKey{
    99  						Netns:  uint32(info.NetworkNsID),
   100  						Family: unix.AF_INET,
   101  						Proto:  unix.IPPROTO_UDP,
   102  						Port:   port,
   103  					},
   104  					Value: socketenricherSocketsValue{
   105  						Mntns:   info.MountNsID,
   106  						PidTgid: uint64(uint32(info.Pid))<<32 + uint64(info.Tid),
   107  						Task:    stringToSlice("socketenricher."),
   108  					},
   109  				}
   110  			},
   111  		},
   112  		"udp6": {
   113  			generateEvent: bindSocketFn("::", unix.AF_INET6, unix.SOCK_DGRAM, 0),
   114  			expectedEvent: func(info *utilstest.RunnerInfo, port uint16) *socketEnricherMapEntry {
   115  				return &socketEnricherMapEntry{
   116  					Key: socketenricherSocketsKey{
   117  						Netns:  uint32(info.NetworkNsID),
   118  						Family: unix.AF_INET6,
   119  						Proto:  unix.IPPROTO_UDP,
   120  						Port:   port,
   121  					},
   122  					Value: socketenricherSocketsValue{
   123  						Mntns:   info.MountNsID,
   124  						PidTgid: uint64(uint32(info.Pid))<<32 + uint64(info.Tid),
   125  						Task:    stringToSlice("socketenricher."),
   126  					},
   127  				}
   128  			},
   129  		},
   130  		"udp6-only": {
   131  			generateEvent: func() (uint16, int, error) {
   132  				opts := []sockOpt{
   133  					{unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 1},
   134  				}
   135  				return bindSocketWithOpts("::", unix.AF_INET6, unix.SOCK_DGRAM, 0, opts)
   136  			},
   137  			expectedEvent: func(info *utilstest.RunnerInfo, port uint16) *socketEnricherMapEntry {
   138  				return &socketEnricherMapEntry{
   139  					Key: socketenricherSocketsKey{
   140  						Netns:  uint32(info.NetworkNsID),
   141  						Family: unix.AF_INET6,
   142  						Proto:  unix.IPPROTO_UDP,
   143  						Port:   port,
   144  					},
   145  					Value: socketenricherSocketsValue{
   146  						Mntns:    info.MountNsID,
   147  						PidTgid:  uint64(uint32(info.Pid))<<32 + uint64(info.Tid),
   148  						Task:     stringToSlice("socketenricher."),
   149  						Ipv6only: int8(1),
   150  					},
   151  				}
   152  			},
   153  		},
   154  		"tcp": {
   155  			generateEvent: bindSocketFn("127.0.0.1", unix.AF_INET, unix.SOCK_STREAM, 0),
   156  			expectedEvent: func(info *utilstest.RunnerInfo, port uint16) *socketEnricherMapEntry {
   157  				return &socketEnricherMapEntry{
   158  					Key: socketenricherSocketsKey{
   159  						Netns:  uint32(info.NetworkNsID),
   160  						Family: unix.AF_INET,
   161  						Proto:  unix.IPPROTO_TCP,
   162  						Port:   port,
   163  					},
   164  					Value: socketenricherSocketsValue{
   165  						Mntns:   info.MountNsID,
   166  						PidTgid: uint64(uint32(info.Pid))<<32 + uint64(info.Tid),
   167  						Task:    stringToSlice("socketenricher."),
   168  					},
   169  				}
   170  			},
   171  		},
   172  		"tcp6": {
   173  			generateEvent: bindSocketFn("::", unix.AF_INET6, unix.SOCK_STREAM, 0),
   174  			expectedEvent: func(info *utilstest.RunnerInfo, port uint16) *socketEnricherMapEntry {
   175  				return &socketEnricherMapEntry{
   176  					Key: socketenricherSocketsKey{
   177  						Netns:  uint32(info.NetworkNsID),
   178  						Family: unix.AF_INET6,
   179  						Proto:  unix.IPPROTO_TCP,
   180  						Port:   port,
   181  					},
   182  					Value: socketenricherSocketsValue{
   183  						Mntns:   info.MountNsID,
   184  						PidTgid: uint64(uint32(info.Pid))<<32 + uint64(info.Tid),
   185  						Task:    stringToSlice("socketenricher."),
   186  					},
   187  				}
   188  			},
   189  		},
   190  		"tcp6-only": {
   191  			generateEvent: func() (uint16, int, error) {
   192  				opts := []sockOpt{
   193  					{unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 1},
   194  				}
   195  				return bindSocketWithOpts("::", unix.AF_INET6, unix.SOCK_STREAM, 0, opts)
   196  			},
   197  			expectedEvent: func(info *utilstest.RunnerInfo, port uint16) *socketEnricherMapEntry {
   198  				return &socketEnricherMapEntry{
   199  					Key: socketenricherSocketsKey{
   200  						Netns:  uint32(info.NetworkNsID),
   201  						Family: unix.AF_INET6,
   202  						Proto:  unix.IPPROTO_TCP,
   203  						Port:   port,
   204  					},
   205  					Value: socketenricherSocketsValue{
   206  						Mntns:    info.MountNsID,
   207  						PidTgid:  uint64(uint32(info.Pid))<<32 + uint64(info.Tid),
   208  						Task:     stringToSlice("socketenricher."),
   209  						Ipv6only: int8(1),
   210  					},
   211  				}
   212  			},
   213  		},
   214  		"tcp_uid_gid": {
   215  			runnerConfig:  &utilstest.RunnerConfig{Uid: 1000, Gid: 1111},
   216  			generateEvent: bindSocketFn("127.0.0.1", unix.AF_INET, unix.SOCK_STREAM, 0),
   217  			expectedEvent: func(info *utilstest.RunnerInfo, port uint16) *socketEnricherMapEntry {
   218  				return &socketEnricherMapEntry{
   219  					Key: socketenricherSocketsKey{
   220  						Netns:  uint32(info.NetworkNsID),
   221  						Family: unix.AF_INET,
   222  						Proto:  unix.IPPROTO_TCP,
   223  						Port:   port,
   224  					},
   225  					Value: socketenricherSocketsValue{
   226  						Mntns:   info.MountNsID,
   227  						PidTgid: uint64(uint32(info.Pid))<<32 + uint64(info.Tid),
   228  						UidGid:  uint64(1111)<<32 + uint64(1000),
   229  						Task:    stringToSlice("socketenricher."),
   230  					},
   231  				}
   232  			},
   233  		},
   234  	} {
   235  		test := test
   236  
   237  		t.Run(name, func(t *testing.T) {
   238  			t.Parallel()
   239  
   240  			runner := utilstest.NewRunnerWithTest(t, test.runnerConfig)
   241  
   242  			// We will test 2 scenarios with 2 different tracers:
   243  			// 1. earlyTracer will be started before the event is generated
   244  			// 2. lateTracer will be started after the event is generated
   245  			earlyTracer, err := NewSocketEnricher()
   246  			if err != nil {
   247  				t.Fatal(err)
   248  			}
   249  			t.Cleanup(earlyTracer.Close)
   250  
   251  			// Generate the event in the fake container
   252  			var port uint16
   253  			var fd int
   254  			utilstest.RunWithRunner(t, runner, func() error {
   255  				var err error
   256  				port, fd, err = test.generateEvent()
   257  
   258  				t.Cleanup(func() {
   259  					// cleanup only if it has not been already closed
   260  					if fd != -1 {
   261  						unix.Close(fd)
   262  					}
   263  				})
   264  
   265  				return err
   266  			})
   267  
   268  			// Start the late tracer after the event has been generated
   269  			lateTracer, err := NewSocketEnricher()
   270  			if err != nil {
   271  				t.Fatal(err)
   272  			}
   273  			t.Cleanup(lateTracer.Close)
   274  
   275  			earlyNormalize := func(entry *socketEnricherMapEntry) {
   276  				entry.Value.Sock = 0
   277  			}
   278  			lateNormalize := func(entry *socketEnricherMapEntry) {
   279  				earlyNormalize(entry)
   280  
   281  				// Remove tid: the late tracer cannot distinguish between threads
   282  				entry.Value.PidTgid = 0xffffffff00000000 & entry.Value.PidTgid
   283  
   284  				// Our fake container is just a thread in a different MountNsID
   285  				// But the late tracer cannot distinguish threads.
   286  				if entry.Value.Mntns > 0 {
   287  					entry.Value.Mntns = 1
   288  				}
   289  
   290  				// We're not able to test uid and gid in the late tracer because our
   291  				// fake container is just another thread running on the same process
   292  				// and that tracer cannot distinguish threads.
   293  				entry.Value.UidGid = 0
   294  			}
   295  
   296  			t.Logf("Testing if early tracer noticed the event")
   297  			entries := socketsMapEntries(t, earlyTracer, earlyNormalize, nil)
   298  			utilstest.ExpectAtLeastOneEvent(test.expectedEvent)(t, runner.Info, port, entries)
   299  
   300  			t.Logf("Testing if late tracer noticed the event")
   301  			entries2 := socketsMapEntries(t, lateTracer, lateNormalize, nil)
   302  			expectedEvent2 := func(info *utilstest.RunnerInfo, port uint16) *socketEnricherMapEntry {
   303  				e := test.expectedEvent(info, port)
   304  				lateNormalize(e)
   305  				return e
   306  			}
   307  			utilstest.ExpectAtLeastOneEvent(expectedEvent2)(t, runner.Info, port, entries2)
   308  
   309  			t.Logf("Close socket in order to check for cleanup")
   310  			if fd != -1 {
   311  				unix.Close(fd)
   312  				// Disable t.Cleanup() above
   313  				fd = -1
   314  			}
   315  
   316  			filter := func(e *socketEnricherMapEntry) bool {
   317  				expected := test.expectedEvent(runner.Info, port)
   318  				return !reflect.DeepEqual(expected, e)
   319  			}
   320  
   321  			t.Logf("Testing if entry is cleaned properly in early tracer")
   322  			entries = socketsMapEntries(t, earlyTracer, earlyNormalize, filter)
   323  			if len(entries) != 0 {
   324  				t.Fatalf("Entry not cleaned properly: %+v", entries)
   325  			}
   326  
   327  			t.Logf("Testing if entry is cleaned properly in late tracer")
   328  			entries2 = socketsMapEntries(t, lateTracer, lateNormalize, filter)
   329  			if len(entries2) != 0 {
   330  				t.Fatalf("Entry for late tracer not cleaned properly: %+v", entries2)
   331  			}
   332  		})
   333  	}
   334  }
   335  
   336  func socketsMapEntries(
   337  	t *testing.T,
   338  	tracer *SocketEnricher,
   339  	normalize func(entry *socketEnricherMapEntry),
   340  	filter func(*socketEnricherMapEntry) bool,
   341  ) (entries []socketEnricherMapEntry) {
   342  	iter := tracer.SocketsMap().Iterate()
   343  	var key socketenricherSocketsKey
   344  	var value socketenricherSocketsValue
   345  	for iter.Next(&key, &value) {
   346  		entry := socketEnricherMapEntry{
   347  			Key:   key,
   348  			Value: value,
   349  		}
   350  
   351  		normalize(&entry)
   352  
   353  		if filter != nil && filter(&entry) {
   354  			continue
   355  		}
   356  		entries = append(entries, entry)
   357  	}
   358  	if err := iter.Err(); err != nil {
   359  		t.Fatal("Cannot iterate over socket enricher map:", err)
   360  	}
   361  	return entries
   362  }
   363  
   364  // bindSocketFn returns a function that creates a socket, binds it and
   365  // returns the port the socket was bound to.
   366  func bindSocketFn(ipStr string, domain, typ int, port int) func() (uint16, int, error) {
   367  	return func() (uint16, int, error) {
   368  		return bindSocket(ipStr, domain, typ, port)
   369  	}
   370  }
   371  
   372  func bindSocket(ipStr string, domain, typ int, port int) (uint16, int, error) {
   373  	return bindSocketWithOpts(ipStr, domain, typ, port, nil)
   374  }
   375  
   376  func setProcessName(name string) error {
   377  	bytes := append([]byte(name), 0)
   378  	return unix.Prctl(unix.PR_SET_NAME, uintptr(unsafe.Pointer(&bytes[0])), 0, 0, 0)
   379  }
   380  
   381  func bindSocketWithOpts(ipStr string, domain, typ int, port int, opts []sockOpt) (uint16, int, error) {
   382  	// The process name is usually based on the package name
   383  	// ("socketenricher.") but it could be changed (e.g. running tests in the
   384  	// Goland IDE environment). Make sure the tests work regardless of the
   385  	// environment.
   386  	//
   387  	// Example how to test this:
   388  	//
   389  	//	$ go test -c ./pkg/socketenricher/...
   390  	//	$ sudo ./socketenricher.test
   391  	//	PASS
   392  	//	$ mv socketenricher.test se.test
   393  	//	$ sudo ./se.test
   394  	//	FAIL
   395  	err := setProcessName("socketenricher.")
   396  	if err != nil {
   397  		return 0, -1, fmt.Errorf("setProcessName: %w", err)
   398  	}
   399  
   400  	fd, err := unix.Socket(domain, typ, 0)
   401  	if err != nil {
   402  		return 0, -1, err
   403  	}
   404  
   405  	for _, opt := range opts {
   406  		if err := unix.SetsockoptInt(fd, opt.level, opt.opt, opt.value); err != nil {
   407  			return 0, -1, fmt.Errorf("SetsockoptInt: %w", err)
   408  		}
   409  	}
   410  
   411  	var sa unix.Sockaddr
   412  
   413  	ip := net.ParseIP(ipStr)
   414  
   415  	if ip.To4() != nil {
   416  		sa4 := &unix.SockaddrInet4{Port: port}
   417  		copy(sa4.Addr[:], ip.To4())
   418  		sa = sa4
   419  	} else if ip.To16() != nil {
   420  		sa6 := &unix.SockaddrInet6{Port: port}
   421  		copy(sa6.Addr[:], ip.To16())
   422  		sa = sa6
   423  	} else {
   424  		return 0, -1, fmt.Errorf("invalid IP address")
   425  	}
   426  
   427  	if err := unix.Bind(fd, sa); err != nil {
   428  		return 0, -1, fmt.Errorf("Bind: %w", err)
   429  	}
   430  
   431  	sa2, err := unix.Getsockname(fd)
   432  	if err != nil {
   433  		return 0, fd, fmt.Errorf("Getsockname: %w", err)
   434  	}
   435  
   436  	if ip.To4() != nil {
   437  		return uint16(sa2.(*unix.SockaddrInet4).Port), fd, nil
   438  	} else if ip.To16() != nil {
   439  		return uint16(sa2.(*unix.SockaddrInet6).Port), fd, nil
   440  	} else {
   441  		return 0, fd, fmt.Errorf("invalid IP address")
   442  	}
   443  }