github.com/metacubex/gvisor@v0.0.0-20240320004321-933faba989ec/pkg/sentry/syscalls/linux/sys_sem.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package linux
    16  
    17  import (
    18  	"math"
    19  	"time"
    20  
    21  	"github.com/metacubex/gvisor/pkg/abi/linux"
    22  	"github.com/metacubex/gvisor/pkg/errors/linuxerr"
    23  	"github.com/metacubex/gvisor/pkg/hostarch"
    24  	"github.com/metacubex/gvisor/pkg/marshal/primitive"
    25  	"github.com/metacubex/gvisor/pkg/sentry/arch"
    26  	"github.com/metacubex/gvisor/pkg/sentry/kernel"
    27  	"github.com/metacubex/gvisor/pkg/sentry/kernel/auth"
    28  	"github.com/metacubex/gvisor/pkg/sentry/kernel/ipc"
    29  )
    30  
    31  const opsMax = 500 // SEMOPM
    32  
    33  // Semget handles: semget(key_t key, int nsems, int semflg)
    34  func Semget(t *kernel.Task, sysno uintptr, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
    35  	key := ipc.Key(args[0].Int())
    36  	nsems := args[1].Int()
    37  	flag := args[2].Int()
    38  
    39  	private := key == linux.IPC_PRIVATE
    40  	create := flag&linux.IPC_CREAT == linux.IPC_CREAT
    41  	exclusive := flag&linux.IPC_EXCL == linux.IPC_EXCL
    42  	mode := linux.FileMode(flag & 0777)
    43  
    44  	r := t.IPCNamespace().SemaphoreRegistry()
    45  	set, err := r.FindOrCreate(t, key, nsems, mode, private, create, exclusive)
    46  	if err != nil {
    47  		return 0, nil, err
    48  	}
    49  	return uintptr(set.ID()), nil, nil
    50  }
    51  
    52  // Semtimedop handles: semop(int semid, struct sembuf *sops, size_t nsops, const struct timespec *timeout)
    53  func Semtimedop(t *kernel.Task, sysno uintptr, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
    54  	// If the timeout argument is NULL, then semtimedop() behaves exactly like semop().
    55  	if args[3].Pointer() == 0 {
    56  		return Semop(t, sysno, args)
    57  	}
    58  
    59  	id := ipc.ID(args[0].Int())
    60  	sembufAddr := args[1].Pointer()
    61  	nsops := args[2].SizeT()
    62  	timespecAddr := args[3].Pointer()
    63  	if nsops <= 0 {
    64  		return 0, nil, linuxerr.EINVAL
    65  	}
    66  	if nsops > opsMax {
    67  		return 0, nil, linuxerr.E2BIG
    68  	}
    69  
    70  	ops := make([]linux.Sembuf, nsops)
    71  	if _, err := linux.CopySembufSliceIn(t, sembufAddr, ops); err != nil {
    72  		return 0, nil, err
    73  	}
    74  
    75  	var timeout linux.Timespec
    76  	if _, err := timeout.CopyIn(t, timespecAddr); err != nil {
    77  		return 0, nil, err
    78  	}
    79  	if timeout.Sec < 0 || timeout.Nsec < 0 || timeout.Nsec >= 1e9 {
    80  		return 0, nil, linuxerr.EINVAL
    81  	}
    82  
    83  	if err := semTimedOp(t, id, ops, true, timeout.ToDuration()); err != nil {
    84  		if linuxerr.Equals(linuxerr.ETIMEDOUT, err) {
    85  			return 0, nil, linuxerr.EAGAIN
    86  		}
    87  		return 0, nil, err
    88  	}
    89  	return 0, nil, nil
    90  }
    91  
    92  // Semop handles: semop(int semid, struct sembuf *sops, size_t nsops)
    93  func Semop(t *kernel.Task, sysno uintptr, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
    94  	id := ipc.ID(args[0].Int())
    95  	sembufAddr := args[1].Pointer()
    96  	nsops := args[2].SizeT()
    97  
    98  	if nsops <= 0 {
    99  		return 0, nil, linuxerr.EINVAL
   100  	}
   101  	if nsops > opsMax {
   102  		return 0, nil, linuxerr.E2BIG
   103  	}
   104  
   105  	ops := make([]linux.Sembuf, nsops)
   106  	if _, err := linux.CopySembufSliceIn(t, sembufAddr, ops); err != nil {
   107  		return 0, nil, err
   108  	}
   109  	return 0, nil, semTimedOp(t, id, ops, false, time.Second)
   110  }
   111  
   112  func semTimedOp(t *kernel.Task, id ipc.ID, ops []linux.Sembuf, haveTimeout bool, timeout time.Duration) error {
   113  	set := t.IPCNamespace().SemaphoreRegistry().FindByID(id)
   114  
   115  	if set == nil {
   116  		return linuxerr.EINVAL
   117  	}
   118  	creds := auth.CredentialsFromContext(t)
   119  	pid := t.Kernel().GlobalInit().PIDNamespace().IDOfThreadGroup(t.ThreadGroup())
   120  	for {
   121  		ch, num, err := set.ExecuteOps(t, ops, creds, int32(pid))
   122  		if ch == nil || err != nil {
   123  			return err
   124  		}
   125  		if _, err = t.BlockWithTimeout(ch, haveTimeout, timeout); err != nil {
   126  			set.AbortWait(num, ch)
   127  			return err
   128  		}
   129  	}
   130  }
   131  
   132  // Semctl handles: semctl(int semid, int semnum, int cmd, ...)
   133  func Semctl(t *kernel.Task, sysno uintptr, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
   134  	id := ipc.ID(args[0].Int())
   135  	num := args[1].Int()
   136  	cmd := args[2].Int()
   137  
   138  	switch cmd {
   139  	case linux.SETVAL:
   140  		val := args[3].Int()
   141  		if val > math.MaxInt16 {
   142  			return 0, nil, linuxerr.ERANGE
   143  		}
   144  		return 0, nil, setVal(t, id, num, int16(val))
   145  
   146  	case linux.SETALL:
   147  		array := args[3].Pointer()
   148  		return 0, nil, setValAll(t, id, array)
   149  
   150  	case linux.GETVAL:
   151  		v, err := getVal(t, id, num)
   152  		return uintptr(v), nil, err
   153  
   154  	case linux.GETALL:
   155  		array := args[3].Pointer()
   156  		return 0, nil, getValAll(t, id, array)
   157  
   158  	case linux.IPC_RMID:
   159  		return 0, nil, remove(t, id)
   160  
   161  	case linux.IPC_SET:
   162  		arg := args[3].Pointer()
   163  		var s linux.SemidDS
   164  		if _, err := s.CopyIn(t, arg); err != nil {
   165  			return 0, nil, err
   166  		}
   167  
   168  		return 0, nil, ipcSet(t, id, &s)
   169  
   170  	case linux.GETPID:
   171  		v, err := getPID(t, id, num)
   172  		return uintptr(v), nil, err
   173  
   174  	case linux.IPC_STAT:
   175  		arg := args[3].Pointer()
   176  		ds, err := ipcStat(t, id)
   177  		if err == nil {
   178  			_, err = ds.CopyOut(t, arg)
   179  		}
   180  
   181  		return 0, nil, err
   182  
   183  	case linux.GETZCNT:
   184  		v, err := getZCnt(t, id, num)
   185  		return uintptr(v), nil, err
   186  
   187  	case linux.GETNCNT:
   188  		v, err := getNCnt(t, id, num)
   189  		return uintptr(v), nil, err
   190  
   191  	case linux.IPC_INFO:
   192  		buf := args[3].Pointer()
   193  		r := t.IPCNamespace().SemaphoreRegistry()
   194  		info := r.IPCInfo()
   195  		if _, err := info.CopyOut(t, buf); err != nil {
   196  			return 0, nil, err
   197  		}
   198  		return uintptr(r.HighestIndex()), nil, nil
   199  
   200  	case linux.SEM_INFO:
   201  		buf := args[3].Pointer()
   202  		r := t.IPCNamespace().SemaphoreRegistry()
   203  		info := r.SemInfo()
   204  		if _, err := info.CopyOut(t, buf); err != nil {
   205  			return 0, nil, err
   206  		}
   207  		return uintptr(r.HighestIndex()), nil, nil
   208  
   209  	case linux.SEM_STAT:
   210  		arg := args[3].Pointer()
   211  		// id is an index in SEM_STAT.
   212  		semid, ds, err := semStat(t, int32(id))
   213  		if err != nil {
   214  			return 0, nil, err
   215  		}
   216  		if _, err := ds.CopyOut(t, arg); err != nil {
   217  			return 0, nil, err
   218  		}
   219  		return uintptr(semid), nil, err
   220  
   221  	case linux.SEM_STAT_ANY:
   222  		arg := args[3].Pointer()
   223  		// id is an index in SEM_STAT.
   224  		semid, ds, err := semStatAny(t, int32(id))
   225  		if err != nil {
   226  			return 0, nil, err
   227  		}
   228  		if _, err := ds.CopyOut(t, arg); err != nil {
   229  			return 0, nil, err
   230  		}
   231  		return uintptr(semid), nil, err
   232  
   233  	default:
   234  		return 0, nil, linuxerr.EINVAL
   235  	}
   236  }
   237  
   238  func remove(t *kernel.Task, id ipc.ID) error {
   239  	r := t.IPCNamespace().SemaphoreRegistry()
   240  	creds := auth.CredentialsFromContext(t)
   241  	return r.Remove(id, creds)
   242  }
   243  
   244  func ipcSet(t *kernel.Task, id ipc.ID, ds *linux.SemidDS) error {
   245  	r := t.IPCNamespace().SemaphoreRegistry()
   246  	set := r.FindByID(id)
   247  	if set == nil {
   248  		return linuxerr.EINVAL
   249  	}
   250  	return set.Set(t, ds)
   251  }
   252  
   253  func ipcStat(t *kernel.Task, id ipc.ID) (*linux.SemidDS, error) {
   254  	r := t.IPCNamespace().SemaphoreRegistry()
   255  	set := r.FindByID(id)
   256  	if set == nil {
   257  		return nil, linuxerr.EINVAL
   258  	}
   259  	creds := auth.CredentialsFromContext(t)
   260  	return set.GetStat(creds)
   261  }
   262  
   263  func semStat(t *kernel.Task, index int32) (int32, *linux.SemidDS, error) {
   264  	r := t.IPCNamespace().SemaphoreRegistry()
   265  	set := r.FindByIndex(index)
   266  	if set == nil {
   267  		return 0, nil, linuxerr.EINVAL
   268  	}
   269  	creds := auth.CredentialsFromContext(t)
   270  	ds, err := set.GetStat(creds)
   271  	if err != nil {
   272  		return 0, ds, err
   273  	}
   274  	return int32(set.ID()), ds, nil
   275  }
   276  
   277  func semStatAny(t *kernel.Task, index int32) (int32, *linux.SemidDS, error) {
   278  	set := t.IPCNamespace().SemaphoreRegistry().FindByIndex(index)
   279  	if set == nil {
   280  		return 0, nil, linuxerr.EINVAL
   281  	}
   282  	creds := auth.CredentialsFromContext(t)
   283  	ds, err := set.GetStatAny(creds)
   284  	if err != nil {
   285  		return 0, ds, err
   286  	}
   287  	return int32(set.ID()), ds, nil
   288  }
   289  
   290  func setVal(t *kernel.Task, id ipc.ID, num int32, val int16) error {
   291  	r := t.IPCNamespace().SemaphoreRegistry()
   292  	set := r.FindByID(id)
   293  	if set == nil {
   294  		return linuxerr.EINVAL
   295  	}
   296  	creds := auth.CredentialsFromContext(t)
   297  	pid := t.Kernel().GlobalInit().PIDNamespace().IDOfThreadGroup(t.ThreadGroup())
   298  	return set.SetVal(t, num, val, creds, int32(pid))
   299  }
   300  
   301  func setValAll(t *kernel.Task, id ipc.ID, array hostarch.Addr) error {
   302  	r := t.IPCNamespace().SemaphoreRegistry()
   303  	set := r.FindByID(id)
   304  	if set == nil {
   305  		return linuxerr.EINVAL
   306  	}
   307  	vals := make([]uint16, set.Size())
   308  	if _, err := primitive.CopyUint16SliceIn(t, array, vals); err != nil {
   309  		return err
   310  	}
   311  	creds := auth.CredentialsFromContext(t)
   312  	pid := t.Kernel().GlobalInit().PIDNamespace().IDOfThreadGroup(t.ThreadGroup())
   313  	return set.SetValAll(t, vals, creds, int32(pid))
   314  }
   315  
   316  func getVal(t *kernel.Task, id ipc.ID, num int32) (int16, error) {
   317  	r := t.IPCNamespace().SemaphoreRegistry()
   318  	set := r.FindByID(id)
   319  	if set == nil {
   320  		return 0, linuxerr.EINVAL
   321  	}
   322  	creds := auth.CredentialsFromContext(t)
   323  	return set.GetVal(num, creds)
   324  }
   325  
   326  func getValAll(t *kernel.Task, id ipc.ID, array hostarch.Addr) error {
   327  	r := t.IPCNamespace().SemaphoreRegistry()
   328  	set := r.FindByID(id)
   329  	if set == nil {
   330  		return linuxerr.EINVAL
   331  	}
   332  	creds := auth.CredentialsFromContext(t)
   333  	vals, err := set.GetValAll(creds)
   334  	if err != nil {
   335  		return err
   336  	}
   337  	_, err = primitive.CopyUint16SliceOut(t, array, vals)
   338  	return err
   339  }
   340  
   341  func getPID(t *kernel.Task, id ipc.ID, num int32) (int32, error) {
   342  	r := t.IPCNamespace().SemaphoreRegistry()
   343  	set := r.FindByID(id)
   344  	if set == nil {
   345  		return 0, linuxerr.EINVAL
   346  	}
   347  	creds := auth.CredentialsFromContext(t)
   348  	gpid, err := set.GetPID(num, creds)
   349  	if err != nil {
   350  		return 0, err
   351  	}
   352  	// Convert pid from init namespace to the caller's namespace.
   353  	tg := t.PIDNamespace().ThreadGroupWithID(kernel.ThreadID(gpid))
   354  	if tg == nil {
   355  		return 0, nil
   356  	}
   357  	return int32(tg.ID()), nil
   358  }
   359  
   360  func getZCnt(t *kernel.Task, id ipc.ID, num int32) (uint16, error) {
   361  	r := t.IPCNamespace().SemaphoreRegistry()
   362  	set := r.FindByID(id)
   363  	if set == nil {
   364  		return 0, linuxerr.EINVAL
   365  	}
   366  	creds := auth.CredentialsFromContext(t)
   367  	return set.CountZeroWaiters(num, creds)
   368  }
   369  
   370  func getNCnt(t *kernel.Task, id ipc.ID, num int32) (uint16, error) {
   371  	r := t.IPCNamespace().SemaphoreRegistry()
   372  	set := r.FindByID(id)
   373  	if set == nil {
   374  		return 0, linuxerr.EINVAL
   375  	}
   376  	creds := auth.CredentialsFromContext(t)
   377  	return set.CountNegativeWaiters(num, creds)
   378  }