github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/pkg/sentry/kernel/semaphore/semaphore_test.go (about) 1 // Copyright 2018 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package semaphore 16 17 import ( 18 "testing" 19 20 "github.com/SagerNet/gvisor/pkg/abi/linux" 21 "github.com/SagerNet/gvisor/pkg/context" 22 "github.com/SagerNet/gvisor/pkg/sentry/contexttest" 23 "github.com/SagerNet/gvisor/pkg/sentry/kernel/auth" 24 "github.com/SagerNet/gvisor/pkg/syserror" 25 ) 26 27 func executeOps(ctx context.Context, t *testing.T, set *Set, ops []linux.Sembuf, block bool) chan struct{} { 28 ch, _, err := set.executeOps(ctx, ops, 123) 29 if err != nil { 30 t.Fatalf("ExecuteOps(ops) failed, err: %v, ops: %+v", err, ops) 31 } 32 if block { 33 if ch == nil { 34 t.Fatalf("ExecuteOps(ops) got: nil, expected: !nil, ops: %+v", ops) 35 } 36 if signalled(ch) { 37 t.Fatalf("ExecuteOps(ops) channel should not have been signalled, ops: %+v", ops) 38 } 39 } else { 40 if ch != nil { 41 t.Fatalf("ExecuteOps(ops) got: %v, expected: nil, ops: %+v", ch, ops) 42 } 43 } 44 return ch 45 } 46 47 func signalled(ch chan struct{}) bool { 48 select { 49 case <-ch: 50 return true 51 default: 52 return false 53 } 54 } 55 56 func TestBasic(t *testing.T) { 57 ctx := contexttest.Context(t) 58 set := &Set{ID: 123, sems: make([]sem, 1)} 59 ops := []linux.Sembuf{ 60 {SemOp: 1}, 61 } 62 executeOps(ctx, t, set, ops, false) 63 64 ops[0].SemOp = -1 65 executeOps(ctx, t, set, ops, false) 66 67 ops[0].SemOp = -1 68 ch1 := executeOps(ctx, t, set, ops, true) 69 70 ops[0].SemOp = 1 71 executeOps(ctx, t, set, ops, false) 72 if !signalled(ch1) { 73 t.Fatalf("ExecuteOps(ops) channel should not have been signalled, ops: %+v", ops) 74 } 75 } 76 77 func TestWaitForZero(t *testing.T) { 78 ctx := contexttest.Context(t) 79 set := &Set{ID: 123, sems: make([]sem, 1)} 80 ops := []linux.Sembuf{ 81 {SemOp: 0}, 82 } 83 executeOps(ctx, t, set, ops, false) 84 85 ops[0].SemOp = -2 86 ch1 := executeOps(ctx, t, set, ops, true) 87 88 ops[0].SemOp = 0 89 executeOps(ctx, t, set, ops, false) 90 91 ops[0].SemOp = 1 92 executeOps(ctx, t, set, ops, false) 93 94 ops[0].SemOp = 0 95 chZero1 := executeOps(ctx, t, set, ops, true) 96 97 ops[0].SemOp = 0 98 chZero2 := executeOps(ctx, t, set, ops, true) 99 100 ops[0].SemOp = 1 101 executeOps(ctx, t, set, ops, false) 102 if !signalled(ch1) { 103 t.Fatalf("ExecuteOps(ops) channel should have been signalled, ops: %+v, set: %+v", ops, set) 104 } 105 106 ops[0].SemOp = -2 107 executeOps(ctx, t, set, ops, false) 108 if !signalled(chZero1) { 109 t.Fatalf("ExecuteOps(ops) channel zero 1 should have been signalled, ops: %+v, set: %+v", ops, set) 110 } 111 if !signalled(chZero2) { 112 t.Fatalf("ExecuteOps(ops) channel zero 2 should have been signalled, ops: %+v, set: %+v", ops, set) 113 } 114 } 115 116 func TestNoWait(t *testing.T) { 117 ctx := contexttest.Context(t) 118 set := &Set{ID: 123, sems: make([]sem, 1)} 119 ops := []linux.Sembuf{ 120 {SemOp: 1}, 121 } 122 executeOps(ctx, t, set, ops, false) 123 124 ops[0].SemOp = -2 125 ops[0].SemFlg = linux.IPC_NOWAIT 126 if _, _, err := set.executeOps(ctx, ops, 123); err != syserror.ErrWouldBlock { 127 t.Fatalf("ExecuteOps(ops) wrong result, got: %v, expected: %v", err, syserror.ErrWouldBlock) 128 } 129 130 ops[0].SemOp = 0 131 ops[0].SemFlg = linux.IPC_NOWAIT 132 if _, _, err := set.executeOps(ctx, ops, 123); err != syserror.ErrWouldBlock { 133 t.Fatalf("ExecuteOps(ops) wrong result, got: %v, expected: %v", err, syserror.ErrWouldBlock) 134 } 135 } 136 137 func TestUnregister(t *testing.T) { 138 ctx := contexttest.Context(t) 139 r := NewRegistry(auth.NewRootUserNamespace()) 140 set, err := r.FindOrCreate(ctx, 123, 2, linux.FileMode(0x600), true, true, true) 141 if err != nil { 142 t.Fatalf("FindOrCreate() failed, err: %v", err) 143 } 144 if got := r.FindByID(set.ID); got.ID != set.ID { 145 t.Fatalf("FindById(%d) failed, got: %+v, expected: %+v", set.ID, got, set) 146 } 147 148 ops := []linux.Sembuf{ 149 {SemOp: -1}, 150 } 151 chs := make([]chan struct{}, 0, 5) 152 for i := 0; i < 5; i++ { 153 ch := executeOps(ctx, t, set, ops, true) 154 chs = append(chs, ch) 155 } 156 157 creds := auth.CredentialsFromContext(ctx) 158 if err := r.RemoveID(set.ID, creds); err != nil { 159 t.Fatalf("RemoveID(%d) failed, err: %v", set.ID, err) 160 } 161 if !set.dead { 162 t.Fatalf("set is not dead: %+v", set) 163 } 164 if got := r.FindByID(set.ID); got != nil { 165 t.Fatalf("FindById(%d) failed, got: %+v, expected: nil", set.ID, got) 166 } 167 for i, ch := range chs { 168 if !signalled(ch) { 169 t.Fatalf("channel %d should have been signalled", i) 170 } 171 } 172 }