github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/pkg/sentry/syscalls/linux/sys_sem.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package linux
    16  
    17  import (
    18  	"math"
    19  	"time"
    20  
    21  	"github.com/SagerNet/gvisor/pkg/abi/linux"
    22  	"github.com/SagerNet/gvisor/pkg/errors/linuxerr"
    23  	"github.com/SagerNet/gvisor/pkg/hostarch"
    24  	"github.com/SagerNet/gvisor/pkg/marshal/primitive"
    25  	"github.com/SagerNet/gvisor/pkg/sentry/arch"
    26  	"github.com/SagerNet/gvisor/pkg/sentry/fs"
    27  	"github.com/SagerNet/gvisor/pkg/sentry/kernel"
    28  	"github.com/SagerNet/gvisor/pkg/sentry/kernel/auth"
    29  	"github.com/SagerNet/gvisor/pkg/syserror"
    30  )
    31  
    32  const opsMax = 500 // SEMOPM
    33  
    34  // Semget handles: semget(key_t key, int nsems, int semflg)
    35  func Semget(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
    36  	key := args[0].Int()
    37  	nsems := args[1].Int()
    38  	flag := args[2].Int()
    39  
    40  	private := key == linux.IPC_PRIVATE
    41  	create := flag&linux.IPC_CREAT == linux.IPC_CREAT
    42  	exclusive := flag&linux.IPC_EXCL == linux.IPC_EXCL
    43  	mode := linux.FileMode(flag & 0777)
    44  
    45  	r := t.IPCNamespace().SemaphoreRegistry()
    46  	set, err := r.FindOrCreate(t, key, nsems, mode, private, create, exclusive)
    47  	if err != nil {
    48  		return 0, nil, err
    49  	}
    50  	return uintptr(set.ID), nil, nil
    51  }
    52  
    53  // Semtimedop handles: semop(int semid, struct sembuf *sops, size_t nsops, const struct timespec *timeout)
    54  func Semtimedop(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
    55  	// If the timeout argument is NULL, then semtimedop() behaves exactly like semop().
    56  	if args[3].Pointer() == 0 {
    57  		return Semop(t, args)
    58  	}
    59  
    60  	id := args[0].Int()
    61  	sembufAddr := args[1].Pointer()
    62  	nsops := args[2].SizeT()
    63  	timespecAddr := args[3].Pointer()
    64  	if nsops <= 0 {
    65  		return 0, nil, linuxerr.EINVAL
    66  	}
    67  	if nsops > opsMax {
    68  		return 0, nil, linuxerr.E2BIG
    69  	}
    70  
    71  	ops := make([]linux.Sembuf, nsops)
    72  	if _, err := linux.CopySembufSliceIn(t, sembufAddr, ops); err != nil {
    73  		return 0, nil, err
    74  	}
    75  
    76  	var timeout linux.Timespec
    77  	if _, err := timeout.CopyIn(t, timespecAddr); err != nil {
    78  		return 0, nil, err
    79  	}
    80  	if timeout.Sec < 0 || timeout.Nsec < 0 || timeout.Nsec >= 1e9 {
    81  		return 0, nil, linuxerr.EINVAL
    82  	}
    83  
    84  	if err := semTimedOp(t, id, ops, true, timeout.ToDuration()); err != nil {
    85  		if linuxerr.Equals(linuxerr.ETIMEDOUT, err) {
    86  			return 0, nil, linuxerr.EAGAIN
    87  		}
    88  		return 0, nil, err
    89  	}
    90  	return 0, nil, nil
    91  }
    92  
    93  // Semop handles: semop(int semid, struct sembuf *sops, size_t nsops)
    94  func Semop(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
    95  	id := args[0].Int()
    96  	sembufAddr := args[1].Pointer()
    97  	nsops := args[2].SizeT()
    98  
    99  	if nsops <= 0 {
   100  		return 0, nil, linuxerr.EINVAL
   101  	}
   102  	if nsops > opsMax {
   103  		return 0, nil, linuxerr.E2BIG
   104  	}
   105  
   106  	ops := make([]linux.Sembuf, nsops)
   107  	if _, err := linux.CopySembufSliceIn(t, sembufAddr, ops); err != nil {
   108  		return 0, nil, err
   109  	}
   110  	return 0, nil, semTimedOp(t, id, ops, false, time.Second)
   111  }
   112  
   113  func semTimedOp(t *kernel.Task, id int32, ops []linux.Sembuf, haveTimeout bool, timeout time.Duration) error {
   114  	set := t.IPCNamespace().SemaphoreRegistry().FindByID(id)
   115  
   116  	if set == nil {
   117  		return linuxerr.EINVAL
   118  	}
   119  	creds := auth.CredentialsFromContext(t)
   120  	pid := t.Kernel().GlobalInit().PIDNamespace().IDOfThreadGroup(t.ThreadGroup())
   121  	for {
   122  		ch, num, err := set.ExecuteOps(t, ops, creds, int32(pid))
   123  		if ch == nil || err != nil {
   124  			return err
   125  		}
   126  		if _, err = t.BlockWithTimeout(ch, haveTimeout, timeout); err != nil {
   127  			set.AbortWait(num, ch)
   128  			return err
   129  		}
   130  	}
   131  }
   132  
   133  // Semctl handles: semctl(int semid, int semnum, int cmd, ...)
   134  func Semctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
   135  	id := args[0].Int()
   136  	num := args[1].Int()
   137  	cmd := args[2].Int()
   138  
   139  	switch cmd {
   140  	case linux.SETVAL:
   141  		val := args[3].Int()
   142  		if val > math.MaxInt16 {
   143  			return 0, nil, syserror.ERANGE
   144  		}
   145  		return 0, nil, setVal(t, id, num, int16(val))
   146  
   147  	case linux.SETALL:
   148  		array := args[3].Pointer()
   149  		return 0, nil, setValAll(t, id, array)
   150  
   151  	case linux.GETVAL:
   152  		v, err := getVal(t, id, num)
   153  		return uintptr(v), nil, err
   154  
   155  	case linux.GETALL:
   156  		array := args[3].Pointer()
   157  		return 0, nil, getValAll(t, id, array)
   158  
   159  	case linux.IPC_RMID:
   160  		return 0, nil, remove(t, id)
   161  
   162  	case linux.IPC_SET:
   163  		arg := args[3].Pointer()
   164  		var s linux.SemidDS
   165  		if _, err := s.CopyIn(t, arg); err != nil {
   166  			return 0, nil, err
   167  		}
   168  
   169  		perms := fs.FilePermsFromMode(linux.FileMode(s.SemPerm.Mode & 0777))
   170  		return 0, nil, ipcSet(t, id, auth.UID(s.SemPerm.UID), auth.GID(s.SemPerm.GID), perms)
   171  
   172  	case linux.GETPID:
   173  		v, err := getPID(t, id, num)
   174  		return uintptr(v), nil, err
   175  
   176  	case linux.IPC_STAT:
   177  		arg := args[3].Pointer()
   178  		ds, err := ipcStat(t, id)
   179  		if err == nil {
   180  			_, err = ds.CopyOut(t, arg)
   181  		}
   182  
   183  		return 0, nil, err
   184  
   185  	case linux.GETZCNT:
   186  		v, err := getZCnt(t, id, num)
   187  		return uintptr(v), nil, err
   188  
   189  	case linux.GETNCNT:
   190  		v, err := getNCnt(t, id, num)
   191  		return uintptr(v), nil, err
   192  
   193  	case linux.IPC_INFO:
   194  		buf := args[3].Pointer()
   195  		r := t.IPCNamespace().SemaphoreRegistry()
   196  		info := r.IPCInfo()
   197  		if _, err := info.CopyOut(t, buf); err != nil {
   198  			return 0, nil, err
   199  		}
   200  		return uintptr(r.HighestIndex()), nil, nil
   201  
   202  	case linux.SEM_INFO:
   203  		buf := args[3].Pointer()
   204  		r := t.IPCNamespace().SemaphoreRegistry()
   205  		info := r.SemInfo()
   206  		if _, err := info.CopyOut(t, buf); err != nil {
   207  			return 0, nil, err
   208  		}
   209  		return uintptr(r.HighestIndex()), nil, nil
   210  
   211  	case linux.SEM_STAT:
   212  		arg := args[3].Pointer()
   213  		// id is an index in SEM_STAT.
   214  		semid, ds, err := semStat(t, id)
   215  		if err != nil {
   216  			return 0, nil, err
   217  		}
   218  		if _, err := ds.CopyOut(t, arg); err != nil {
   219  			return 0, nil, err
   220  		}
   221  		return uintptr(semid), nil, err
   222  
   223  	case linux.SEM_STAT_ANY:
   224  		arg := args[3].Pointer()
   225  		// id is an index in SEM_STAT.
   226  		semid, ds, err := semStatAny(t, id)
   227  		if err != nil {
   228  			return 0, nil, err
   229  		}
   230  		if _, err := ds.CopyOut(t, arg); err != nil {
   231  			return 0, nil, err
   232  		}
   233  		return uintptr(semid), nil, err
   234  
   235  	default:
   236  		return 0, nil, linuxerr.EINVAL
   237  	}
   238  }
   239  
   240  func remove(t *kernel.Task, id int32) error {
   241  	r := t.IPCNamespace().SemaphoreRegistry()
   242  	creds := auth.CredentialsFromContext(t)
   243  	return r.RemoveID(id, creds)
   244  }
   245  
   246  func ipcSet(t *kernel.Task, id int32, uid auth.UID, gid auth.GID, perms fs.FilePermissions) error {
   247  	r := t.IPCNamespace().SemaphoreRegistry()
   248  	set := r.FindByID(id)
   249  	if set == nil {
   250  		return linuxerr.EINVAL
   251  	}
   252  
   253  	creds := auth.CredentialsFromContext(t)
   254  	kuid := creds.UserNamespace.MapToKUID(uid)
   255  	if !kuid.Ok() {
   256  		return linuxerr.EINVAL
   257  	}
   258  	kgid := creds.UserNamespace.MapToKGID(gid)
   259  	if !kgid.Ok() {
   260  		return linuxerr.EINVAL
   261  	}
   262  	owner := fs.FileOwner{UID: kuid, GID: kgid}
   263  	return set.Change(t, creds, owner, perms)
   264  }
   265  
   266  func ipcStat(t *kernel.Task, id int32) (*linux.SemidDS, error) {
   267  	r := t.IPCNamespace().SemaphoreRegistry()
   268  	set := r.FindByID(id)
   269  	if set == nil {
   270  		return nil, linuxerr.EINVAL
   271  	}
   272  	creds := auth.CredentialsFromContext(t)
   273  	return set.GetStat(creds)
   274  }
   275  
   276  func semStat(t *kernel.Task, index int32) (int32, *linux.SemidDS, error) {
   277  	r := t.IPCNamespace().SemaphoreRegistry()
   278  	set := r.FindByIndex(index)
   279  	if set == nil {
   280  		return 0, nil, linuxerr.EINVAL
   281  	}
   282  	creds := auth.CredentialsFromContext(t)
   283  	ds, err := set.GetStat(creds)
   284  	if err != nil {
   285  		return 0, ds, err
   286  	}
   287  	return set.ID, ds, nil
   288  }
   289  
   290  func semStatAny(t *kernel.Task, index int32) (int32, *linux.SemidDS, error) {
   291  	set := t.IPCNamespace().SemaphoreRegistry().FindByIndex(index)
   292  	if set == nil {
   293  		return 0, nil, linuxerr.EINVAL
   294  	}
   295  	creds := auth.CredentialsFromContext(t)
   296  	ds, err := set.GetStatAny(creds)
   297  	if err != nil {
   298  		return 0, ds, err
   299  	}
   300  	return set.ID, ds, nil
   301  }
   302  
   303  func setVal(t *kernel.Task, id int32, num int32, val int16) error {
   304  	r := t.IPCNamespace().SemaphoreRegistry()
   305  	set := r.FindByID(id)
   306  	if set == nil {
   307  		return linuxerr.EINVAL
   308  	}
   309  	creds := auth.CredentialsFromContext(t)
   310  	pid := t.Kernel().GlobalInit().PIDNamespace().IDOfThreadGroup(t.ThreadGroup())
   311  	return set.SetVal(t, num, val, creds, int32(pid))
   312  }
   313  
   314  func setValAll(t *kernel.Task, id int32, array hostarch.Addr) error {
   315  	r := t.IPCNamespace().SemaphoreRegistry()
   316  	set := r.FindByID(id)
   317  	if set == nil {
   318  		return linuxerr.EINVAL
   319  	}
   320  	vals := make([]uint16, set.Size())
   321  	if _, err := primitive.CopyUint16SliceIn(t, array, vals); err != nil {
   322  		return err
   323  	}
   324  	creds := auth.CredentialsFromContext(t)
   325  	pid := t.Kernel().GlobalInit().PIDNamespace().IDOfThreadGroup(t.ThreadGroup())
   326  	return set.SetValAll(t, vals, creds, int32(pid))
   327  }
   328  
   329  func getVal(t *kernel.Task, id int32, num int32) (int16, error) {
   330  	r := t.IPCNamespace().SemaphoreRegistry()
   331  	set := r.FindByID(id)
   332  	if set == nil {
   333  		return 0, linuxerr.EINVAL
   334  	}
   335  	creds := auth.CredentialsFromContext(t)
   336  	return set.GetVal(num, creds)
   337  }
   338  
   339  func getValAll(t *kernel.Task, id int32, array hostarch.Addr) error {
   340  	r := t.IPCNamespace().SemaphoreRegistry()
   341  	set := r.FindByID(id)
   342  	if set == nil {
   343  		return linuxerr.EINVAL
   344  	}
   345  	creds := auth.CredentialsFromContext(t)
   346  	vals, err := set.GetValAll(creds)
   347  	if err != nil {
   348  		return err
   349  	}
   350  	_, err = primitive.CopyUint16SliceOut(t, array, vals)
   351  	return err
   352  }
   353  
   354  func getPID(t *kernel.Task, id int32, num int32) (int32, error) {
   355  	r := t.IPCNamespace().SemaphoreRegistry()
   356  	set := r.FindByID(id)
   357  	if set == nil {
   358  		return 0, linuxerr.EINVAL
   359  	}
   360  	creds := auth.CredentialsFromContext(t)
   361  	gpid, err := set.GetPID(num, creds)
   362  	if err != nil {
   363  		return 0, err
   364  	}
   365  	// Convert pid from init namespace to the caller's namespace.
   366  	tg := t.PIDNamespace().ThreadGroupWithID(kernel.ThreadID(gpid))
   367  	if tg == nil {
   368  		return 0, nil
   369  	}
   370  	return int32(tg.ID()), nil
   371  }
   372  
   373  func getZCnt(t *kernel.Task, id int32, num int32) (uint16, error) {
   374  	r := t.IPCNamespace().SemaphoreRegistry()
   375  	set := r.FindByID(id)
   376  	if set == nil {
   377  		return 0, linuxerr.EINVAL
   378  	}
   379  	creds := auth.CredentialsFromContext(t)
   380  	return set.CountZeroWaiters(num, creds)
   381  }
   382  
   383  func getNCnt(t *kernel.Task, id int32, num int32) (uint16, error) {
   384  	r := t.IPCNamespace().SemaphoreRegistry()
   385  	set := r.FindByID(id)
   386  	if set == nil {
   387  		return 0, linuxerr.EINVAL
   388  	}
   389  	creds := auth.CredentialsFromContext(t)
   390  	return set.CountNegativeWaiters(num, creds)
   391  }