github.com/hugh712/snapd@v0.0.0-20200910133618-1a99902bd583/osutil/udev/netlink/rawsockstop.go (about)

     1  package netlink
     2  
     3  import (
     4  	"fmt"
     5  	"math/bits"
     6  	"os"
     7  	"syscall"
     8  )
     9  
    10  // RawSockStopper returns a pair of functions to manage stopping code
    11  // reading from a raw socket, readableOrStop blocks until
    12  // fd is readable or stop was called. To work properly it sets fd
    13  // to non-blocking mode.
    14  // TODO: with go 1.11+ it should be possible to just switch to setting
    15  // fd to non-blocking and then wrapping the socket via os.NewFile and
    16  // use Close to force a read to stop.
    17  // c.f. https://github.com/golang/go/commit/ea5825b0b64e1a017a76eac0ad734e11ff557c8e
    18  func RawSockStopper(fd int) (readableOrStop func() (bool, error), stop func(), err error) {
    19  	if err := syscall.SetNonblock(fd, true); err != nil {
    20  		return nil, nil, err
    21  	}
    22  
    23  	stopR, stopW, err := os.Pipe()
    24  	if err != nil {
    25  		return nil, nil, err
    26  	}
    27  
    28  	// both stopR and stopW must be kept alive otherwise the corresponding
    29  	// file descriptors will get closed
    30  	readableOrStop = func() (bool, error) {
    31  		return stopperSelectReadable(fd, int(stopR.Fd()))
    32  	}
    33  	stop = func() {
    34  		stopW.Write([]byte{0})
    35  	}
    36  	return readableOrStop, stop, nil
    37  }
    38  
    39  func stopperSelectReadable(fd, stopFd int) (bool, error) {
    40  	maxFd := fd
    41  	if maxFd < stopFd {
    42  		maxFd = stopFd
    43  	}
    44  	if maxFd >= 1024 {
    45  		return false, fmt.Errorf("fd too high for syscall.Select")
    46  	}
    47  	fdIdx := fd / bits.UintSize
    48  	fdShift := uint(fd) % bits.UintSize
    49  	stopFdIdx := stopFd / bits.UintSize
    50  	stopFdShift := uint(stopFd) % bits.UintSize
    51  	readable := false
    52  	tout := stopperSelectTimeout()
    53  	for {
    54  		var r syscall.FdSet
    55  		r.Bits[fdIdx] = 1 << fdShift
    56  		r.Bits[stopFdIdx] |= 1 << stopFdShift
    57  		_, err := syscall.Select(maxFd+1, &r, nil, nil, tout)
    58  		if errno, ok := err.(syscall.Errno); ok && errno.Temporary() {
    59  			continue
    60  		}
    61  		if err != nil {
    62  			return false, err
    63  		}
    64  		readable = (r.Bits[fdIdx] & (1 << fdShift)) != 0
    65  		break
    66  	}
    67  	return readable, nil
    68  }