github.com/kubeshark/ebpf@v0.9.2/internal/epoll/poller.go (about)

     1  package epoll
     2  
     3  import (
     4  	"fmt"
     5  	"math"
     6  	"os"
     7  	"runtime"
     8  	"sync"
     9  
    10  	"github.com/kubeshark/ebpf/internal"
    11  	"github.com/kubeshark/ebpf/internal/unix"
    12  )
    13  
    14  // Poller waits for readiness notifications from multiple file descriptors.
    15  //
    16  // The wait can be interrupted by calling Close.
    17  type Poller struct {
    18  	// mutexes protect the fields declared below them. If you need to
    19  	// acquire both at once you must lock epollMu before eventMu.
    20  	epollMu sync.Mutex
    21  	epollFd int
    22  
    23  	eventMu sync.Mutex
    24  	event   *eventFd
    25  }
    26  
    27  func New() (*Poller, error) {
    28  	epollFd, err := unix.EpollCreate1(unix.EPOLL_CLOEXEC)
    29  	if err != nil {
    30  		return nil, fmt.Errorf("create epoll fd: %v", err)
    31  	}
    32  
    33  	p := &Poller{epollFd: epollFd}
    34  	p.event, err = newEventFd()
    35  	if err != nil {
    36  		unix.Close(epollFd)
    37  		return nil, err
    38  	}
    39  
    40  	if err := p.Add(p.event.raw, 0); err != nil {
    41  		unix.Close(epollFd)
    42  		p.event.close()
    43  		return nil, fmt.Errorf("add eventfd: %w", err)
    44  	}
    45  
    46  	runtime.SetFinalizer(p, (*Poller).Close)
    47  	return p, nil
    48  }
    49  
    50  // Close the poller.
    51  //
    52  // Interrupts any calls to Wait. Multiple calls to Close are valid, but subsequent
    53  // calls will return os.ErrClosed.
    54  func (p *Poller) Close() error {
    55  	runtime.SetFinalizer(p, nil)
    56  
    57  	// Interrupt Wait() via the event fd if it's currently blocked.
    58  	if err := p.wakeWait(); err != nil {
    59  		return err
    60  	}
    61  
    62  	// Acquire the lock. This ensures that Wait isn't running.
    63  	p.epollMu.Lock()
    64  	defer p.epollMu.Unlock()
    65  
    66  	// Prevent other calls to Close().
    67  	p.eventMu.Lock()
    68  	defer p.eventMu.Unlock()
    69  
    70  	if p.epollFd != -1 {
    71  		unix.Close(p.epollFd)
    72  		p.epollFd = -1
    73  	}
    74  
    75  	if p.event != nil {
    76  		p.event.close()
    77  		p.event = nil
    78  	}
    79  
    80  	return nil
    81  }
    82  
    83  // Add an fd to the poller.
    84  //
    85  // id is returned by Wait in the unix.EpollEvent.Pad field any may be zero. It
    86  // must not exceed math.MaxInt32.
    87  //
    88  // Add is blocked by Wait.
    89  func (p *Poller) Add(fd int, id int) error {
    90  	if int64(id) > math.MaxInt32 {
    91  		return fmt.Errorf("unsupported id: %d", id)
    92  	}
    93  
    94  	p.epollMu.Lock()
    95  	defer p.epollMu.Unlock()
    96  
    97  	if p.epollFd == -1 {
    98  		return fmt.Errorf("epoll add: %w", os.ErrClosed)
    99  	}
   100  
   101  	// The representation of EpollEvent isn't entirely accurate.
   102  	// Pad is fully useable, not just padding. Hence we stuff the
   103  	// id in there, which allows us to identify the event later (e.g.,
   104  	// in case of perf events, which CPU sent it).
   105  	event := unix.EpollEvent{
   106  		Events: unix.EPOLLIN,
   107  		Fd:     int32(fd),
   108  		Pad:    int32(id),
   109  	}
   110  
   111  	if err := unix.EpollCtl(p.epollFd, unix.EPOLL_CTL_ADD, fd, &event); err != nil {
   112  		return fmt.Errorf("add fd to epoll: %v", err)
   113  	}
   114  
   115  	return nil
   116  }
   117  
   118  // Wait for events.
   119  //
   120  // Returns the number of pending events or an error wrapping os.ErrClosed if
   121  // Close is called.
   122  func (p *Poller) Wait(events []unix.EpollEvent) (int, error) {
   123  	p.epollMu.Lock()
   124  	defer p.epollMu.Unlock()
   125  
   126  	if p.epollFd == -1 {
   127  		return 0, fmt.Errorf("epoll wait: %w", os.ErrClosed)
   128  	}
   129  
   130  	for {
   131  		n, err := unix.EpollWait(p.epollFd, events, -1)
   132  		if temp, ok := err.(temporaryError); ok && temp.Temporary() {
   133  			// Retry the syscall if we were interrupted, see https://github.com/golang/go/issues/20400
   134  			continue
   135  		}
   136  
   137  		if err != nil {
   138  			return 0, err
   139  		}
   140  
   141  		for _, event := range events[:n] {
   142  			if int(event.Fd) == p.event.raw {
   143  				// Since we don't read p.event the event is never cleared and
   144  				// we'll keep getting this wakeup until Close() acquires the
   145  				// lock and sets p.epollFd = -1.
   146  				return 0, fmt.Errorf("epoll wait: %w", os.ErrClosed)
   147  			}
   148  		}
   149  
   150  		return n, nil
   151  	}
   152  }
   153  
   154  type temporaryError interface {
   155  	Temporary() bool
   156  }
   157  
   158  // waitWait unblocks Wait if it's epoll_wait.
   159  func (p *Poller) wakeWait() error {
   160  	p.eventMu.Lock()
   161  	defer p.eventMu.Unlock()
   162  
   163  	if p.event == nil {
   164  		return fmt.Errorf("epoll wake: %w", os.ErrClosed)
   165  	}
   166  
   167  	return p.event.add(1)
   168  }
   169  
   170  // eventFd wraps a Linux eventfd.
   171  //
   172  // An eventfd acts like a counter: writes add to the counter, reads retrieve
   173  // the counter and reset it to zero. Reads also block if the counter is zero.
   174  //
   175  // See man 2 eventfd.
   176  type eventFd struct {
   177  	file *os.File
   178  	// prefer raw over file.Fd(), since the latter puts the file into blocking
   179  	// mode.
   180  	raw int
   181  }
   182  
   183  func newEventFd() (*eventFd, error) {
   184  	fd, err := unix.Eventfd(0, unix.O_CLOEXEC|unix.O_NONBLOCK)
   185  	if err != nil {
   186  		return nil, err
   187  	}
   188  	file := os.NewFile(uintptr(fd), "event")
   189  	return &eventFd{file, fd}, nil
   190  }
   191  
   192  func (efd *eventFd) close() error {
   193  	return efd.file.Close()
   194  }
   195  
   196  func (efd *eventFd) add(n uint64) error {
   197  	var buf [8]byte
   198  	internal.NativeEndian.PutUint64(buf[:], 1)
   199  	_, err := efd.file.Write(buf[:])
   200  	return err
   201  }
   202  
   203  func (efd *eventFd) read() (uint64, error) {
   204  	var buf [8]byte
   205  	_, err := efd.file.Read(buf[:])
   206  	return internal.NativeEndian.Uint64(buf[:]), err
   207  }