gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/pkg/sentry/kernel/semaphore/semaphore_test.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package semaphore
    16  
    17  import (
    18  	"testing"
    19  
    20  	"gvisor.dev/gvisor/pkg/abi/linux"
    21  	"gvisor.dev/gvisor/pkg/context"
    22  	"gvisor.dev/gvisor/pkg/errors/linuxerr"
    23  	"gvisor.dev/gvisor/pkg/sentry/contexttest"
    24  	"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
    25  	"gvisor.dev/gvisor/pkg/sentry/kernel/ipc"
    26  )
    27  
    28  func executeOps(ctx context.Context, t *testing.T, set *Set, ops []linux.Sembuf, block bool) chan struct{} {
    29  	ch, _, err := set.executeOps(ctx, ops, 123)
    30  	if err != nil {
    31  		t.Fatalf("ExecuteOps(ops) failed, err: %v, ops: %+v", err, ops)
    32  	}
    33  	if block {
    34  		if ch == nil {
    35  			t.Fatalf("ExecuteOps(ops) got: nil, expected: !nil, ops: %+v", ops)
    36  		}
    37  		if signalled(ch) {
    38  			t.Fatalf("ExecuteOps(ops) channel should not have been signalled, ops: %+v", ops)
    39  		}
    40  	} else {
    41  		if ch != nil {
    42  			t.Fatalf("ExecuteOps(ops) got: %v, expected: nil, ops: %+v", ch, ops)
    43  		}
    44  	}
    45  	return ch
    46  }
    47  
    48  func signalled(ch chan struct{}) bool {
    49  	select {
    50  	case <-ch:
    51  		return true
    52  	default:
    53  		return false
    54  	}
    55  }
    56  
    57  func TestBasic(t *testing.T) {
    58  	ctx := contexttest.Context(t)
    59  	set := &Set{obj: &ipc.Object{ID: 123}, sems: make([]sem, 1)}
    60  	ops := []linux.Sembuf{
    61  		{SemOp: 1},
    62  	}
    63  	executeOps(ctx, t, set, ops, false)
    64  
    65  	ops[0].SemOp = -1
    66  	executeOps(ctx, t, set, ops, false)
    67  
    68  	ops[0].SemOp = -1
    69  	ch1 := executeOps(ctx, t, set, ops, true)
    70  
    71  	ops[0].SemOp = 1
    72  	executeOps(ctx, t, set, ops, false)
    73  	if !signalled(ch1) {
    74  		t.Fatalf("ExecuteOps(ops) channel should not have been signalled, ops: %+v", ops)
    75  	}
    76  }
    77  
    78  func TestWaitForZero(t *testing.T) {
    79  	ctx := contexttest.Context(t)
    80  	set := &Set{obj: &ipc.Object{ID: 123}, sems: make([]sem, 1)}
    81  	ops := []linux.Sembuf{
    82  		{SemOp: 0},
    83  	}
    84  	executeOps(ctx, t, set, ops, false)
    85  
    86  	ops[0].SemOp = -2
    87  	ch1 := executeOps(ctx, t, set, ops, true)
    88  
    89  	ops[0].SemOp = 0
    90  	executeOps(ctx, t, set, ops, false)
    91  
    92  	ops[0].SemOp = 1
    93  	executeOps(ctx, t, set, ops, false)
    94  
    95  	ops[0].SemOp = 0
    96  	chZero1 := executeOps(ctx, t, set, ops, true)
    97  
    98  	ops[0].SemOp = 0
    99  	chZero2 := executeOps(ctx, t, set, ops, true)
   100  
   101  	ops[0].SemOp = 1
   102  	executeOps(ctx, t, set, ops, false)
   103  	if !signalled(ch1) {
   104  		t.Fatalf("ExecuteOps(ops) channel should have been signalled, ops: %+v, set: %+v", ops, set)
   105  	}
   106  
   107  	ops[0].SemOp = -2
   108  	executeOps(ctx, t, set, ops, false)
   109  	if !signalled(chZero1) {
   110  		t.Fatalf("ExecuteOps(ops) channel zero 1 should have been signalled, ops: %+v, set: %+v", ops, set)
   111  	}
   112  	if !signalled(chZero2) {
   113  		t.Fatalf("ExecuteOps(ops) channel zero 2 should have been signalled, ops: %+v, set: %+v", ops, set)
   114  	}
   115  }
   116  
   117  func TestNoWait(t *testing.T) {
   118  	ctx := contexttest.Context(t)
   119  	set := &Set{obj: &ipc.Object{ID: 123}, sems: make([]sem, 1)}
   120  	ops := []linux.Sembuf{
   121  		{SemOp: 1},
   122  	}
   123  	executeOps(ctx, t, set, ops, false)
   124  
   125  	ops[0].SemOp = -2
   126  	ops[0].SemFlg = linux.IPC_NOWAIT
   127  	if _, _, err := set.executeOps(ctx, ops, 123); err != linuxerr.ErrWouldBlock {
   128  		t.Fatalf("ExecuteOps(ops) wrong result, got: %v, expected: %v", err, linuxerr.ErrWouldBlock)
   129  	}
   130  
   131  	ops[0].SemOp = 0
   132  	ops[0].SemFlg = linux.IPC_NOWAIT
   133  	if _, _, err := set.executeOps(ctx, ops, 123); err != linuxerr.ErrWouldBlock {
   134  		t.Fatalf("ExecuteOps(ops) wrong result, got: %v, expected: %v", err, linuxerr.ErrWouldBlock)
   135  	}
   136  }
   137  
   138  func TestUnregister(t *testing.T) {
   139  	ctx := contexttest.Context(t)
   140  	r := NewRegistry(auth.NewRootUserNamespace())
   141  	set, err := r.FindOrCreate(ctx, 123, 2, linux.FileMode(0x600), true, true, true)
   142  
   143  	if err != nil {
   144  		t.Fatalf("FindOrCreate() failed, err: %v", err)
   145  	}
   146  	if got := r.FindByID(set.obj.ID); got.obj.ID != set.obj.ID {
   147  		t.Fatalf("FindById(%d) failed, got: %+v, expected: %+v", set.obj.ID, got, set)
   148  	}
   149  
   150  	ops := []linux.Sembuf{
   151  		{SemOp: -1},
   152  	}
   153  	chs := make([]chan struct{}, 0, 5)
   154  	for i := 0; i < 5; i++ {
   155  		ch := executeOps(ctx, t, set, ops, true)
   156  		chs = append(chs, ch)
   157  	}
   158  
   159  	creds := auth.CredentialsFromContext(ctx)
   160  	if err := r.Remove(set.obj.ID, creds); err != nil {
   161  		t.Fatalf("Remove(%d) failed, err: %v", set.obj.ID, err)
   162  	}
   163  	if !set.dead {
   164  		t.Fatalf("set is not dead: %+v", set)
   165  	}
   166  	if got := r.FindByID(set.obj.ID); got != nil {
   167  		t.Fatalf("FindById(%d) failed, got: %+v, expected: nil", set.obj.ID, got)
   168  	}
   169  	for i, ch := range chs {
   170  		if !signalled(ch) {
   171  			t.Fatalf("channel %d should have been signalled", i)
   172  		}
   173  	}
   174  }