gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/pkg/sentry/seccheck/sinks/remote/test/server.go (about)

     1  // Copyright 2022 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package test provides functionality used to test the remote sink.
    16  package test
    17  
    18  import (
    19  	"io/ioutil"
    20  	"os"
    21  	"path/filepath"
    22  
    23  	pb "gvisor.dev/gvisor/pkg/sentry/seccheck/points/points_go_proto"
    24  	"gvisor.dev/gvisor/pkg/sentry/seccheck/sinks/remote/server"
    25  	"gvisor.dev/gvisor/pkg/sentry/seccheck/sinks/remote/wire"
    26  	"gvisor.dev/gvisor/pkg/sync"
    27  )
    28  
    29  // Server is the counterpart to the sinks.Remote. It receives connections
    30  // remote sink and stores all points that it receives.
    31  type Server struct {
    32  	server.CommonServer
    33  
    34  	cond sync.Cond
    35  
    36  	// +checklocks:cond.L
    37  	points []Message
    38  
    39  	mu sync.Mutex
    40  
    41  	// +checklocks:mu
    42  	version uint32
    43  }
    44  
    45  // Message corresponds to a single message sent from sinks.Remote.
    46  type Message struct {
    47  	// MsgType indicates what is the type of Msg.
    48  	MsgType pb.MessageType
    49  	// Msg is the payload to the message that can be decoded using MsgType.
    50  	Msg []byte
    51  }
    52  
    53  // NewServer creates a new server that listens to a UDS that it creates under
    54  // os.TempDir.
    55  func NewServer() (*Server, error) {
    56  	dir, err := ioutil.TempDir(os.TempDir(), "remote")
    57  	if err != nil {
    58  		return nil, err
    59  	}
    60  	s := &Server{
    61  		version: wire.CurrentVersion,
    62  		cond:    sync.Cond{L: &sync.Mutex{}},
    63  	}
    64  	s.CommonServer.Init(filepath.Join(dir, "remote.sock"), s)
    65  	if err := s.CommonServer.Start(); err != nil {
    66  		_ = os.RemoveAll(dir)
    67  		return nil, err
    68  	}
    69  	return s, nil
    70  }
    71  
    72  // NewClient returns a new MessageHandler to process messages.
    73  func (s *Server) NewClient() (server.MessageHandler, error) {
    74  	return &msgHandler{owner: s}, nil
    75  }
    76  
    77  // Count return the number of points it has received.
    78  func (s *Server) Count() int {
    79  	s.cond.L.Lock()
    80  	defer s.cond.L.Unlock()
    81  	return len(s.points)
    82  }
    83  
    84  // Reset throws aways all points received so far and returns the number of
    85  // points discarded.
    86  func (s *Server) Reset() int {
    87  	s.cond.L.Lock()
    88  	defer s.cond.L.Unlock()
    89  	count := len(s.points)
    90  	s.points = nil
    91  	return count
    92  }
    93  
    94  // GetPoints returns all points that it has received.
    95  func (s *Server) GetPoints() []Message {
    96  	s.cond.L.Lock()
    97  	defer s.cond.L.Unlock()
    98  	cpy := make([]Message, len(s.points))
    99  	copy(cpy, s.points)
   100  	return cpy
   101  }
   102  
   103  // WaitForCount waits for the number of points to reach the desired number.
   104  func (s *Server) WaitForCount(count int) {
   105  	s.cond.L.Lock()
   106  	defer s.cond.L.Unlock()
   107  	for len(s.points) < count {
   108  		s.cond.Wait()
   109  	}
   110  	return
   111  }
   112  
   113  // SetVersion sets the version to be used in handshake.
   114  func (s *Server) SetVersion(newVersion uint32) {
   115  	s.mu.Lock()
   116  	defer s.mu.Unlock()
   117  	s.version = newVersion
   118  }
   119  
   120  type msgHandler struct {
   121  	owner *Server
   122  }
   123  
   124  // Message stores the message type and payload.
   125  func (m *msgHandler) Message(_ []byte, hdr wire.Header, payload []byte) error {
   126  	msg := Message{
   127  		MsgType: pb.MessageType(hdr.MessageType),
   128  		Msg:     make([]byte, len(payload)),
   129  	}
   130  	copy(msg.Msg, payload)
   131  
   132  	m.owner.cond.L.Lock()
   133  	defer m.owner.cond.L.Unlock()
   134  	m.owner.points = append(m.owner.points, msg)
   135  	m.owner.cond.Broadcast()
   136  	return nil
   137  }
   138  
   139  // Version returns the wire version supported or overridden by SetVersion.
   140  func (m *msgHandler) Version() uint32 {
   141  	m.owner.mu.Lock()
   142  	defer m.owner.mu.Unlock()
   143  	return m.owner.version
   144  }
   145  
   146  // Close implements server.MessageHandler.
   147  func (m *msgHandler) Close() {}