github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/pkg/sentry/syscalls/linux/sys_poll.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package linux
    16  
    17  import (
    18  	"time"
    19  
    20  	"github.com/SagerNet/gvisor/pkg/abi/linux"
    21  	"github.com/SagerNet/gvisor/pkg/errors/linuxerr"
    22  	"github.com/SagerNet/gvisor/pkg/hostarch"
    23  	"github.com/SagerNet/gvisor/pkg/sentry/arch"
    24  	"github.com/SagerNet/gvisor/pkg/sentry/fs"
    25  	"github.com/SagerNet/gvisor/pkg/sentry/kernel"
    26  	ktime "github.com/SagerNet/gvisor/pkg/sentry/kernel/time"
    27  	"github.com/SagerNet/gvisor/pkg/sentry/limits"
    28  	"github.com/SagerNet/gvisor/pkg/syserror"
    29  	"github.com/SagerNet/gvisor/pkg/waiter"
    30  )
    31  
    32  // fileCap is the maximum allowable files for poll & select.
    33  const fileCap = 1024 * 1024
    34  
    35  // Masks for "readable", "writable", and "exceptional" events as defined by
    36  // select(2).
    37  const (
    38  	// selectReadEvents is analogous to the Linux kernel's
    39  	// fs/select.c:POLLIN_SET.
    40  	selectReadEvents = linux.POLLIN | linux.POLLHUP | linux.POLLERR
    41  
    42  	// selectWriteEvents is analogous to the Linux kernel's
    43  	// fs/select.c:POLLOUT_SET.
    44  	selectWriteEvents = linux.POLLOUT | linux.POLLERR
    45  
    46  	// selectExceptEvents is analogous to the Linux kernel's
    47  	// fs/select.c:POLLEX_SET.
    48  	selectExceptEvents = linux.POLLPRI
    49  )
    50  
    51  // pollState tracks the associated file descriptor and waiter of a PollFD.
    52  type pollState struct {
    53  	file   *fs.File
    54  	waiter waiter.Entry
    55  }
    56  
    57  // initReadiness gets the current ready mask for the file represented by the FD
    58  // stored in pfd.FD. If a channel is passed in, the waiter entry in "state" is
    59  // used to register with the file for event notifications, and a reference to
    60  // the file is stored in "state".
    61  func initReadiness(t *kernel.Task, pfd *linux.PollFD, state *pollState, ch chan struct{}) {
    62  	if pfd.FD < 0 {
    63  		pfd.REvents = 0
    64  		return
    65  	}
    66  
    67  	file := t.GetFile(pfd.FD)
    68  	if file == nil {
    69  		pfd.REvents = linux.POLLNVAL
    70  		return
    71  	}
    72  
    73  	if ch == nil {
    74  		defer file.DecRef(t)
    75  	} else {
    76  		state.file = file
    77  		state.waiter, _ = waiter.NewChannelEntry(ch)
    78  		file.EventRegister(&state.waiter, waiter.EventMaskFromLinux(uint32(pfd.Events)))
    79  	}
    80  
    81  	r := file.Readiness(waiter.EventMaskFromLinux(uint32(pfd.Events)))
    82  	pfd.REvents = int16(r.ToLinux()) & pfd.Events
    83  }
    84  
    85  // releaseState releases all the pollState in "state".
    86  func releaseState(t *kernel.Task, state []pollState) {
    87  	for i := range state {
    88  		if state[i].file != nil {
    89  			state[i].file.EventUnregister(&state[i].waiter)
    90  			state[i].file.DecRef(t)
    91  		}
    92  	}
    93  }
    94  
    95  // pollBlock polls the PollFDs in "pfd" with a bounded time specified in "timeout"
    96  // when "timeout" is greater than zero.
    97  //
    98  // pollBlock returns the remaining timeout, which is always 0 on a timeout; and 0 or
    99  // positive if interrupted by a signal.
   100  func pollBlock(t *kernel.Task, pfd []linux.PollFD, timeout time.Duration) (time.Duration, uintptr, error) {
   101  	var ch chan struct{}
   102  	if timeout != 0 {
   103  		ch = make(chan struct{}, 1)
   104  	}
   105  
   106  	// Register for event notification in the files involved if we may
   107  	// block (timeout not zero). Once we find a file that has a non-zero
   108  	// result, we stop registering for events but still go through all files
   109  	// to get their ready masks.
   110  	state := make([]pollState, len(pfd))
   111  	defer releaseState(t, state)
   112  	n := uintptr(0)
   113  	for i := range pfd {
   114  		initReadiness(t, &pfd[i], &state[i], ch)
   115  		if pfd[i].REvents != 0 {
   116  			n++
   117  			ch = nil
   118  		}
   119  	}
   120  
   121  	if timeout == 0 {
   122  		return timeout, n, nil
   123  	}
   124  
   125  	forever := timeout < 0
   126  
   127  	for n == 0 {
   128  		var err error
   129  		// Wait for a notification.
   130  		timeout, err = t.BlockWithTimeout(ch, !forever, timeout)
   131  		if err != nil {
   132  			if linuxerr.Equals(linuxerr.ETIMEDOUT, err) {
   133  				err = nil
   134  			}
   135  			return timeout, 0, err
   136  		}
   137  
   138  		// We got notified, count how many files are ready. If none,
   139  		// then this was a spurious notification, and we just go back
   140  		// to sleep with the remaining timeout.
   141  		for i := range state {
   142  			if state[i].file == nil {
   143  				continue
   144  			}
   145  
   146  			r := state[i].file.Readiness(waiter.EventMaskFromLinux(uint32(pfd[i].Events)))
   147  			rl := int16(r.ToLinux()) & pfd[i].Events
   148  			if rl != 0 {
   149  				pfd[i].REvents = rl
   150  				n++
   151  			}
   152  		}
   153  	}
   154  
   155  	return timeout, n, nil
   156  }
   157  
   158  // CopyInPollFDs copies an array of struct pollfd unless nfds exceeds the max.
   159  func CopyInPollFDs(t *kernel.Task, addr hostarch.Addr, nfds uint) ([]linux.PollFD, error) {
   160  	if uint64(nfds) > t.ThreadGroup().Limits().GetCapped(limits.NumberOfFiles, fileCap) {
   161  		return nil, linuxerr.EINVAL
   162  	}
   163  
   164  	pfd := make([]linux.PollFD, nfds)
   165  	if nfds > 0 {
   166  		if _, err := linux.CopyPollFDSliceIn(t, addr, pfd); err != nil {
   167  			return nil, err
   168  		}
   169  	}
   170  
   171  	return pfd, nil
   172  }
   173  
   174  func doPoll(t *kernel.Task, addr hostarch.Addr, nfds uint, timeout time.Duration) (time.Duration, uintptr, error) {
   175  	pfd, err := CopyInPollFDs(t, addr, nfds)
   176  	if err != nil {
   177  		return timeout, 0, err
   178  	}
   179  
   180  	// Compatibility warning: Linux adds POLLHUP and POLLERR just before
   181  	// polling, in fs/select.c:do_pollfd(). Since pfd is copied out after
   182  	// polling, changing event masks here is an application-visible difference.
   183  	// (Linux also doesn't copy out event masks at all, only revents.)
   184  	for i := range pfd {
   185  		pfd[i].Events |= linux.POLLHUP | linux.POLLERR
   186  	}
   187  	remainingTimeout, n, err := pollBlock(t, pfd, timeout)
   188  	err = syserror.ConvertIntr(err, syserror.EINTR)
   189  
   190  	// The poll entries are copied out regardless of whether
   191  	// any are set or not. This aligns with the Linux behavior.
   192  	if nfds > 0 && err == nil {
   193  		if _, err := linux.CopyPollFDSliceOut(t, addr, pfd); err != nil {
   194  			return remainingTimeout, 0, err
   195  		}
   196  	}
   197  
   198  	return remainingTimeout, n, err
   199  }
   200  
   201  // CopyInFDSet copies an fd set from select(2)/pselect(2).
   202  func CopyInFDSet(t *kernel.Task, addr hostarch.Addr, nBytes, nBitsInLastPartialByte int) ([]byte, error) {
   203  	set := make([]byte, nBytes)
   204  
   205  	if addr != 0 {
   206  		if _, err := t.CopyInBytes(addr, set); err != nil {
   207  			return nil, err
   208  		}
   209  		// If we only use part of the last byte, mask out the extraneous bits.
   210  		//
   211  		// N.B. This only works on little-endian architectures.
   212  		if nBitsInLastPartialByte != 0 {
   213  			set[nBytes-1] &^= byte(0xff) << nBitsInLastPartialByte
   214  		}
   215  	}
   216  	return set, nil
   217  }
   218  
   219  func doSelect(t *kernel.Task, nfds int, readFDs, writeFDs, exceptFDs hostarch.Addr, timeout time.Duration) (uintptr, error) {
   220  	if nfds < 0 || nfds > fileCap {
   221  		return 0, linuxerr.EINVAL
   222  	}
   223  
   224  	// Calculate the size of the fd sets (one bit per fd).
   225  	nBytes := (nfds + 7) / 8
   226  	nBitsInLastPartialByte := nfds % 8
   227  
   228  	// Capture all the provided input vectors.
   229  	r, err := CopyInFDSet(t, readFDs, nBytes, nBitsInLastPartialByte)
   230  	if err != nil {
   231  		return 0, err
   232  	}
   233  	w, err := CopyInFDSet(t, writeFDs, nBytes, nBitsInLastPartialByte)
   234  	if err != nil {
   235  		return 0, err
   236  	}
   237  	e, err := CopyInFDSet(t, exceptFDs, nBytes, nBitsInLastPartialByte)
   238  	if err != nil {
   239  		return 0, err
   240  	}
   241  
   242  	// Count how many FDs are actually being requested so that we can build
   243  	// a PollFD array.
   244  	fdCount := 0
   245  	for i := 0; i < nBytes; i++ {
   246  		v := r[i] | w[i] | e[i]
   247  		for v != 0 {
   248  			v &= (v - 1)
   249  			fdCount++
   250  		}
   251  	}
   252  
   253  	// Build the PollFD array.
   254  	pfd := make([]linux.PollFD, 0, fdCount)
   255  	var fd int32
   256  	for i := 0; i < nBytes; i++ {
   257  		rV, wV, eV := r[i], w[i], e[i]
   258  		v := rV | wV | eV
   259  		m := byte(1)
   260  		for j := 0; j < 8; j++ {
   261  			if (v & m) != 0 {
   262  				// Make sure the fd is valid and decrement the reference
   263  				// immediately to ensure we don't leak. Note, another thread
   264  				// might be about to close fd. This is racy, but that's
   265  				// OK. Linux is racy in the same way.
   266  				file := t.GetFile(fd)
   267  				if file == nil {
   268  					return 0, linuxerr.EBADF
   269  				}
   270  				file.DecRef(t)
   271  
   272  				var mask int16
   273  				if (rV & m) != 0 {
   274  					mask |= selectReadEvents
   275  				}
   276  
   277  				if (wV & m) != 0 {
   278  					mask |= selectWriteEvents
   279  				}
   280  
   281  				if (eV & m) != 0 {
   282  					mask |= selectExceptEvents
   283  				}
   284  
   285  				pfd = append(pfd, linux.PollFD{
   286  					FD:     fd,
   287  					Events: mask,
   288  				})
   289  			}
   290  
   291  			fd++
   292  			m <<= 1
   293  		}
   294  	}
   295  
   296  	// Do the syscall, then count the number of bits set.
   297  	if _, _, err = pollBlock(t, pfd, timeout); err != nil {
   298  		return 0, syserror.ConvertIntr(err, syserror.EINTR)
   299  	}
   300  
   301  	// r, w, and e are currently event mask bitsets; unset bits corresponding
   302  	// to events that *didn't* occur.
   303  	bitSetCount := uintptr(0)
   304  	for idx := range pfd {
   305  		events := pfd[idx].REvents
   306  		i, j := pfd[idx].FD/8, uint(pfd[idx].FD%8)
   307  		m := byte(1) << j
   308  		if r[i]&m != 0 {
   309  			if (events & selectReadEvents) != 0 {
   310  				bitSetCount++
   311  			} else {
   312  				r[i] &^= m
   313  			}
   314  		}
   315  		if w[i]&m != 0 {
   316  			if (events & selectWriteEvents) != 0 {
   317  				bitSetCount++
   318  			} else {
   319  				w[i] &^= m
   320  			}
   321  		}
   322  		if e[i]&m != 0 {
   323  			if (events & selectExceptEvents) != 0 {
   324  				bitSetCount++
   325  			} else {
   326  				e[i] &^= m
   327  			}
   328  		}
   329  	}
   330  
   331  	// Copy updated vectors back.
   332  	if readFDs != 0 {
   333  		if _, err := t.CopyOutBytes(readFDs, r); err != nil {
   334  			return 0, err
   335  		}
   336  	}
   337  
   338  	if writeFDs != 0 {
   339  		if _, err := t.CopyOutBytes(writeFDs, w); err != nil {
   340  			return 0, err
   341  		}
   342  	}
   343  
   344  	if exceptFDs != 0 {
   345  		if _, err := t.CopyOutBytes(exceptFDs, e); err != nil {
   346  			return 0, err
   347  		}
   348  	}
   349  
   350  	return bitSetCount, nil
   351  }
   352  
   353  // timeoutRemaining returns the amount of time remaining for the specified
   354  // timeout or 0 if it has elapsed.
   355  //
   356  // startNs must be from CLOCK_MONOTONIC.
   357  func timeoutRemaining(t *kernel.Task, startNs ktime.Time, timeout time.Duration) time.Duration {
   358  	now := t.Kernel().MonotonicClock().Now()
   359  	remaining := timeout - now.Sub(startNs)
   360  	if remaining < 0 {
   361  		remaining = 0
   362  	}
   363  	return remaining
   364  }
   365  
   366  // copyOutTimespecRemaining copies the time remaining in timeout to timespecAddr.
   367  //
   368  // startNs must be from CLOCK_MONOTONIC.
   369  func copyOutTimespecRemaining(t *kernel.Task, startNs ktime.Time, timeout time.Duration, timespecAddr hostarch.Addr) error {
   370  	if timeout <= 0 {
   371  		return nil
   372  	}
   373  	remaining := timeoutRemaining(t, startNs, timeout)
   374  	tsRemaining := linux.NsecToTimespec(remaining.Nanoseconds())
   375  	return copyTimespecOut(t, timespecAddr, &tsRemaining)
   376  }
   377  
   378  // copyOutTimevalRemaining copies the time remaining in timeout to timevalAddr.
   379  //
   380  // startNs must be from CLOCK_MONOTONIC.
   381  func copyOutTimevalRemaining(t *kernel.Task, startNs ktime.Time, timeout time.Duration, timevalAddr hostarch.Addr) error {
   382  	if timeout <= 0 {
   383  		return nil
   384  	}
   385  	remaining := timeoutRemaining(t, startNs, timeout)
   386  	tvRemaining := linux.NsecToTimeval(remaining.Nanoseconds())
   387  	return copyTimevalOut(t, timevalAddr, &tvRemaining)
   388  }
   389  
   390  // pollRestartBlock encapsulates the state required to restart poll(2) via
   391  // restart_syscall(2).
   392  //
   393  // +stateify savable
   394  type pollRestartBlock struct {
   395  	pfdAddr hostarch.Addr
   396  	nfds    uint
   397  	timeout time.Duration
   398  }
   399  
   400  // Restart implements kernel.SyscallRestartBlock.Restart.
   401  func (p *pollRestartBlock) Restart(t *kernel.Task) (uintptr, error) {
   402  	return poll(t, p.pfdAddr, p.nfds, p.timeout)
   403  }
   404  
   405  func poll(t *kernel.Task, pfdAddr hostarch.Addr, nfds uint, timeout time.Duration) (uintptr, error) {
   406  	remainingTimeout, n, err := doPoll(t, pfdAddr, nfds, timeout)
   407  	// On an interrupt poll(2) is restarted with the remaining timeout.
   408  	if linuxerr.Equals(linuxerr.EINTR, err) {
   409  		t.SetSyscallRestartBlock(&pollRestartBlock{
   410  			pfdAddr: pfdAddr,
   411  			nfds:    nfds,
   412  			timeout: remainingTimeout,
   413  		})
   414  		return 0, syserror.ERESTART_RESTARTBLOCK
   415  	}
   416  	return n, err
   417  }
   418  
   419  // Poll implements linux syscall poll(2).
   420  func Poll(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
   421  	pfdAddr := args[0].Pointer()
   422  	nfds := uint(args[1].Uint()) // poll(2) uses unsigned long.
   423  	timeout := time.Duration(args[2].Int()) * time.Millisecond
   424  	n, err := poll(t, pfdAddr, nfds, timeout)
   425  	return n, nil, err
   426  }
   427  
   428  // Ppoll implements linux syscall ppoll(2).
   429  func Ppoll(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
   430  	pfdAddr := args[0].Pointer()
   431  	nfds := uint(args[1].Uint()) // poll(2) uses unsigned long.
   432  	timespecAddr := args[2].Pointer()
   433  	maskAddr := args[3].Pointer()
   434  	maskSize := uint(args[4].Uint())
   435  
   436  	timeout, err := copyTimespecInToDuration(t, timespecAddr)
   437  	if err != nil {
   438  		return 0, nil, err
   439  	}
   440  
   441  	var startNs ktime.Time
   442  	if timeout > 0 {
   443  		startNs = t.Kernel().MonotonicClock().Now()
   444  	}
   445  
   446  	if maskAddr != 0 {
   447  		mask, err := CopyInSigSet(t, maskAddr, maskSize)
   448  		if err != nil {
   449  			return 0, nil, err
   450  		}
   451  
   452  		oldmask := t.SignalMask()
   453  		t.SetSignalMask(mask)
   454  		t.SetSavedSignalMask(oldmask)
   455  	}
   456  
   457  	_, n, err := doPoll(t, pfdAddr, nfds, timeout)
   458  	copyErr := copyOutTimespecRemaining(t, startNs, timeout, timespecAddr)
   459  	// doPoll returns EINTR if interrupted, but ppoll is normally restartable
   460  	// if interrupted by something other than a signal handled by the
   461  	// application (i.e. returns ERESTARTNOHAND). However, if
   462  	// copyOutTimespecRemaining failed, then the restarted ppoll would use the
   463  	// wrong timeout, so the error should be left as EINTR.
   464  	//
   465  	// Note that this means that if err is nil but copyErr is not, copyErr is
   466  	// ignored. This is consistent with Linux.
   467  	if linuxerr.Equals(linuxerr.EINTR, err) && copyErr == nil {
   468  		err = syserror.ERESTARTNOHAND
   469  	}
   470  	return n, nil, err
   471  }
   472  
   473  // Select implements linux syscall select(2).
   474  func Select(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
   475  	nfds := int(args[0].Int()) // select(2) uses an int.
   476  	readFDs := args[1].Pointer()
   477  	writeFDs := args[2].Pointer()
   478  	exceptFDs := args[3].Pointer()
   479  	timevalAddr := args[4].Pointer()
   480  
   481  	// Use a negative Duration to indicate "no timeout".
   482  	timeout := time.Duration(-1)
   483  	if timevalAddr != 0 {
   484  		timeval, err := copyTimevalIn(t, timevalAddr)
   485  		if err != nil {
   486  			return 0, nil, err
   487  		}
   488  		if timeval.Sec < 0 || timeval.Usec < 0 {
   489  			return 0, nil, linuxerr.EINVAL
   490  		}
   491  		timeout = time.Duration(timeval.ToNsecCapped())
   492  	}
   493  	startNs := t.Kernel().MonotonicClock().Now()
   494  	n, err := doSelect(t, nfds, readFDs, writeFDs, exceptFDs, timeout)
   495  	copyErr := copyOutTimevalRemaining(t, startNs, timeout, timevalAddr)
   496  	// See comment in Ppoll.
   497  	if linuxerr.Equals(linuxerr.EINTR, err) && copyErr == nil {
   498  		err = syserror.ERESTARTNOHAND
   499  	}
   500  	return n, nil, err
   501  }
   502  
   503  // Pselect implements linux syscall pselect(2).
   504  func Pselect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
   505  	nfds := int(args[0].Int()) // select(2) uses an int.
   506  	readFDs := args[1].Pointer()
   507  	writeFDs := args[2].Pointer()
   508  	exceptFDs := args[3].Pointer()
   509  	timespecAddr := args[4].Pointer()
   510  	maskWithSizeAddr := args[5].Pointer()
   511  
   512  	timeout, err := copyTimespecInToDuration(t, timespecAddr)
   513  	if err != nil {
   514  		return 0, nil, err
   515  	}
   516  
   517  	var startNs ktime.Time
   518  	if timeout > 0 {
   519  		startNs = t.Kernel().MonotonicClock().Now()
   520  	}
   521  
   522  	if maskWithSizeAddr != 0 {
   523  		maskAddr, size, err := copyInSigSetWithSize(t, maskWithSizeAddr)
   524  		if err != nil {
   525  			return 0, nil, err
   526  		}
   527  
   528  		if maskAddr != 0 {
   529  			mask, err := CopyInSigSet(t, maskAddr, size)
   530  			if err != nil {
   531  				return 0, nil, err
   532  			}
   533  			oldmask := t.SignalMask()
   534  			t.SetSignalMask(mask)
   535  			t.SetSavedSignalMask(oldmask)
   536  		}
   537  	}
   538  
   539  	n, err := doSelect(t, nfds, readFDs, writeFDs, exceptFDs, timeout)
   540  	copyErr := copyOutTimespecRemaining(t, startNs, timeout, timespecAddr)
   541  	// See comment in Ppoll.
   542  	if linuxerr.Equals(linuxerr.EINTR, err) && copyErr == nil {
   543  		err = syserror.ERESTARTNOHAND
   544  	}
   545  	return n, nil, err
   546  }