gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/pkg/sentry/seccheck/sinks/remote/remote_test.go (about)

     1  // Copyright 2022 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package remote
    16  
    17  import (
    18  	"bytes"
    19  	"fmt"
    20  	"os"
    21  	"os/exec"
    22  	"path/filepath"
    23  	"strings"
    24  	"sync"
    25  	"testing"
    26  	"time"
    27  
    28  	"github.com/cenkalti/backoff"
    29  	"google.golang.org/protobuf/proto"
    30  	"google.golang.org/protobuf/types/known/anypb"
    31  	"gvisor.dev/gvisor/pkg/fd"
    32  	"gvisor.dev/gvisor/pkg/sentry/seccheck"
    33  	pb "gvisor.dev/gvisor/pkg/sentry/seccheck/points/points_go_proto"
    34  	"gvisor.dev/gvisor/pkg/sentry/seccheck/sinks/remote/test"
    35  	"gvisor.dev/gvisor/pkg/sentry/seccheck/sinks/remote/wire"
    36  	"gvisor.dev/gvisor/pkg/test/testutil"
    37  )
    38  
    39  func waitForFile(path string) error {
    40  	return testutil.Poll(func() error {
    41  		if _, err := os.Stat(path); err != nil {
    42  			if os.IsNotExist(err) {
    43  				return err
    44  			}
    45  			return &backoff.PermanentError{Err: err}
    46  		}
    47  		return nil
    48  	}, 5*time.Second)
    49  }
    50  
    51  type syncBuffer struct {
    52  	mu sync.Mutex
    53  	// +checklocks:mu
    54  	buf bytes.Buffer
    55  }
    56  
    57  func (s *syncBuffer) Write(p []byte) (n int, err error) {
    58  	s.mu.Lock()
    59  	defer s.mu.Unlock()
    60  	return s.buf.Write(p)
    61  }
    62  
    63  func (s *syncBuffer) String() string {
    64  	s.mu.Lock()
    65  	defer s.mu.Unlock()
    66  	return s.buf.String()
    67  }
    68  
    69  type exampleServer struct {
    70  	path string
    71  	cmd  *exec.Cmd
    72  	out  syncBuffer
    73  }
    74  
    75  func newExampleServer(quiet bool) (*exampleServer, error) {
    76  	exe, err := testutil.FindFile("examples/seccheck/server_cc")
    77  	if err != nil {
    78  		return nil, fmt.Errorf("error finding server_cc: %v", err)
    79  	}
    80  
    81  	dir, err := os.MkdirTemp(os.TempDir(), "remote")
    82  	if err != nil {
    83  		return nil, fmt.Errorf("Setup(%q): %v", dir, err)
    84  	}
    85  
    86  	server := &exampleServer{path: filepath.Join(dir, "remote.sock")}
    87  	server.cmd = exec.Command(exe, server.path)
    88  	if quiet {
    89  		server.cmd.Args = append(server.cmd.Args, "-q")
    90  	}
    91  	server.cmd.Stdout = &server.out
    92  	server.cmd.Stderr = &server.out
    93  	if err := server.cmd.Start(); err != nil {
    94  		os.RemoveAll(dir)
    95  		return nil, fmt.Errorf("error running %q: %v", exe, err)
    96  	}
    97  
    98  	if err := waitForFile(server.path); err != nil {
    99  		server.stop()
   100  		return nil, fmt.Errorf("error waiting for server file %q: %w", server.path, err)
   101  	}
   102  	return server, nil
   103  }
   104  
   105  func (s *exampleServer) stop() {
   106  	_ = s.cmd.Process.Kill()
   107  	_ = s.cmd.Wait()
   108  	_ = os.Remove(s.path)
   109  }
   110  
   111  func TestBasic(t *testing.T) {
   112  	server, err := test.NewServer()
   113  	if err != nil {
   114  		t.Fatalf("newServer(): %v", err)
   115  	}
   116  	defer server.Close()
   117  
   118  	endpoint, err := setup(server.Endpoint)
   119  	if err != nil {
   120  		t.Fatalf("setup(): %v", err)
   121  	}
   122  	endpointFD, err := fd.NewFromFile(endpoint)
   123  	if err != nil {
   124  		_ = endpoint.Close()
   125  		t.Fatalf("NewFromFile(): %v", err)
   126  	}
   127  	_ = endpoint.Close()
   128  
   129  	r, err := new(nil, endpointFD)
   130  	if err != nil {
   131  		t.Fatalf("New(): %v", err)
   132  	}
   133  
   134  	info := &pb.ExitNotifyParentInfo{ExitStatus: 123}
   135  	if err := r.ExitNotifyParent(nil, seccheck.FieldSet{}, info); err != nil {
   136  		t.Fatalf("ExitNotifyParent: %v", err)
   137  	}
   138  
   139  	server.WaitForCount(1)
   140  	pt := server.GetPoints()[0]
   141  	if want := pb.MessageType_MESSAGE_SENTRY_EXIT_NOTIFY_PARENT; pt.MsgType != want {
   142  		t.Errorf("wrong message type, want: %v, got: %v", want, pt.MsgType)
   143  	}
   144  	got := &pb.ExitNotifyParentInfo{}
   145  	if err := proto.Unmarshal(pt.Msg, got); err != nil {
   146  		t.Errorf("proto.Unmarshal(ExitNotifyParentInfo): %v", err)
   147  	}
   148  	if !proto.Equal(info, got) {
   149  		t.Errorf("Received point is different, want: %+v, got: %+v", info, got)
   150  	}
   151  	// Check that no more points were received.
   152  	if want, got := 1, server.Count(); want != got {
   153  		t.Errorf("wrong number of points, want: %d, got: %d", want, got)
   154  	}
   155  }
   156  
   157  func TestVersionUnsupported(t *testing.T) {
   158  	server, err := test.NewServer()
   159  	if err != nil {
   160  		t.Fatalf("newServer(): %v", err)
   161  	}
   162  	defer server.Close()
   163  
   164  	server.SetVersion(0)
   165  
   166  	_, err = setup(server.Endpoint)
   167  	if err == nil || !strings.Contains(err.Error(), "remote version") {
   168  		t.Fatalf("Wrong error: %v", err)
   169  	}
   170  }
   171  
   172  func TestVersionNewer(t *testing.T) {
   173  	server, err := test.NewServer()
   174  	if err != nil {
   175  		t.Fatalf("newServer(): %v", err)
   176  	}
   177  	defer server.Close()
   178  
   179  	server.SetVersion(wire.CurrentVersion + 10)
   180  
   181  	endpoint, err := setup(server.Endpoint)
   182  	if err != nil {
   183  		t.Fatalf("setup(): %v", err)
   184  	}
   185  	_ = endpoint.Close()
   186  }
   187  
   188  // Test that the example C++ server works. It's easier to test from here and
   189  // also changes that can break it will likely originate here.
   190  func TestExample(t *testing.T) {
   191  	server, err := newExampleServer(false)
   192  	if err != nil {
   193  		t.Fatalf("newExampleServer(): %v", err)
   194  	}
   195  	defer server.stop()
   196  
   197  	endpoint, err := setup(server.path)
   198  	if err != nil {
   199  		t.Fatalf("setup(): %v", err)
   200  	}
   201  	endpointFD, err := fd.NewFromFile(endpoint)
   202  	if err != nil {
   203  		_ = endpoint.Close()
   204  		t.Fatalf("NewFromFile(): %v", err)
   205  	}
   206  	_ = endpoint.Close()
   207  
   208  	r, err := new(nil, endpointFD)
   209  	if err != nil {
   210  		t.Fatalf("New(): %v", err)
   211  	}
   212  
   213  	info := pb.ExitNotifyParentInfo{ExitStatus: 123}
   214  	if err := r.ExitNotifyParent(nil, seccheck.FieldSet{}, &info); err != nil {
   215  		t.Fatalf("ExitNotifyParent: %v", err)
   216  	}
   217  	check := func() error {
   218  		gotRaw := server.out.String()
   219  		// Collapse whitespace.
   220  		got := strings.Join(strings.Fields(gotRaw), " ")
   221  		if !strings.Contains(got, "ExitNotifyParentInfo => exit_status: 123") {
   222  			return fmt.Errorf("ExitNotifyParentInfo point didn't get to the server, out: %q, raw: %q", got, gotRaw)
   223  		}
   224  		return nil
   225  	}
   226  	if err := testutil.Poll(check, time.Second); err != nil {
   227  		t.Errorf(err.Error())
   228  	}
   229  }
   230  
   231  func TestConfig(t *testing.T) {
   232  	for _, tc := range []struct {
   233  		name   string
   234  		config map[string]any
   235  		want   *remote
   236  		err    string
   237  	}{
   238  		{
   239  			name:   "default",
   240  			config: map[string]any{},
   241  			want: &remote{
   242  				retries:        0,
   243  				initialBackoff: 25 * time.Microsecond,
   244  				maxBackoff:     10 * time.Millisecond,
   245  			},
   246  		},
   247  		{
   248  			name: "all",
   249  			config: map[string]any{
   250  				"retries":     float64(10),
   251  				"backoff":     "1s",
   252  				"backoff_max": "10s",
   253  			},
   254  			want: &remote{
   255  				retries:        10,
   256  				initialBackoff: time.Second,
   257  				maxBackoff:     10 * time.Second,
   258  			},
   259  		},
   260  		{
   261  			name: "bad-retries",
   262  			config: map[string]any{
   263  				"retries": "10",
   264  			},
   265  			err: "retries",
   266  		},
   267  		{
   268  			name: "bad-backoff",
   269  			config: map[string]any{
   270  				"backoff": "wrong",
   271  			},
   272  			err: "invalid duration",
   273  		},
   274  		{
   275  			name: "bad-backoff-max",
   276  			config: map[string]any{
   277  				"backoff_max": 10,
   278  			},
   279  			err: "is not an string",
   280  		},
   281  		{
   282  			name: "bad-invalid-backoffs",
   283  			config: map[string]any{
   284  				"retries":     float64(10),
   285  				"backoff":     "10s",
   286  				"backoff_max": "1s",
   287  			},
   288  			err: "cannot be larger than max",
   289  		},
   290  	} {
   291  		t.Run(tc.name, func(t *testing.T) {
   292  			var endpoint fd.FD
   293  			sink, err := new(tc.config, &endpoint)
   294  			if len(tc.err) == 0 {
   295  				if err != nil {
   296  					t.Fatalf("new(%q): %v", tc.config, err)
   297  				}
   298  				got := sink.(*remote)
   299  				got.endpoint = nil
   300  				if *got != *tc.want {
   301  					t.Errorf("wrong remote: want: %+v, got: %+v", tc.want, got)
   302  				}
   303  			} else if err == nil || !strings.Contains(err.Error(), tc.err) {
   304  				t.Errorf("wrong error: want: %v, got: %v", tc.err, err)
   305  			}
   306  		})
   307  	}
   308  }
   309  
   310  func BenchmarkSmall(t *testing.B) {
   311  	// Run server in a separate process just to isolate it as much as possible.
   312  	server, err := newExampleServer(false)
   313  	if err != nil {
   314  		t.Fatalf("newExampleServer(): %v", err)
   315  	}
   316  	defer server.stop()
   317  
   318  	endpoint, err := setup(server.path)
   319  	if err != nil {
   320  		t.Fatalf("setup(): %v", err)
   321  	}
   322  	endpointFD, err := fd.NewFromFile(endpoint)
   323  	if err != nil {
   324  		_ = endpoint.Close()
   325  		t.Fatalf("NewFromFile(): %v", err)
   326  	}
   327  	_ = endpoint.Close()
   328  
   329  	r, err := new(nil, endpointFD)
   330  	if err != nil {
   331  		t.Fatalf("New(): %v", err)
   332  	}
   333  
   334  	t.ResetTimer()
   335  	t.RunParallel(func(sub *testing.PB) {
   336  		for sub.Next() {
   337  			info := pb.ExitNotifyParentInfo{ExitStatus: 123}
   338  			if err := r.ExitNotifyParent(nil, seccheck.FieldSet{}, &info); err != nil {
   339  				t.Fatalf("ExitNotifyParent: %v", err)
   340  			}
   341  		}
   342  	})
   343  }
   344  
   345  func BenchmarkProtoAny(t *testing.B) {
   346  	info := &pb.ExitNotifyParentInfo{ExitStatus: 123}
   347  
   348  	t.ResetTimer()
   349  	t.RunParallel(func(sub *testing.PB) {
   350  		for sub.Next() {
   351  			any, err := anypb.New(info)
   352  			if err != nil {
   353  				t.Fatal(err)
   354  			}
   355  			if _, err := proto.Marshal(any); err != nil {
   356  				t.Fatal(err)
   357  			}
   358  		}
   359  	})
   360  }
   361  
   362  func BenchmarkProtoEnum(t *testing.B) {
   363  	info := &pb.ExitNotifyParentInfo{ExitStatus: 123}
   364  
   365  	t.ResetTimer()
   366  	t.RunParallel(func(sub *testing.PB) {
   367  		for sub.Next() {
   368  			if _, err := proto.Marshal(info); err != nil {
   369  				t.Fatal(err)
   370  			}
   371  		}
   372  	})
   373  }