github.com/MerlinKodo/gvisor@v0.0.0-20231110090155-957f62ecf90e/pkg/sentry/syscalls/linux/sys_mempolicy.go (about)

     1  // Copyright 2019 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package linux
    16  
    17  import (
    18  	"fmt"
    19  
    20  	"github.com/MerlinKodo/gvisor/pkg/abi/linux"
    21  	"github.com/MerlinKodo/gvisor/pkg/errors/linuxerr"
    22  	"github.com/MerlinKodo/gvisor/pkg/hostarch"
    23  	"github.com/MerlinKodo/gvisor/pkg/sentry/arch"
    24  	"github.com/MerlinKodo/gvisor/pkg/sentry/kernel"
    25  	"github.com/MerlinKodo/gvisor/pkg/usermem"
    26  )
    27  
    28  // We unconditionally report a single NUMA node. This also means that our
    29  // "nodemask_t" is a single unsigned long (uint64).
    30  const (
    31  	maxNodes        = 1
    32  	allowedNodemask = (1 << maxNodes) - 1
    33  )
    34  
    35  func copyInNodemask(t *kernel.Task, addr hostarch.Addr, maxnode uint32) (uint64, error) {
    36  	// "nodemask points to a bit mask of node IDs that contains up to maxnode
    37  	// bits. The bit mask size is rounded to the next multiple of
    38  	// sizeof(unsigned long), but the kernel will use bits only up to maxnode.
    39  	// A NULL value of nodemask or a maxnode value of zero specifies the empty
    40  	// set of nodes. If the value of maxnode is zero, the nodemask argument is
    41  	// ignored." - set_mempolicy(2). Unfortunately, most of this is inaccurate
    42  	// because of what appears to be a bug: mm/mempolicy.c:get_nodes() uses
    43  	// maxnode-1, not maxnode, as the number of bits.
    44  	bits := maxnode - 1
    45  	if bits > hostarch.PageSize*8 { // also handles overflow from maxnode == 0
    46  		return 0, linuxerr.EINVAL
    47  	}
    48  	if bits == 0 {
    49  		return 0, nil
    50  	}
    51  	// Copy in the whole nodemask.
    52  	numUint64 := (bits + 63) / 64
    53  	buf := t.CopyScratchBuffer(int(numUint64) * 8)
    54  	if _, err := t.CopyInBytes(addr, buf); err != nil {
    55  		return 0, err
    56  	}
    57  	val := hostarch.ByteOrder.Uint64(buf)
    58  	// Check that only allowed bits in the first unsigned long in the nodemask
    59  	// are set.
    60  	if val&^allowedNodemask != 0 {
    61  		return 0, linuxerr.EINVAL
    62  	}
    63  	// Check that all remaining bits in the nodemask are 0.
    64  	for i := 8; i < len(buf); i++ {
    65  		if buf[i] != 0 {
    66  			return 0, linuxerr.EINVAL
    67  		}
    68  	}
    69  	return val, nil
    70  }
    71  
    72  func copyOutNodemask(t *kernel.Task, addr hostarch.Addr, maxnode uint32, val uint64) error {
    73  	// mm/mempolicy.c:copy_nodes_to_user() also uses maxnode-1 as the number of
    74  	// bits.
    75  	bits := maxnode - 1
    76  	if bits > hostarch.PageSize*8 { // also handles overflow from maxnode == 0
    77  		return linuxerr.EINVAL
    78  	}
    79  	if bits == 0 {
    80  		return nil
    81  	}
    82  	// Copy out the first unsigned long in the nodemask.
    83  	buf := t.CopyScratchBuffer(8)
    84  	hostarch.ByteOrder.PutUint64(buf, val)
    85  	if _, err := t.CopyOutBytes(addr, buf); err != nil {
    86  		return err
    87  	}
    88  	// Zero out remaining unsigned longs in the nodemask.
    89  	if bits > 64 {
    90  		remAddr, ok := addr.AddLength(8)
    91  		if !ok {
    92  			return linuxerr.EFAULT
    93  		}
    94  		remUint64 := (bits - 1) / 64
    95  		if _, err := t.MemoryManager().ZeroOut(t, remAddr, int64(remUint64)*8, usermem.IOOpts{
    96  			AddressSpaceActive: true,
    97  		}); err != nil {
    98  			return err
    99  		}
   100  	}
   101  	return nil
   102  }
   103  
   104  // GetMempolicy implements the syscall get_mempolicy(2).
   105  func GetMempolicy(t *kernel.Task, sysno uintptr, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
   106  	mode := args[0].Pointer()
   107  	nodemask := args[1].Pointer()
   108  	maxnode := args[2].Uint()
   109  	addr := args[3].Pointer()
   110  	flags := args[4].Uint()
   111  
   112  	if flags&^(linux.MPOL_F_NODE|linux.MPOL_F_ADDR|linux.MPOL_F_MEMS_ALLOWED) != 0 {
   113  		return 0, nil, linuxerr.EINVAL
   114  	}
   115  	nodeFlag := flags&linux.MPOL_F_NODE != 0
   116  	addrFlag := flags&linux.MPOL_F_ADDR != 0
   117  	memsAllowed := flags&linux.MPOL_F_MEMS_ALLOWED != 0
   118  
   119  	// "EINVAL: The value specified by maxnode is less than the number of node
   120  	// IDs supported by the system." - get_mempolicy(2)
   121  	if nodemask != 0 && maxnode < maxNodes {
   122  		return 0, nil, linuxerr.EINVAL
   123  	}
   124  
   125  	// "If flags specifies MPOL_F_MEMS_ALLOWED [...], the mode argument is
   126  	// ignored and the set of nodes (memories) that the thread is allowed to
   127  	// specify in subsequent calls to mbind(2) or set_mempolicy(2) (in the
   128  	// absence of any mode flags) is returned in nodemask."
   129  	if memsAllowed {
   130  		// "It is not permitted to combine MPOL_F_MEMS_ALLOWED with either
   131  		// MPOL_F_ADDR or MPOL_F_NODE."
   132  		if nodeFlag || addrFlag {
   133  			return 0, nil, linuxerr.EINVAL
   134  		}
   135  		if err := copyOutNodemask(t, nodemask, maxnode, allowedNodemask); err != nil {
   136  			return 0, nil, err
   137  		}
   138  		return 0, nil, nil
   139  	}
   140  
   141  	// "If flags specifies MPOL_F_ADDR, then information is returned about the
   142  	// policy governing the memory address given in addr. ... If the mode
   143  	// argument is not NULL, then get_mempolicy() will store the policy mode
   144  	// and any optional mode flags of the requested NUMA policy in the location
   145  	// pointed to by this argument. If nodemask is not NULL, then the nodemask
   146  	// associated with the policy will be stored in the location pointed to by
   147  	// this argument."
   148  	if addrFlag {
   149  		policy, nodemaskVal, err := t.MemoryManager().NumaPolicy(addr)
   150  		if err != nil {
   151  			return 0, nil, err
   152  		}
   153  		if nodeFlag {
   154  			// "If flags specifies both MPOL_F_NODE and MPOL_F_ADDR,
   155  			// get_mempolicy() will return the node ID of the node on which the
   156  			// address addr is allocated into the location pointed to by mode.
   157  			// If no page has yet been allocated for the specified address,
   158  			// get_mempolicy() will allocate a page as if the thread had
   159  			// performed a read (load) access to that address, and return the
   160  			// ID of the node where that page was allocated."
   161  			buf := t.CopyScratchBuffer(1)
   162  			_, err := t.CopyInBytes(addr, buf)
   163  			if err != nil {
   164  				return 0, nil, err
   165  			}
   166  			policy = linux.MPOL_DEFAULT // maxNodes == 1
   167  		}
   168  		if mode != 0 {
   169  			if _, err := policy.CopyOut(t, mode); err != nil {
   170  				return 0, nil, err
   171  			}
   172  		}
   173  		if nodemask != 0 {
   174  			if err := copyOutNodemask(t, nodemask, maxnode, nodemaskVal); err != nil {
   175  				return 0, nil, err
   176  			}
   177  		}
   178  		return 0, nil, nil
   179  	}
   180  
   181  	// "EINVAL: ... flags specified MPOL_F_ADDR and addr is NULL, or flags did
   182  	// not specify MPOL_F_ADDR and addr is not NULL." This is partially
   183  	// inaccurate: if flags specifies MPOL_F_ADDR,
   184  	// mm/mempolicy.c:do_get_mempolicy() doesn't special-case NULL; it will
   185  	// just (usually) fail to find a VMA at address 0 and return EFAULT.
   186  	if addr != 0 {
   187  		return 0, nil, linuxerr.EINVAL
   188  	}
   189  
   190  	// "If flags is specified as 0, then information about the calling thread's
   191  	// default policy (as set by set_mempolicy(2)) is returned, in the buffers
   192  	// pointed to by mode and nodemask. ... If flags specifies MPOL_F_NODE, but
   193  	// not MPOL_F_ADDR, and the thread's current policy is MPOL_INTERLEAVE,
   194  	// then get_mempolicy() will return in the location pointed to by a
   195  	// non-NULL mode argument, the node ID of the next node that will be used
   196  	// for interleaving of internal kernel pages allocated on behalf of the
   197  	// thread."
   198  	policy, nodemaskVal := t.NumaPolicy()
   199  	if nodeFlag {
   200  		if policy&^linux.MPOL_MODE_FLAGS != linux.MPOL_INTERLEAVE {
   201  			return 0, nil, linuxerr.EINVAL
   202  		}
   203  		policy = linux.MPOL_DEFAULT // maxNodes == 1
   204  	}
   205  	if mode != 0 {
   206  		if _, err := policy.CopyOut(t, mode); err != nil {
   207  			return 0, nil, err
   208  		}
   209  	}
   210  	if nodemask != 0 {
   211  		if err := copyOutNodemask(t, nodemask, maxnode, nodemaskVal); err != nil {
   212  			return 0, nil, err
   213  		}
   214  	}
   215  	return 0, nil, nil
   216  }
   217  
   218  // SetMempolicy implements the syscall set_mempolicy(2).
   219  func SetMempolicy(t *kernel.Task, sysno uintptr, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
   220  	modeWithFlags := linux.NumaPolicy(args[0].Int())
   221  	nodemask := args[1].Pointer()
   222  	maxnode := args[2].Uint()
   223  
   224  	modeWithFlags, nodemaskVal, err := copyInMempolicyNodemask(t, modeWithFlags, nodemask, maxnode)
   225  	if err != nil {
   226  		return 0, nil, err
   227  	}
   228  
   229  	t.SetNumaPolicy(modeWithFlags, nodemaskVal)
   230  	return 0, nil, nil
   231  }
   232  
   233  // Mbind implements the syscall mbind(2).
   234  func Mbind(t *kernel.Task, sysno uintptr, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
   235  	addr := args[0].Pointer()
   236  	length := args[1].Uint64()
   237  	mode := linux.NumaPolicy(args[2].Int())
   238  	nodemask := args[3].Pointer()
   239  	maxnode := args[4].Uint()
   240  	flags := args[5].Uint()
   241  
   242  	if flags&^linux.MPOL_MF_VALID != 0 {
   243  		return 0, nil, linuxerr.EINVAL
   244  	}
   245  	// "If MPOL_MF_MOVE_ALL is passed in flags ... [the] calling thread must be
   246  	// privileged (CAP_SYS_NICE) to use this flag." - mbind(2)
   247  	if flags&linux.MPOL_MF_MOVE_ALL != 0 && !t.HasCapability(linux.CAP_SYS_NICE) {
   248  		return 0, nil, linuxerr.EPERM
   249  	}
   250  
   251  	mode, nodemaskVal, err := copyInMempolicyNodemask(t, mode, nodemask, maxnode)
   252  	if err != nil {
   253  		return 0, nil, err
   254  	}
   255  
   256  	// Since we claim to have only a single node, all flags can be ignored
   257  	// (since all pages must already be on that single node).
   258  	err = t.MemoryManager().SetNumaPolicy(addr, length, mode, nodemaskVal)
   259  	return 0, nil, err
   260  }
   261  
   262  func copyInMempolicyNodemask(t *kernel.Task, modeWithFlags linux.NumaPolicy, nodemask hostarch.Addr, maxnode uint32) (linux.NumaPolicy, uint64, error) {
   263  	flags := linux.NumaPolicy(modeWithFlags & linux.MPOL_MODE_FLAGS)
   264  	mode := linux.NumaPolicy(modeWithFlags &^ linux.MPOL_MODE_FLAGS)
   265  	if flags == linux.MPOL_MODE_FLAGS {
   266  		// Can't specify both mode flags simultaneously.
   267  		return 0, 0, linuxerr.EINVAL
   268  	}
   269  	if mode < 0 || mode >= linux.MPOL_MAX {
   270  		// Must specify a valid mode.
   271  		return 0, 0, linuxerr.EINVAL
   272  	}
   273  
   274  	var nodemaskVal uint64
   275  	if nodemask != 0 {
   276  		var err error
   277  		nodemaskVal, err = copyInNodemask(t, nodemask, maxnode)
   278  		if err != nil {
   279  			return 0, 0, err
   280  		}
   281  	}
   282  
   283  	switch mode {
   284  	case linux.MPOL_DEFAULT:
   285  		// "nodemask must be specified as NULL." - set_mempolicy(2). This is inaccurate;
   286  		// Linux allows a nodemask to be specified, as long as it is empty.
   287  		if nodemaskVal != 0 {
   288  			return 0, 0, linuxerr.EINVAL
   289  		}
   290  	case linux.MPOL_BIND, linux.MPOL_INTERLEAVE:
   291  		// These require a non-empty nodemask.
   292  		if nodemaskVal == 0 {
   293  			return 0, 0, linuxerr.EINVAL
   294  		}
   295  	case linux.MPOL_PREFERRED:
   296  		// This permits an empty nodemask, as long as no flags are set.
   297  		if nodemaskVal == 0 {
   298  			if flags != 0 {
   299  				return 0, 0, linuxerr.EINVAL
   300  			}
   301  			// On newer Linux versions, MPOL_PREFERRED is implemented as MPOL_LOCAL
   302  			// when node set is empty. See 7858d7bca7fb ("mm/mempolicy: don't handle
   303  			// MPOL_LOCAL like a fake MPOL_PREFERRED policy").
   304  			mode = linux.MPOL_LOCAL
   305  		}
   306  	case linux.MPOL_LOCAL:
   307  		// This requires an empty nodemask and no flags set.
   308  		if nodemaskVal != 0 || flags != 0 {
   309  			return 0, 0, linuxerr.EINVAL
   310  		}
   311  	default:
   312  		// Unknown mode, which we should have rejected above.
   313  		panic(fmt.Sprintf("unknown mode: %v", mode))
   314  	}
   315  
   316  	return mode | flags, nodemaskVal, nil
   317  }