github.com/nicocha30/gvisor-ligolo@v0.0.0-20230726075806-989fa2c0a413/pkg/sentry/syscalls/linux/sys_poll.go (about)

     1  // Copyright 2020 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package linux
    16  
    17  import (
    18  	"fmt"
    19  	"time"
    20  
    21  	"github.com/nicocha30/gvisor-ligolo/pkg/abi/linux"
    22  	"github.com/nicocha30/gvisor-ligolo/pkg/errors/linuxerr"
    23  	"github.com/nicocha30/gvisor-ligolo/pkg/hostarch"
    24  	"github.com/nicocha30/gvisor-ligolo/pkg/sentry/arch"
    25  	"github.com/nicocha30/gvisor-ligolo/pkg/sentry/kernel"
    26  	ktime "github.com/nicocha30/gvisor-ligolo/pkg/sentry/kernel/time"
    27  	"github.com/nicocha30/gvisor-ligolo/pkg/sentry/limits"
    28  	"github.com/nicocha30/gvisor-ligolo/pkg/sentry/vfs"
    29  	"github.com/nicocha30/gvisor-ligolo/pkg/waiter"
    30  )
    31  
    32  // fileCap is the maximum allowable files for poll & select. This has no
    33  // equivalent in Linux; it exists in gVisor since allocation failure in Go is
    34  // unrecoverable.
    35  const fileCap = 1024 * 1024
    36  
    37  // Masks for "readable", "writable", and "exceptional" events as defined by
    38  // select(2).
    39  const (
    40  	// selectReadEvents is analogous to the Linux kernel's
    41  	// fs/select.c:POLLIN_SET.
    42  	selectReadEvents = linux.POLLIN | linux.POLLHUP | linux.POLLERR
    43  
    44  	// selectWriteEvents is analogous to the Linux kernel's
    45  	// fs/select.c:POLLOUT_SET.
    46  	selectWriteEvents = linux.POLLOUT | linux.POLLERR
    47  
    48  	// selectExceptEvents is analogous to the Linux kernel's
    49  	// fs/select.c:POLLEX_SET.
    50  	selectExceptEvents = linux.POLLPRI
    51  )
    52  
    53  // pollState tracks the associated file description and waiter of a PollFD.
    54  type pollState struct {
    55  	file   *vfs.FileDescription
    56  	waiter waiter.Entry
    57  }
    58  
    59  // initReadiness gets the current ready mask for the file represented by the FD
    60  // stored in pfd.FD. If a channel is passed in, the waiter entry in "state" is
    61  // used to register with the file for event notifications, and a reference to
    62  // the file is stored in "state".
    63  func initReadiness(t *kernel.Task, pfd *linux.PollFD, state *pollState, ch chan struct{}) error {
    64  	if pfd.FD < 0 {
    65  		pfd.REvents = 0
    66  		return nil
    67  	}
    68  
    69  	file := t.GetFile(pfd.FD)
    70  	if file == nil {
    71  		pfd.REvents = linux.POLLNVAL
    72  		return nil
    73  	}
    74  
    75  	if ch == nil {
    76  		defer file.DecRef(t)
    77  	} else {
    78  		state.file = file
    79  		state.waiter.Init(waiter.ChannelNotifier(ch), waiter.EventMaskFromLinux(uint32(pfd.Events)))
    80  		if err := file.EventRegister(&state.waiter); err != nil {
    81  			return err
    82  		}
    83  	}
    84  
    85  	r := file.Readiness(waiter.EventMaskFromLinux(uint32(pfd.Events)))
    86  	pfd.REvents = int16(r.ToLinux()) & pfd.Events
    87  	return nil
    88  }
    89  
    90  // releaseState releases all the pollState in "state".
    91  func releaseState(t *kernel.Task, state []pollState) {
    92  	for i := range state {
    93  		if state[i].file != nil {
    94  			state[i].file.EventUnregister(&state[i].waiter)
    95  			state[i].file.DecRef(t)
    96  		}
    97  	}
    98  }
    99  
   100  // pollBlock polls the PollFDs in "pfd" with a bounded time specified in "timeout"
   101  // when "timeout" is greater than zero.
   102  //
   103  // pollBlock returns the remaining timeout, which is always 0 on a timeout; and 0 or
   104  // positive if interrupted by a signal.
   105  func pollBlock(t *kernel.Task, pfd []linux.PollFD, timeout time.Duration) (time.Duration, uintptr, error) {
   106  	var ch chan struct{}
   107  	if timeout != 0 {
   108  		ch = make(chan struct{}, 1)
   109  	}
   110  
   111  	// Register for event notification in the files involved if we may
   112  	// block (timeout not zero). Once we find a file that has a non-zero
   113  	// result, we stop registering for events but still go through all files
   114  	// to get their ready masks.
   115  	state := make([]pollState, len(pfd))
   116  	defer releaseState(t, state)
   117  	n := uintptr(0)
   118  	for i := range pfd {
   119  		if err := initReadiness(t, &pfd[i], &state[i], ch); err != nil {
   120  			return timeout, 0, err
   121  		}
   122  		if pfd[i].REvents != 0 {
   123  			n++
   124  			ch = nil
   125  		}
   126  	}
   127  
   128  	if timeout == 0 {
   129  		return timeout, n, nil
   130  	}
   131  
   132  	haveTimeout := timeout >= 0
   133  
   134  	for n == 0 {
   135  		var err error
   136  		// Wait for a notification.
   137  		timeout, err = t.BlockWithTimeout(ch, haveTimeout, timeout)
   138  		if err != nil {
   139  			if linuxerr.Equals(linuxerr.ETIMEDOUT, err) {
   140  				err = nil
   141  			}
   142  			return timeout, 0, err
   143  		}
   144  
   145  		// We got notified, count how many files are ready. If none,
   146  		// then this was a spurious notification, and we just go back
   147  		// to sleep with the remaining timeout.
   148  		for i := range state {
   149  			if state[i].file == nil {
   150  				continue
   151  			}
   152  
   153  			r := state[i].file.Readiness(waiter.EventMaskFromLinux(uint32(pfd[i].Events)))
   154  			rl := int16(r.ToLinux()) & pfd[i].Events
   155  			if rl != 0 {
   156  				pfd[i].REvents = rl
   157  				n++
   158  			}
   159  		}
   160  	}
   161  
   162  	return timeout, n, nil
   163  }
   164  
   165  // CopyInPollFDs copies an array of struct pollfd unless nfds exceeds the max.
   166  func CopyInPollFDs(t *kernel.Task, addr hostarch.Addr, nfds uint) ([]linux.PollFD, error) {
   167  	if uint64(nfds) > t.ThreadGroup().Limits().GetCapped(limits.NumberOfFiles, fileCap) {
   168  		return nil, linuxerr.EINVAL
   169  	}
   170  
   171  	pfd := make([]linux.PollFD, nfds)
   172  	if nfds > 0 {
   173  		if _, err := linux.CopyPollFDSliceIn(t, addr, pfd); err != nil {
   174  			return nil, err
   175  		}
   176  	}
   177  
   178  	return pfd, nil
   179  }
   180  
   181  func doPoll(t *kernel.Task, addr hostarch.Addr, nfds uint, timeout time.Duration) (time.Duration, uintptr, error) {
   182  	pfd, err := CopyInPollFDs(t, addr, nfds)
   183  	if err != nil {
   184  		return timeout, 0, err
   185  	}
   186  
   187  	// Compatibility warning: Linux adds POLLHUP and POLLERR just before
   188  	// polling, in fs/select.c:do_pollfd(). Since pfd is copied out after
   189  	// polling, changing event masks here is an application-visible difference.
   190  	// (Linux also doesn't copy out event masks at all, only revents.)
   191  	for i := range pfd {
   192  		pfd[i].Events |= linux.POLLHUP | linux.POLLERR
   193  	}
   194  	remainingTimeout, n, err := pollBlock(t, pfd, timeout)
   195  	err = linuxerr.ConvertIntr(err, linuxerr.EINTR)
   196  
   197  	// The poll entries are copied out regardless of whether
   198  	// any are set or not. This aligns with the Linux behavior.
   199  	if nfds > 0 && err == nil {
   200  		if _, err := linux.CopyPollFDSliceOut(t, addr, pfd); err != nil {
   201  			return remainingTimeout, 0, err
   202  		}
   203  	}
   204  
   205  	return remainingTimeout, n, err
   206  }
   207  
   208  // CopyInFDSet copies an fd set from select(2)/pselect(2).
   209  func CopyInFDSet(t *kernel.Task, addr hostarch.Addr, nBytes, nBitsInLastPartialByte int) ([]byte, error) {
   210  	set := make([]byte, nBytes)
   211  
   212  	if addr != 0 {
   213  		if _, err := t.CopyInBytes(addr, set); err != nil {
   214  			return nil, err
   215  		}
   216  		// If we only use part of the last byte, mask out the extraneous bits.
   217  		//
   218  		// N.B. This only works on little-endian architectures.
   219  		if nBitsInLastPartialByte != 0 {
   220  			set[nBytes-1] &^= byte(0xff) << nBitsInLastPartialByte
   221  		}
   222  	}
   223  	return set, nil
   224  }
   225  
   226  func doSelect(t *kernel.Task, nfds int, readFDs, writeFDs, exceptFDs hostarch.Addr, timeout time.Duration) (uintptr, error) {
   227  	if nfds < 0 || nfds > fileCap {
   228  		return 0, linuxerr.EINVAL
   229  	}
   230  
   231  	// Calculate the size of the fd sets (one bit per fd).
   232  	nBytes := (nfds + 7) / 8
   233  	nBitsInLastPartialByte := nfds % 8
   234  
   235  	// Capture all the provided input vectors.
   236  	r, err := CopyInFDSet(t, readFDs, nBytes, nBitsInLastPartialByte)
   237  	if err != nil {
   238  		return 0, err
   239  	}
   240  	w, err := CopyInFDSet(t, writeFDs, nBytes, nBitsInLastPartialByte)
   241  	if err != nil {
   242  		return 0, err
   243  	}
   244  	e, err := CopyInFDSet(t, exceptFDs, nBytes, nBitsInLastPartialByte)
   245  	if err != nil {
   246  		return 0, err
   247  	}
   248  
   249  	// Count how many FDs are actually being requested so that we can build
   250  	// a PollFD array.
   251  	fdCount := 0
   252  	for i := 0; i < nBytes; i++ {
   253  		v := r[i] | w[i] | e[i]
   254  		for v != 0 {
   255  			v &= (v - 1)
   256  			fdCount++
   257  		}
   258  	}
   259  
   260  	// Build the PollFD array.
   261  	pfd := make([]linux.PollFD, 0, fdCount)
   262  	var fd int32
   263  	for i := 0; i < nBytes; i++ {
   264  		rV, wV, eV := r[i], w[i], e[i]
   265  		v := rV | wV | eV
   266  		m := byte(1)
   267  		for j := 0; j < 8; j++ {
   268  			if (v & m) != 0 {
   269  				// Make sure the fd is valid and decrement the reference
   270  				// immediately to ensure we don't leak. Note, another thread
   271  				// might be about to close fd. This is racy, but that's
   272  				// OK. Linux is racy in the same way.
   273  				file := t.GetFile(fd)
   274  				if file == nil {
   275  					return 0, linuxerr.EBADF
   276  				}
   277  				file.DecRef(t)
   278  
   279  				var mask int16
   280  				if (rV & m) != 0 {
   281  					mask |= selectReadEvents
   282  				}
   283  
   284  				if (wV & m) != 0 {
   285  					mask |= selectWriteEvents
   286  				}
   287  
   288  				if (eV & m) != 0 {
   289  					mask |= selectExceptEvents
   290  				}
   291  
   292  				pfd = append(pfd, linux.PollFD{
   293  					FD:     fd,
   294  					Events: mask,
   295  				})
   296  			}
   297  
   298  			fd++
   299  			m <<= 1
   300  		}
   301  	}
   302  
   303  	// Do the syscall, then count the number of bits set.
   304  	if _, _, err = pollBlock(t, pfd, timeout); err != nil {
   305  		return 0, linuxerr.ConvertIntr(err, linuxerr.EINTR)
   306  	}
   307  
   308  	// r, w, and e are currently event mask bitsets; unset bits corresponding
   309  	// to events that *didn't* occur.
   310  	bitSetCount := uintptr(0)
   311  	for idx := range pfd {
   312  		events := pfd[idx].REvents
   313  		i, j := pfd[idx].FD/8, uint(pfd[idx].FD%8)
   314  		m := byte(1) << j
   315  		if r[i]&m != 0 {
   316  			if (events & selectReadEvents) != 0 {
   317  				bitSetCount++
   318  			} else {
   319  				r[i] &^= m
   320  			}
   321  		}
   322  		if w[i]&m != 0 {
   323  			if (events & selectWriteEvents) != 0 {
   324  				bitSetCount++
   325  			} else {
   326  				w[i] &^= m
   327  			}
   328  		}
   329  		if e[i]&m != 0 {
   330  			if (events & selectExceptEvents) != 0 {
   331  				bitSetCount++
   332  			} else {
   333  				e[i] &^= m
   334  			}
   335  		}
   336  	}
   337  
   338  	// Copy updated vectors back.
   339  	if readFDs != 0 {
   340  		if _, err := t.CopyOutBytes(readFDs, r); err != nil {
   341  			return 0, err
   342  		}
   343  	}
   344  
   345  	if writeFDs != 0 {
   346  		if _, err := t.CopyOutBytes(writeFDs, w); err != nil {
   347  			return 0, err
   348  		}
   349  	}
   350  
   351  	if exceptFDs != 0 {
   352  		if _, err := t.CopyOutBytes(exceptFDs, e); err != nil {
   353  			return 0, err
   354  		}
   355  	}
   356  
   357  	return bitSetCount, nil
   358  }
   359  
   360  // timeoutRemaining returns the amount of time remaining for the specified
   361  // timeout or 0 if it has elapsed.
   362  //
   363  // startNs must be from CLOCK_MONOTONIC.
   364  func timeoutRemaining(t *kernel.Task, startNs ktime.Time, timeout time.Duration) time.Duration {
   365  	now := t.Kernel().MonotonicClock().Now()
   366  	remaining := timeout - now.Sub(startNs)
   367  	if remaining < 0 {
   368  		remaining = 0
   369  	}
   370  	return remaining
   371  }
   372  
   373  // copyOutTimespecRemaining copies the time remaining in timeout to timespecAddr.
   374  //
   375  // startNs must be from CLOCK_MONOTONIC.
   376  func copyOutTimespecRemaining(t *kernel.Task, startNs ktime.Time, timeout time.Duration, timespecAddr hostarch.Addr) error {
   377  	if timeout <= 0 {
   378  		return nil
   379  	}
   380  	remaining := timeoutRemaining(t, startNs, timeout)
   381  	tsRemaining := linux.NsecToTimespec(remaining.Nanoseconds())
   382  	_, err := tsRemaining.CopyOut(t, timespecAddr)
   383  	return err
   384  }
   385  
   386  // copyOutTimevalRemaining copies the time remaining in timeout to timevalAddr.
   387  //
   388  // startNs must be from CLOCK_MONOTONIC.
   389  func copyOutTimevalRemaining(t *kernel.Task, startNs ktime.Time, timeout time.Duration, timevalAddr hostarch.Addr) error {
   390  	if timeout <= 0 {
   391  		return nil
   392  	}
   393  	remaining := timeoutRemaining(t, startNs, timeout)
   394  	tvRemaining := linux.NsecToTimeval(remaining.Nanoseconds())
   395  	_, err := tvRemaining.CopyOut(t, timevalAddr)
   396  	return err
   397  }
   398  
   399  // pollRestartBlock encapsulates the state required to restart poll(2) via
   400  // restart_syscall(2).
   401  //
   402  // +stateify savable
   403  type pollRestartBlock struct {
   404  	pfdAddr hostarch.Addr
   405  	nfds    uint
   406  	timeout time.Duration
   407  }
   408  
   409  // Restart implements kernel.SyscallRestartBlock.Restart.
   410  func (p *pollRestartBlock) Restart(t *kernel.Task) (uintptr, error) {
   411  	return poll(t, p.pfdAddr, p.nfds, p.timeout)
   412  }
   413  
   414  func poll(t *kernel.Task, pfdAddr hostarch.Addr, nfds uint, timeout time.Duration) (uintptr, error) {
   415  	remainingTimeout, n, err := doPoll(t, pfdAddr, nfds, timeout)
   416  	// On an interrupt poll(2) is restarted with the remaining timeout.
   417  	if linuxerr.Equals(linuxerr.EINTR, err) {
   418  		t.SetSyscallRestartBlock(&pollRestartBlock{
   419  			pfdAddr: pfdAddr,
   420  			nfds:    nfds,
   421  			timeout: remainingTimeout,
   422  		})
   423  		return 0, linuxerr.ERESTART_RESTARTBLOCK
   424  	}
   425  	return n, err
   426  }
   427  
   428  // Poll implements linux syscall poll(2).
   429  func Poll(t *kernel.Task, sysno uintptr, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
   430  	pfdAddr := args[0].Pointer()
   431  	nfds := uint(args[1].Uint()) // poll(2) uses unsigned long.
   432  	timeout := time.Duration(args[2].Int()) * time.Millisecond
   433  	n, err := poll(t, pfdAddr, nfds, timeout)
   434  	return n, nil, err
   435  }
   436  
   437  // Ppoll implements linux syscall ppoll(2).
   438  func Ppoll(t *kernel.Task, sysno uintptr, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
   439  	pfdAddr := args[0].Pointer()
   440  	nfds := uint(args[1].Uint()) // poll(2) uses unsigned long.
   441  	timespecAddr := args[2].Pointer()
   442  	maskAddr := args[3].Pointer()
   443  	maskSize := uint(args[4].Uint())
   444  
   445  	timeout, err := copyTimespecInToDuration(t, timespecAddr)
   446  	if err != nil {
   447  		return 0, nil, err
   448  	}
   449  
   450  	var startNs ktime.Time
   451  	if timeout > 0 {
   452  		startNs = t.Kernel().MonotonicClock().Now()
   453  	}
   454  
   455  	if err := setTempSignalSet(t, maskAddr, maskSize); err != nil {
   456  		return 0, nil, err
   457  	}
   458  
   459  	_, n, err := doPoll(t, pfdAddr, nfds, timeout)
   460  	copyErr := copyOutTimespecRemaining(t, startNs, timeout, timespecAddr)
   461  	// doPoll returns EINTR if interrupted, but ppoll is normally restartable
   462  	// if interrupted by something other than a signal handled by the
   463  	// application (i.e. returns ERESTARTNOHAND). However, if
   464  	// copyOutTimespecRemaining failed, then the restarted ppoll would use the
   465  	// wrong timeout, so the error should be left as EINTR.
   466  	//
   467  	// Note that this means that if err is nil but copyErr is not, copyErr is
   468  	// ignored. This is consistent with Linux.
   469  	if linuxerr.Equals(linuxerr.EINTR, err) && copyErr == nil {
   470  		err = linuxerr.ERESTARTNOHAND
   471  	}
   472  	return n, nil, err
   473  }
   474  
   475  // Select implements linux syscall select(2).
   476  func Select(t *kernel.Task, sysno uintptr, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
   477  	nfds := int(args[0].Int()) // select(2) uses an int.
   478  	readFDs := args[1].Pointer()
   479  	writeFDs := args[2].Pointer()
   480  	exceptFDs := args[3].Pointer()
   481  	timevalAddr := args[4].Pointer()
   482  
   483  	// Use a negative Duration to indicate "no timeout".
   484  	timeout := time.Duration(-1)
   485  	if timevalAddr != 0 {
   486  		var timeval linux.Timeval
   487  		if _, err := timeval.CopyIn(t, timevalAddr); err != nil {
   488  			return 0, nil, err
   489  		}
   490  		if timeval.Sec < 0 || timeval.Usec < 0 {
   491  			return 0, nil, linuxerr.EINVAL
   492  		}
   493  		timeout = time.Duration(timeval.ToNsecCapped())
   494  	}
   495  	startNs := t.Kernel().MonotonicClock().Now()
   496  	n, err := doSelect(t, nfds, readFDs, writeFDs, exceptFDs, timeout)
   497  	copyErr := copyOutTimevalRemaining(t, startNs, timeout, timevalAddr)
   498  	// See comment in Ppoll.
   499  	if linuxerr.Equals(linuxerr.EINTR, err) && copyErr == nil {
   500  		err = linuxerr.ERESTARTNOHAND
   501  	}
   502  	return n, nil, err
   503  }
   504  
   505  // +marshal
   506  type sigSetWithSize struct {
   507  	sigsetAddr   uint64
   508  	sizeofSigset uint64
   509  }
   510  
   511  // Pselect6 implements linux syscall pselect6(2).
   512  func Pselect6(t *kernel.Task, sysno uintptr, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
   513  	nfds := int(args[0].Int()) // select(2) uses an int.
   514  	readFDs := args[1].Pointer()
   515  	writeFDs := args[2].Pointer()
   516  	exceptFDs := args[3].Pointer()
   517  	timespecAddr := args[4].Pointer()
   518  	maskWithSizeAddr := args[5].Pointer()
   519  
   520  	timeout, err := copyTimespecInToDuration(t, timespecAddr)
   521  	if err != nil {
   522  		return 0, nil, err
   523  	}
   524  
   525  	var startNs ktime.Time
   526  	if timeout > 0 {
   527  		startNs = t.Kernel().MonotonicClock().Now()
   528  	}
   529  
   530  	if maskWithSizeAddr != 0 {
   531  		if t.Arch().Width() != 8 {
   532  			panic(fmt.Sprintf("unsupported sizeof(void*): %d", t.Arch().Width()))
   533  		}
   534  		var maskStruct sigSetWithSize
   535  		if _, err := maskStruct.CopyIn(t, maskWithSizeAddr); err != nil {
   536  			return 0, nil, err
   537  		}
   538  		if err := setTempSignalSet(t, hostarch.Addr(maskStruct.sigsetAddr), uint(maskStruct.sizeofSigset)); err != nil {
   539  			return 0, nil, err
   540  		}
   541  	}
   542  
   543  	n, err := doSelect(t, nfds, readFDs, writeFDs, exceptFDs, timeout)
   544  	copyErr := copyOutTimespecRemaining(t, startNs, timeout, timespecAddr)
   545  	// See comment in Ppoll.
   546  	if linuxerr.Equals(linuxerr.EINTR, err) && copyErr == nil {
   547  		err = linuxerr.ERESTARTNOHAND
   548  	}
   549  	return n, nil, err
   550  }
   551  
   552  func setTempSignalSet(t *kernel.Task, maskAddr hostarch.Addr, maskSize uint) error {
   553  	if maskAddr == 0 {
   554  		return nil
   555  	}
   556  	if maskSize != linux.SignalSetSize {
   557  		return linuxerr.EINVAL
   558  	}
   559  	var mask linux.SignalSet
   560  	if _, err := mask.CopyIn(t, maskAddr); err != nil {
   561  		return err
   562  	}
   563  	mask &^= kernel.UnblockableSignals
   564  	oldmask := t.SignalMask()
   565  	t.SetSignalMask(mask)
   566  	t.SetSavedSignalMask(oldmask)
   567  	return nil
   568  }