gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/pkg/sentry/seccheck/seccheck_test.go (about)

     1  // Copyright 2021 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package seccheck
    16  
    17  import (
    18  	"errors"
    19  	"os"
    20  	"testing"
    21  
    22  	"gvisor.dev/gvisor/pkg/context"
    23  	"gvisor.dev/gvisor/pkg/fd"
    24  	pb "gvisor.dev/gvisor/pkg/sentry/seccheck/points/points_go_proto"
    25  )
    26  
    27  type testSink struct {
    28  	SinkDefaults
    29  
    30  	onClone func(ctx context.Context, fields FieldSet, info *pb.CloneInfo) error
    31  }
    32  
    33  var _ Sink = (*testSink)(nil)
    34  
    35  func newTestSink(_ map[string]any, _ *fd.FD) (Sink, error) {
    36  	return &testSink{}, nil
    37  }
    38  
    39  // Name implements Sink.Name.
    40  func (c *testSink) Name() string {
    41  	return "test-sink"
    42  }
    43  
    44  // Clone implements Sink.Clone.
    45  func (c *testSink) Clone(ctx context.Context, fields FieldSet, info *pb.CloneInfo) error {
    46  	if c.onClone == nil {
    47  		return nil
    48  	}
    49  	return c.onClone(ctx, fields, info)
    50  }
    51  
    52  func TestNoSink(t *testing.T) {
    53  	var s State
    54  	if s.Enabled(PointClone) {
    55  		t.Errorf("Enabled(PointClone): got true, wanted false")
    56  	}
    57  }
    58  
    59  func TestSinkNotRegisteredForPoint(t *testing.T) {
    60  	var s State
    61  	s.AppendSink(&testSink{}, nil)
    62  	if s.Enabled(PointClone) {
    63  		t.Errorf("Enabled(PointClone): got true, wanted false")
    64  	}
    65  }
    66  
    67  func TestSinkRegistered(t *testing.T) {
    68  	var s State
    69  	sinkCalled := false
    70  	sink := &testSink{
    71  		onClone: func(context.Context, FieldSet, *pb.CloneInfo) error {
    72  			sinkCalled = true
    73  			return nil
    74  		},
    75  	}
    76  	req := []PointReq{
    77  		{
    78  			Pt:     PointClone,
    79  			Fields: FieldSet{Context: MakeFieldMask(FieldCtxtCredentials)},
    80  		},
    81  	}
    82  	s.AppendSink(sink, req)
    83  
    84  	if !s.Enabled(PointClone) {
    85  		t.Errorf("Enabled(PointClone): got false, wanted true")
    86  	}
    87  	fields := s.GetFieldSet(PointClone)
    88  	if !fields.Context.Contains(FieldCtxtCredentials) {
    89  		t.Errorf("fields.Context.Contains(PointContextCredentials): got false, wanted true")
    90  	}
    91  	if err := s.SentToSinks(func(c Sink) error {
    92  		return c.Clone(context.Background(), fields, &pb.CloneInfo{})
    93  	}); err != nil {
    94  		t.Errorf("Clone(): got %v, wanted nil", err)
    95  	}
    96  	if !sinkCalled {
    97  		t.Errorf("Clone() did not call Sink.Clone()")
    98  	}
    99  }
   100  
   101  func TestMultipleSinksRegistered(t *testing.T) {
   102  	var s State
   103  	sinkCalled := [2]bool{}
   104  	sink := &testSink{
   105  		onClone: func(context.Context, FieldSet, *pb.CloneInfo) error {
   106  			sinkCalled[0] = true
   107  			return nil
   108  		},
   109  	}
   110  	reqs := []PointReq{
   111  		{Pt: PointClone},
   112  	}
   113  	s.AppendSink(sink, reqs)
   114  
   115  	sink = &testSink{onClone: func(context.Context, FieldSet, *pb.CloneInfo) error {
   116  		sinkCalled[1] = true
   117  		return nil
   118  	}}
   119  	reqs = []PointReq{
   120  		{Pt: PointClone},
   121  	}
   122  	s.AppendSink(sink, reqs)
   123  
   124  	if !s.Enabled(PointClone) {
   125  		t.Errorf("Enabled(PointClone): got false, wanted true")
   126  	}
   127  	// CloneReq() should return the union of requested fields from all calls to
   128  	// AppendSink.
   129  	fields := s.GetFieldSet(PointClone)
   130  	if err := s.SentToSinks(func(c Sink) error {
   131  		return c.Clone(context.Background(), fields, &pb.CloneInfo{})
   132  	}); err != nil {
   133  		t.Errorf("Clone(): got %v, wanted nil", err)
   134  	}
   135  	for i := range sinkCalled {
   136  		if !sinkCalled[i] {
   137  			t.Errorf("Clone() did not call Sink.Clone() index %d", i)
   138  		}
   139  	}
   140  }
   141  
   142  func TestCheckpointReturnsFirstSinkError(t *testing.T) {
   143  	errFirstSink := errors.New("first Sink error")
   144  	errSecondSink := errors.New("second Sink error")
   145  
   146  	var s State
   147  	sinkCalled := [2]bool{}
   148  	sink := &testSink{
   149  		onClone: func(context.Context, FieldSet, *pb.CloneInfo) error {
   150  			sinkCalled[0] = true
   151  			return errFirstSink
   152  		},
   153  	}
   154  	reqs := []PointReq{
   155  		{Pt: PointClone},
   156  	}
   157  
   158  	s.AppendSink(sink, reqs)
   159  
   160  	sink = &testSink{
   161  		onClone: func(context.Context, FieldSet, *pb.CloneInfo) error {
   162  			sinkCalled[1] = true
   163  			return errSecondSink
   164  		},
   165  	}
   166  	s.AppendSink(sink, reqs)
   167  
   168  	if !s.Enabled(PointClone) {
   169  		t.Errorf("Enabled(PointClone): got false, wanted true")
   170  	}
   171  	if err := s.SentToSinks(func(c Sink) error {
   172  		return c.Clone(context.Background(), FieldSet{}, &pb.CloneInfo{})
   173  	}); err != errFirstSink {
   174  		t.Errorf("Clone(): got %v, wanted %v", err, errFirstSink)
   175  	}
   176  	if !sinkCalled[0] {
   177  		t.Errorf("Clone() did not call first Sink")
   178  	}
   179  	if sinkCalled[1] {
   180  		t.Errorf("Clone() called second Sink")
   181  	}
   182  }
   183  
   184  func TestFieldMaskEmpty(t *testing.T) {
   185  	fd := FieldMask{}
   186  	if !fd.Empty() {
   187  		t.Errorf("new FieldMask must be empty: %+v", fd)
   188  	}
   189  }
   190  
   191  func TestFieldMaskMake(t *testing.T) {
   192  	zero := Field(0)
   193  	one := Field(1)
   194  	two := Field(2)
   195  	fd := MakeFieldMask(zero, two)
   196  	if fd.Empty() {
   197  		t.Errorf("FieldMask must not be empty: %+v", fd)
   198  	}
   199  	if want := zero; !fd.Contains(want) {
   200  		t.Errorf("FieldMask must contain %v: %+v", want, fd)
   201  	}
   202  	if want := two; !fd.Contains(want) {
   203  		t.Errorf("FieldMask must contain %v: %+v", want, fd)
   204  	}
   205  	if want := one; fd.Contains(want) {
   206  		t.Errorf("FieldMask must not contain %v: %+v", want, fd)
   207  	}
   208  }
   209  
   210  func TestFieldMask(t *testing.T) {
   211  	zero := Field(0)
   212  	one := Field(1)
   213  	two := Field(2)
   214  	fd := FieldMask{}
   215  
   216  	fd.Add(zero)
   217  	if fd.Empty() {
   218  		t.Errorf("FieldMask must not be empty: %+v", fd)
   219  	}
   220  	if want := zero; !fd.Contains(want) {
   221  		t.Errorf("FieldMask must contain %v: %+v", want, fd)
   222  	}
   223  	if want := one; fd.Contains(want) {
   224  		t.Errorf("FieldMask must not contain %v: %+v", want, fd)
   225  	}
   226  	if want := two; fd.Contains(want) {
   227  		t.Errorf("FieldMask must not contain %v: %+v", want, fd)
   228  	}
   229  
   230  	fd.Add(two)
   231  	if fd.Empty() {
   232  		t.Errorf("FieldMask must not be empty: %+v", fd)
   233  	}
   234  	if want := zero; !fd.Contains(want) {
   235  		t.Errorf("FieldMask must contain %v: %+v", want, fd)
   236  	}
   237  	if want := one; fd.Contains(want) {
   238  		t.Errorf("FieldMask must not contain %v: %+v", want, fd)
   239  	}
   240  	if want := two; !fd.Contains(want) {
   241  		t.Errorf("FieldMask must contain %v: %+v", want, fd)
   242  	}
   243  
   244  	fd.Remove(zero)
   245  	if fd.Empty() {
   246  		t.Errorf("FieldMask must not be empty: %+v", fd)
   247  	}
   248  	if want := zero; fd.Contains(want) {
   249  		t.Errorf("FieldMask must not contain %v: %+v", want, fd)
   250  	}
   251  	if want := one; fd.Contains(want) {
   252  		t.Errorf("FieldMask must not contain %v: %+v", want, fd)
   253  	}
   254  	if want := two; !fd.Contains(want) {
   255  		t.Errorf("FieldMask must contain %v: %+v", want, fd)
   256  	}
   257  
   258  	fd.Remove(two)
   259  	if !fd.Empty() {
   260  		t.Errorf("FieldMask must be empty: %+v", fd)
   261  	}
   262  	if want := zero; fd.Contains(want) {
   263  		t.Errorf("FieldMask must not contain %v: %+v", want, fd)
   264  	}
   265  	if want := one; fd.Contains(want) {
   266  		t.Errorf("FieldMask must not contain %v: %+v", want, fd)
   267  	}
   268  	if want := two; fd.Contains(want) {
   269  		t.Errorf("FieldMask must not contain %v: %+v", want, fd)
   270  	}
   271  }
   272  
   273  func TestMain(m *testing.M) {
   274  
   275  	RegisterSink(SinkDesc{
   276  		Name: "test-sink",
   277  		New:  newTestSink,
   278  	})
   279  	Initialize()
   280  	os.Exit(m.Run())
   281  }