github.com/inspektor-gadget/inspektor-gadget@v0.28.1/pkg/gadgets/top/tcp/tracer/tracer.go (about)

     1  // Copyright 2023 The Inspektor Gadget authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  //go:build !withoutebpf
    16  
    17  package tracer
    18  
    19  import (
    20  	"context"
    21  	"errors"
    22  	"fmt"
    23  	"time"
    24  	"unsafe"
    25  
    26  	"github.com/cilium/ebpf"
    27  	"github.com/cilium/ebpf/link"
    28  
    29  	"github.com/inspektor-gadget/inspektor-gadget/pkg/columns"
    30  	"github.com/inspektor-gadget/inspektor-gadget/pkg/gadgets"
    31  	"github.com/inspektor-gadget/inspektor-gadget/pkg/gadgets/top"
    32  	"github.com/inspektor-gadget/inspektor-gadget/pkg/gadgets/top/tcp/types"
    33  	eventtypes "github.com/inspektor-gadget/inspektor-gadget/pkg/types"
    34  )
    35  
    36  //go:generate go run github.com/cilium/ebpf/cmd/bpf2go -no-global-types -target $TARGET -type ip_key_t -type traffic_t -cc clang -cflags ${CFLAGS} tcptop ./bpf/tcptop.bpf.c -- -I./bpf/
    37  
    38  type Config struct {
    39  	MountnsMap   *ebpf.Map
    40  	TargetPid    int32
    41  	TargetFamily int32
    42  	MaxRows      int
    43  	Interval     time.Duration
    44  	Iterations   int
    45  	SortBy       []string
    46  }
    47  
    48  type Tracer struct {
    49  	config             *Config
    50  	objs               tcptopObjects
    51  	tcpSendmsgLink     link.Link
    52  	tcpCleanupRbufLink link.Link
    53  	enricher           gadgets.DataEnricherByMntNs
    54  	eventCallback      func(*top.Event[types.Stats])
    55  	done               chan bool
    56  	colMap             columns.ColumnMap[types.Stats]
    57  }
    58  
    59  func NewTracer(config *Config, enricher gadgets.DataEnricherByMntNs,
    60  	eventCallback func(*top.Event[types.Stats]),
    61  ) (*Tracer, error) {
    62  	t := &Tracer{
    63  		config:        config,
    64  		enricher:      enricher,
    65  		eventCallback: eventCallback,
    66  		done:          make(chan bool),
    67  	}
    68  
    69  	if err := t.install(); err != nil {
    70  		t.close()
    71  		return nil, err
    72  	}
    73  
    74  	statCols, err := columns.NewColumns[types.Stats]()
    75  	if err != nil {
    76  		t.close()
    77  		return nil, err
    78  	}
    79  	t.colMap = statCols.GetColumnMap()
    80  
    81  	go t.run(context.TODO())
    82  
    83  	return t, nil
    84  }
    85  
    86  // Stop stops the tracer
    87  // TODO: Remove after refactoring
    88  func (t *Tracer) Stop() {
    89  	t.close()
    90  }
    91  
    92  func (t *Tracer) close() {
    93  	close(t.done)
    94  
    95  	t.tcpSendmsgLink = gadgets.CloseLink(t.tcpSendmsgLink)
    96  	t.tcpCleanupRbufLink = gadgets.CloseLink(t.tcpCleanupRbufLink)
    97  
    98  	t.objs.Close()
    99  }
   100  
   101  func (t *Tracer) install() error {
   102  	spec, err := loadTcptop()
   103  	if err != nil {
   104  		return fmt.Errorf("loading ebpf program: %w", err)
   105  	}
   106  
   107  	consts := map[string]interface{}{
   108  		"target_pid":    t.config.TargetPid,
   109  		"target_family": t.config.TargetFamily,
   110  	}
   111  
   112  	if err := gadgets.LoadeBPFSpec(t.config.MountnsMap, spec, consts, &t.objs); err != nil {
   113  		return fmt.Errorf("loading ebpf spec: %w", err)
   114  	}
   115  
   116  	t.tcpSendmsgLink, err = link.Kprobe("tcp_sendmsg", t.objs.IgToptcpSdmsg, nil)
   117  	if err != nil {
   118  		return fmt.Errorf("attaching kprobe: %w", err)
   119  	}
   120  
   121  	t.tcpCleanupRbufLink, err = link.Kprobe("tcp_cleanup_rbuf", t.objs.IgToptcpClean, nil)
   122  	if err != nil {
   123  		return fmt.Errorf("attaching kprobe: %w", err)
   124  	}
   125  
   126  	return nil
   127  }
   128  
   129  func (t *Tracer) nextStats() ([]*types.Stats, error) {
   130  	stats := []*types.Stats{}
   131  
   132  	var prev *tcptopIpKeyT = nil
   133  	key := tcptopIpKeyT{}
   134  	ips := t.objs.IpMap
   135  
   136  	defer func() {
   137  		// delete elements
   138  		err := ips.NextKey(nil, unsafe.Pointer(&key))
   139  		if err != nil {
   140  			return
   141  		}
   142  
   143  		for {
   144  			if err := ips.Delete(key); err != nil {
   145  				return
   146  			}
   147  
   148  			prev = &key
   149  			if err := ips.NextKey(unsafe.Pointer(prev), unsafe.Pointer(&key)); err != nil {
   150  				return
   151  			}
   152  		}
   153  	}()
   154  
   155  	// gather elements
   156  	err := ips.NextKey(nil, unsafe.Pointer(&key))
   157  	if err != nil {
   158  		if errors.Is(err, ebpf.ErrKeyNotExist) {
   159  			return stats, nil
   160  		}
   161  		return nil, fmt.Errorf("getting next key: %w", err)
   162  	}
   163  
   164  	for {
   165  		val := tcptopTrafficT{}
   166  		if err := ips.Lookup(key, unsafe.Pointer(&val)); err != nil {
   167  			return nil, err
   168  		}
   169  
   170  		ipversion := gadgets.IPVerFromAF(key.Family)
   171  
   172  		stat := types.Stats{
   173  			WithMountNsID: eventtypes.WithMountNsID{MountNsID: key.Mntnsid},
   174  			Pid:           int32(key.Pid),
   175  			Comm:          gadgets.FromCString(key.Name[:]),
   176  			SrcEndpoint: eventtypes.L4Endpoint{
   177  				L3Endpoint: eventtypes.L3Endpoint{
   178  					Addr:    gadgets.IPStringFromBytes(key.Saddr, ipversion),
   179  					Version: uint8(ipversion),
   180  				},
   181  				Port: key.Lport,
   182  			},
   183  			DstEndpoint: eventtypes.L4Endpoint{
   184  				L3Endpoint: eventtypes.L3Endpoint{
   185  					Addr:    gadgets.IPStringFromBytes(key.Daddr, ipversion),
   186  					Version: uint8(ipversion),
   187  				},
   188  				Port: key.Dport,
   189  			},
   190  			IPVersion: ipversion,
   191  			Sent:      val.Sent,
   192  			Received:  val.Received,
   193  		}
   194  
   195  		if t.enricher != nil {
   196  			t.enricher.EnrichByMntNs(&stat.CommonData, stat.MountNsID)
   197  		}
   198  
   199  		stats = append(stats, &stat)
   200  
   201  		prev = &key
   202  		if err := ips.NextKey(unsafe.Pointer(prev), unsafe.Pointer(&key)); err != nil {
   203  			if errors.Is(err, ebpf.ErrKeyNotExist) {
   204  				break
   205  			}
   206  			return nil, fmt.Errorf("getting next key: %w", err)
   207  		}
   208  	}
   209  
   210  	top.SortStats(stats, t.config.SortBy, &t.colMap)
   211  
   212  	return stats, nil
   213  }
   214  
   215  func (t *Tracer) run(ctx context.Context) error {
   216  	// Don't use a context with a timeout but a counter to avoid having to deal
   217  	// with two timers: one for the timeout and another for the ticker.
   218  	count := t.config.Iterations
   219  	ticker := time.NewTicker(t.config.Interval)
   220  	defer ticker.Stop()
   221  
   222  	for {
   223  		select {
   224  		case <-t.done:
   225  			// TODO: Once we completely move to use Run instead of NewTracer,
   226  			// we can remove this as nobody will directly call Stop (cleanup).
   227  			return nil
   228  		case <-ctx.Done():
   229  			return nil
   230  		case <-ticker.C:
   231  			stats, err := t.nextStats()
   232  			if err != nil {
   233  				return fmt.Errorf("getting next stats: %w", err)
   234  			}
   235  
   236  			n := len(stats)
   237  			if n > t.config.MaxRows {
   238  				n = t.config.MaxRows
   239  			}
   240  			t.eventCallback(&top.Event[types.Stats]{Stats: stats[:n]})
   241  
   242  			// Count down only if user requested a finite number of iterations
   243  			// through a timeout.
   244  			if t.config.Iterations > 0 {
   245  				count--
   246  				if count == 0 {
   247  					return nil
   248  				}
   249  			}
   250  		}
   251  	}
   252  }
   253  
   254  func (t *Tracer) Run(gadgetCtx gadgets.GadgetContext) error {
   255  	if err := t.init(gadgetCtx); err != nil {
   256  		return fmt.Errorf("initializing tracer: %w", err)
   257  	}
   258  
   259  	defer t.close()
   260  	if err := t.install(); err != nil {
   261  		return fmt.Errorf("installing tracer: %w", err)
   262  	}
   263  
   264  	return t.run(gadgetCtx.Context())
   265  }
   266  
   267  func (t *Tracer) SetEventHandlerArray(handler any) {
   268  	nh, ok := handler.(func(ev []*types.Stats))
   269  	if !ok {
   270  		panic("event handler invalid")
   271  	}
   272  
   273  	// TODO: add errorHandler
   274  	t.eventCallback = func(ev *top.Event[types.Stats]) {
   275  		if ev.Error != "" {
   276  			return
   277  		}
   278  		nh(ev.Stats)
   279  	}
   280  }
   281  
   282  func (t *Tracer) SetMountNsMap(mntnsMap *ebpf.Map) {
   283  	t.config.MountnsMap = mntnsMap
   284  }
   285  
   286  func (g *GadgetDesc) NewInstance() (gadgets.Gadget, error) {
   287  	tracer := &Tracer{
   288  		config: &Config{
   289  			TargetFamily: -1,
   290  			TargetPid:    -1,
   291  		},
   292  		done: make(chan bool),
   293  	}
   294  	return tracer, nil
   295  }
   296  
   297  func (t *Tracer) init(gadgetCtx gadgets.GadgetContext) error {
   298  	params := gadgetCtx.GadgetParams()
   299  	t.config.MaxRows = params.Get(gadgets.ParamMaxRows).AsInt()
   300  	t.config.SortBy = params.Get(gadgets.ParamSortBy).AsStringSlice()
   301  	t.config.Interval = time.Second * time.Duration(params.Get(gadgets.ParamInterval).AsInt())
   302  	t.config.TargetFamily, _ = types.ParseFilterByFamily(params.Get(types.FamilyParam).AsString())
   303  	t.config.TargetPid = params.Get(types.PidParam).AsInt32()
   304  
   305  	var err error
   306  	if t.config.Iterations, err = top.ComputeIterations(t.config.Interval, gadgetCtx.Timeout()); err != nil {
   307  		return err
   308  	}
   309  
   310  	statCols, err := columns.NewColumns[types.Stats]()
   311  	if err != nil {
   312  		return err
   313  	}
   314  	t.colMap = statCols.GetColumnMap()
   315  
   316  	return nil
   317  }