github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/pkg/sentry/syscalls/linux/sys_mempolicy.go (about)

     1  // Copyright 2019 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package linux
    16  
    17  import (
    18  	"fmt"
    19  
    20  	"github.com/SagerNet/gvisor/pkg/abi/linux"
    21  	"github.com/SagerNet/gvisor/pkg/errors/linuxerr"
    22  	"github.com/SagerNet/gvisor/pkg/hostarch"
    23  	"github.com/SagerNet/gvisor/pkg/sentry/arch"
    24  	"github.com/SagerNet/gvisor/pkg/sentry/kernel"
    25  	"github.com/SagerNet/gvisor/pkg/syserror"
    26  	"github.com/SagerNet/gvisor/pkg/usermem"
    27  )
    28  
    29  // We unconditionally report a single NUMA node. This also means that our
    30  // "nodemask_t" is a single unsigned long (uint64).
    31  const (
    32  	maxNodes        = 1
    33  	allowedNodemask = (1 << maxNodes) - 1
    34  )
    35  
    36  func copyInNodemask(t *kernel.Task, addr hostarch.Addr, maxnode uint32) (uint64, error) {
    37  	// "nodemask points to a bit mask of node IDs that contains up to maxnode
    38  	// bits. The bit mask size is rounded to the next multiple of
    39  	// sizeof(unsigned long), but the kernel will use bits only up to maxnode.
    40  	// A NULL value of nodemask or a maxnode value of zero specifies the empty
    41  	// set of nodes. If the value of maxnode is zero, the nodemask argument is
    42  	// ignored." - set_mempolicy(2). Unfortunately, most of this is inaccurate
    43  	// because of what appears to be a bug: mm/mempolicy.c:get_nodes() uses
    44  	// maxnode-1, not maxnode, as the number of bits.
    45  	bits := maxnode - 1
    46  	if bits > hostarch.PageSize*8 { // also handles overflow from maxnode == 0
    47  		return 0, linuxerr.EINVAL
    48  	}
    49  	if bits == 0 {
    50  		return 0, nil
    51  	}
    52  	// Copy in the whole nodemask.
    53  	numUint64 := (bits + 63) / 64
    54  	buf := t.CopyScratchBuffer(int(numUint64) * 8)
    55  	if _, err := t.CopyInBytes(addr, buf); err != nil {
    56  		return 0, err
    57  	}
    58  	val := hostarch.ByteOrder.Uint64(buf)
    59  	// Check that only allowed bits in the first unsigned long in the nodemask
    60  	// are set.
    61  	if val&^allowedNodemask != 0 {
    62  		return 0, linuxerr.EINVAL
    63  	}
    64  	// Check that all remaining bits in the nodemask are 0.
    65  	for i := 8; i < len(buf); i++ {
    66  		if buf[i] != 0 {
    67  			return 0, linuxerr.EINVAL
    68  		}
    69  	}
    70  	return val, nil
    71  }
    72  
    73  func copyOutNodemask(t *kernel.Task, addr hostarch.Addr, maxnode uint32, val uint64) error {
    74  	// mm/mempolicy.c:copy_nodes_to_user() also uses maxnode-1 as the number of
    75  	// bits.
    76  	bits := maxnode - 1
    77  	if bits > hostarch.PageSize*8 { // also handles overflow from maxnode == 0
    78  		return linuxerr.EINVAL
    79  	}
    80  	if bits == 0 {
    81  		return nil
    82  	}
    83  	// Copy out the first unsigned long in the nodemask.
    84  	buf := t.CopyScratchBuffer(8)
    85  	hostarch.ByteOrder.PutUint64(buf, val)
    86  	if _, err := t.CopyOutBytes(addr, buf); err != nil {
    87  		return err
    88  	}
    89  	// Zero out remaining unsigned longs in the nodemask.
    90  	if bits > 64 {
    91  		remAddr, ok := addr.AddLength(8)
    92  		if !ok {
    93  			return syserror.EFAULT
    94  		}
    95  		remUint64 := (bits - 1) / 64
    96  		if _, err := t.MemoryManager().ZeroOut(t, remAddr, int64(remUint64)*8, usermem.IOOpts{
    97  			AddressSpaceActive: true,
    98  		}); err != nil {
    99  			return err
   100  		}
   101  	}
   102  	return nil
   103  }
   104  
   105  // GetMempolicy implements the syscall get_mempolicy(2).
   106  func GetMempolicy(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
   107  	mode := args[0].Pointer()
   108  	nodemask := args[1].Pointer()
   109  	maxnode := args[2].Uint()
   110  	addr := args[3].Pointer()
   111  	flags := args[4].Uint()
   112  
   113  	if flags&^(linux.MPOL_F_NODE|linux.MPOL_F_ADDR|linux.MPOL_F_MEMS_ALLOWED) != 0 {
   114  		return 0, nil, linuxerr.EINVAL
   115  	}
   116  	nodeFlag := flags&linux.MPOL_F_NODE != 0
   117  	addrFlag := flags&linux.MPOL_F_ADDR != 0
   118  	memsAllowed := flags&linux.MPOL_F_MEMS_ALLOWED != 0
   119  
   120  	// "EINVAL: The value specified by maxnode is less than the number of node
   121  	// IDs supported by the system." - get_mempolicy(2)
   122  	if nodemask != 0 && maxnode < maxNodes {
   123  		return 0, nil, linuxerr.EINVAL
   124  	}
   125  
   126  	// "If flags specifies MPOL_F_MEMS_ALLOWED [...], the mode argument is
   127  	// ignored and the set of nodes (memories) that the thread is allowed to
   128  	// specify in subsequent calls to mbind(2) or set_mempolicy(2) (in the
   129  	// absence of any mode flags) is returned in nodemask."
   130  	if memsAllowed {
   131  		// "It is not permitted to combine MPOL_F_MEMS_ALLOWED with either
   132  		// MPOL_F_ADDR or MPOL_F_NODE."
   133  		if nodeFlag || addrFlag {
   134  			return 0, nil, linuxerr.EINVAL
   135  		}
   136  		if err := copyOutNodemask(t, nodemask, maxnode, allowedNodemask); err != nil {
   137  			return 0, nil, err
   138  		}
   139  		return 0, nil, nil
   140  	}
   141  
   142  	// "If flags specifies MPOL_F_ADDR, then information is returned about the
   143  	// policy governing the memory address given in addr. ... If the mode
   144  	// argument is not NULL, then get_mempolicy() will store the policy mode
   145  	// and any optional mode flags of the requested NUMA policy in the location
   146  	// pointed to by this argument. If nodemask is not NULL, then the nodemask
   147  	// associated with the policy will be stored in the location pointed to by
   148  	// this argument."
   149  	if addrFlag {
   150  		policy, nodemaskVal, err := t.MemoryManager().NumaPolicy(addr)
   151  		if err != nil {
   152  			return 0, nil, err
   153  		}
   154  		if nodeFlag {
   155  			// "If flags specifies both MPOL_F_NODE and MPOL_F_ADDR,
   156  			// get_mempolicy() will return the node ID of the node on which the
   157  			// address addr is allocated into the location pointed to by mode.
   158  			// If no page has yet been allocated for the specified address,
   159  			// get_mempolicy() will allocate a page as if the thread had
   160  			// performed a read (load) access to that address, and return the
   161  			// ID of the node where that page was allocated."
   162  			buf := t.CopyScratchBuffer(1)
   163  			_, err := t.CopyInBytes(addr, buf)
   164  			if err != nil {
   165  				return 0, nil, err
   166  			}
   167  			policy = linux.MPOL_DEFAULT // maxNodes == 1
   168  		}
   169  		if mode != 0 {
   170  			if _, err := policy.CopyOut(t, mode); err != nil {
   171  				return 0, nil, err
   172  			}
   173  		}
   174  		if nodemask != 0 {
   175  			if err := copyOutNodemask(t, nodemask, maxnode, nodemaskVal); err != nil {
   176  				return 0, nil, err
   177  			}
   178  		}
   179  		return 0, nil, nil
   180  	}
   181  
   182  	// "EINVAL: ... flags specified MPOL_F_ADDR and addr is NULL, or flags did
   183  	// not specify MPOL_F_ADDR and addr is not NULL." This is partially
   184  	// inaccurate: if flags specifies MPOL_F_ADDR,
   185  	// mm/mempolicy.c:do_get_mempolicy() doesn't special-case NULL; it will
   186  	// just (usually) fail to find a VMA at address 0 and return EFAULT.
   187  	if addr != 0 {
   188  		return 0, nil, linuxerr.EINVAL
   189  	}
   190  
   191  	// "If flags is specified as 0, then information about the calling thread's
   192  	// default policy (as set by set_mempolicy(2)) is returned, in the buffers
   193  	// pointed to by mode and nodemask. ... If flags specifies MPOL_F_NODE, but
   194  	// not MPOL_F_ADDR, and the thread's current policy is MPOL_INTERLEAVE,
   195  	// then get_mempolicy() will return in the location pointed to by a
   196  	// non-NULL mode argument, the node ID of the next node that will be used
   197  	// for interleaving of internal kernel pages allocated on behalf of the
   198  	// thread."
   199  	policy, nodemaskVal := t.NumaPolicy()
   200  	if nodeFlag {
   201  		if policy&^linux.MPOL_MODE_FLAGS != linux.MPOL_INTERLEAVE {
   202  			return 0, nil, linuxerr.EINVAL
   203  		}
   204  		policy = linux.MPOL_DEFAULT // maxNodes == 1
   205  	}
   206  	if mode != 0 {
   207  		if _, err := policy.CopyOut(t, mode); err != nil {
   208  			return 0, nil, err
   209  		}
   210  	}
   211  	if nodemask != 0 {
   212  		if err := copyOutNodemask(t, nodemask, maxnode, nodemaskVal); err != nil {
   213  			return 0, nil, err
   214  		}
   215  	}
   216  	return 0, nil, nil
   217  }
   218  
   219  // SetMempolicy implements the syscall set_mempolicy(2).
   220  func SetMempolicy(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
   221  	modeWithFlags := linux.NumaPolicy(args[0].Int())
   222  	nodemask := args[1].Pointer()
   223  	maxnode := args[2].Uint()
   224  
   225  	modeWithFlags, nodemaskVal, err := copyInMempolicyNodemask(t, modeWithFlags, nodemask, maxnode)
   226  	if err != nil {
   227  		return 0, nil, err
   228  	}
   229  
   230  	t.SetNumaPolicy(modeWithFlags, nodemaskVal)
   231  	return 0, nil, nil
   232  }
   233  
   234  // Mbind implements the syscall mbind(2).
   235  func Mbind(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
   236  	addr := args[0].Pointer()
   237  	length := args[1].Uint64()
   238  	mode := linux.NumaPolicy(args[2].Int())
   239  	nodemask := args[3].Pointer()
   240  	maxnode := args[4].Uint()
   241  	flags := args[5].Uint()
   242  
   243  	if flags&^linux.MPOL_MF_VALID != 0 {
   244  		return 0, nil, linuxerr.EINVAL
   245  	}
   246  	// "If MPOL_MF_MOVE_ALL is passed in flags ... [the] calling thread must be
   247  	// privileged (CAP_SYS_NICE) to use this flag." - mbind(2)
   248  	if flags&linux.MPOL_MF_MOVE_ALL != 0 && !t.HasCapability(linux.CAP_SYS_NICE) {
   249  		return 0, nil, linuxerr.EPERM
   250  	}
   251  
   252  	mode, nodemaskVal, err := copyInMempolicyNodemask(t, mode, nodemask, maxnode)
   253  	if err != nil {
   254  		return 0, nil, err
   255  	}
   256  
   257  	// Since we claim to have only a single node, all flags can be ignored
   258  	// (since all pages must already be on that single node).
   259  	err = t.MemoryManager().SetNumaPolicy(addr, length, mode, nodemaskVal)
   260  	return 0, nil, err
   261  }
   262  
   263  func copyInMempolicyNodemask(t *kernel.Task, modeWithFlags linux.NumaPolicy, nodemask hostarch.Addr, maxnode uint32) (linux.NumaPolicy, uint64, error) {
   264  	flags := linux.NumaPolicy(modeWithFlags & linux.MPOL_MODE_FLAGS)
   265  	mode := linux.NumaPolicy(modeWithFlags &^ linux.MPOL_MODE_FLAGS)
   266  	if flags == linux.MPOL_MODE_FLAGS {
   267  		// Can't specify both mode flags simultaneously.
   268  		return 0, 0, linuxerr.EINVAL
   269  	}
   270  	if mode < 0 || mode >= linux.MPOL_MAX {
   271  		// Must specify a valid mode.
   272  		return 0, 0, linuxerr.EINVAL
   273  	}
   274  
   275  	var nodemaskVal uint64
   276  	if nodemask != 0 {
   277  		var err error
   278  		nodemaskVal, err = copyInNodemask(t, nodemask, maxnode)
   279  		if err != nil {
   280  			return 0, 0, err
   281  		}
   282  	}
   283  
   284  	switch mode {
   285  	case linux.MPOL_DEFAULT:
   286  		// "nodemask must be specified as NULL." - set_mempolicy(2). This is inaccurate;
   287  		// Linux allows a nodemask to be specified, as long as it is empty.
   288  		if nodemaskVal != 0 {
   289  			return 0, 0, linuxerr.EINVAL
   290  		}
   291  	case linux.MPOL_BIND, linux.MPOL_INTERLEAVE:
   292  		// These require a non-empty nodemask.
   293  		if nodemaskVal == 0 {
   294  			return 0, 0, linuxerr.EINVAL
   295  		}
   296  	case linux.MPOL_PREFERRED:
   297  		// This permits an empty nodemask, as long as no flags are set.
   298  		if nodemaskVal == 0 && flags != 0 {
   299  			return 0, 0, linuxerr.EINVAL
   300  		}
   301  	case linux.MPOL_LOCAL:
   302  		// This requires an empty nodemask and no flags set ...
   303  		if nodemaskVal != 0 || flags != 0 {
   304  			return 0, 0, linuxerr.EINVAL
   305  		}
   306  		// ... and is implemented as MPOL_PREFERRED.
   307  		mode = linux.MPOL_PREFERRED
   308  	default:
   309  		// Unknown mode, which we should have rejected above.
   310  		panic(fmt.Sprintf("unknown mode: %v", mode))
   311  	}
   312  
   313  	return mode | flags, nodemaskVal, nil
   314  }