github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/pkg/sentry/kernel/semaphore/semaphore_test.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package semaphore
    16  
    17  import (
    18  	"testing"
    19  
    20  	"github.com/SagerNet/gvisor/pkg/abi/linux"
    21  	"github.com/SagerNet/gvisor/pkg/context"
    22  	"github.com/SagerNet/gvisor/pkg/sentry/contexttest"
    23  	"github.com/SagerNet/gvisor/pkg/sentry/kernel/auth"
    24  	"github.com/SagerNet/gvisor/pkg/syserror"
    25  )
    26  
    27  func executeOps(ctx context.Context, t *testing.T, set *Set, ops []linux.Sembuf, block bool) chan struct{} {
    28  	ch, _, err := set.executeOps(ctx, ops, 123)
    29  	if err != nil {
    30  		t.Fatalf("ExecuteOps(ops) failed, err: %v, ops: %+v", err, ops)
    31  	}
    32  	if block {
    33  		if ch == nil {
    34  			t.Fatalf("ExecuteOps(ops) got: nil, expected: !nil, ops: %+v", ops)
    35  		}
    36  		if signalled(ch) {
    37  			t.Fatalf("ExecuteOps(ops) channel should not have been signalled, ops: %+v", ops)
    38  		}
    39  	} else {
    40  		if ch != nil {
    41  			t.Fatalf("ExecuteOps(ops) got: %v, expected: nil, ops: %+v", ch, ops)
    42  		}
    43  	}
    44  	return ch
    45  }
    46  
    47  func signalled(ch chan struct{}) bool {
    48  	select {
    49  	case <-ch:
    50  		return true
    51  	default:
    52  		return false
    53  	}
    54  }
    55  
    56  func TestBasic(t *testing.T) {
    57  	ctx := contexttest.Context(t)
    58  	set := &Set{ID: 123, sems: make([]sem, 1)}
    59  	ops := []linux.Sembuf{
    60  		{SemOp: 1},
    61  	}
    62  	executeOps(ctx, t, set, ops, false)
    63  
    64  	ops[0].SemOp = -1
    65  	executeOps(ctx, t, set, ops, false)
    66  
    67  	ops[0].SemOp = -1
    68  	ch1 := executeOps(ctx, t, set, ops, true)
    69  
    70  	ops[0].SemOp = 1
    71  	executeOps(ctx, t, set, ops, false)
    72  	if !signalled(ch1) {
    73  		t.Fatalf("ExecuteOps(ops) channel should not have been signalled, ops: %+v", ops)
    74  	}
    75  }
    76  
    77  func TestWaitForZero(t *testing.T) {
    78  	ctx := contexttest.Context(t)
    79  	set := &Set{ID: 123, sems: make([]sem, 1)}
    80  	ops := []linux.Sembuf{
    81  		{SemOp: 0},
    82  	}
    83  	executeOps(ctx, t, set, ops, false)
    84  
    85  	ops[0].SemOp = -2
    86  	ch1 := executeOps(ctx, t, set, ops, true)
    87  
    88  	ops[0].SemOp = 0
    89  	executeOps(ctx, t, set, ops, false)
    90  
    91  	ops[0].SemOp = 1
    92  	executeOps(ctx, t, set, ops, false)
    93  
    94  	ops[0].SemOp = 0
    95  	chZero1 := executeOps(ctx, t, set, ops, true)
    96  
    97  	ops[0].SemOp = 0
    98  	chZero2 := executeOps(ctx, t, set, ops, true)
    99  
   100  	ops[0].SemOp = 1
   101  	executeOps(ctx, t, set, ops, false)
   102  	if !signalled(ch1) {
   103  		t.Fatalf("ExecuteOps(ops) channel should have been signalled, ops: %+v, set: %+v", ops, set)
   104  	}
   105  
   106  	ops[0].SemOp = -2
   107  	executeOps(ctx, t, set, ops, false)
   108  	if !signalled(chZero1) {
   109  		t.Fatalf("ExecuteOps(ops) channel zero 1 should have been signalled, ops: %+v, set: %+v", ops, set)
   110  	}
   111  	if !signalled(chZero2) {
   112  		t.Fatalf("ExecuteOps(ops) channel zero 2 should have been signalled, ops: %+v, set: %+v", ops, set)
   113  	}
   114  }
   115  
   116  func TestNoWait(t *testing.T) {
   117  	ctx := contexttest.Context(t)
   118  	set := &Set{ID: 123, sems: make([]sem, 1)}
   119  	ops := []linux.Sembuf{
   120  		{SemOp: 1},
   121  	}
   122  	executeOps(ctx, t, set, ops, false)
   123  
   124  	ops[0].SemOp = -2
   125  	ops[0].SemFlg = linux.IPC_NOWAIT
   126  	if _, _, err := set.executeOps(ctx, ops, 123); err != syserror.ErrWouldBlock {
   127  		t.Fatalf("ExecuteOps(ops) wrong result, got: %v, expected: %v", err, syserror.ErrWouldBlock)
   128  	}
   129  
   130  	ops[0].SemOp = 0
   131  	ops[0].SemFlg = linux.IPC_NOWAIT
   132  	if _, _, err := set.executeOps(ctx, ops, 123); err != syserror.ErrWouldBlock {
   133  		t.Fatalf("ExecuteOps(ops) wrong result, got: %v, expected: %v", err, syserror.ErrWouldBlock)
   134  	}
   135  }
   136  
   137  func TestUnregister(t *testing.T) {
   138  	ctx := contexttest.Context(t)
   139  	r := NewRegistry(auth.NewRootUserNamespace())
   140  	set, err := r.FindOrCreate(ctx, 123, 2, linux.FileMode(0x600), true, true, true)
   141  	if err != nil {
   142  		t.Fatalf("FindOrCreate() failed, err: %v", err)
   143  	}
   144  	if got := r.FindByID(set.ID); got.ID != set.ID {
   145  		t.Fatalf("FindById(%d) failed, got: %+v, expected: %+v", set.ID, got, set)
   146  	}
   147  
   148  	ops := []linux.Sembuf{
   149  		{SemOp: -1},
   150  	}
   151  	chs := make([]chan struct{}, 0, 5)
   152  	for i := 0; i < 5; i++ {
   153  		ch := executeOps(ctx, t, set, ops, true)
   154  		chs = append(chs, ch)
   155  	}
   156  
   157  	creds := auth.CredentialsFromContext(ctx)
   158  	if err := r.RemoveID(set.ID, creds); err != nil {
   159  		t.Fatalf("RemoveID(%d) failed, err: %v", set.ID, err)
   160  	}
   161  	if !set.dead {
   162  		t.Fatalf("set is not dead: %+v", set)
   163  	}
   164  	if got := r.FindByID(set.ID); got != nil {
   165  		t.Fatalf("FindById(%d) failed, got: %+v, expected: nil", set.ID, got)
   166  	}
   167  	for i, ch := range chs {
   168  		if !signalled(ch) {
   169  			t.Fatalf("channel %d should have been signalled", i)
   170  		}
   171  	}
   172  }