github.com/nicocha30/gvisor-ligolo@v0.0.0-20230726075806-989fa2c0a413/pkg/sentry/seccheck/seccheck.go (about)

     1  // Copyright 2021 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package seccheck defines a structure for dynamically-configured security
    16  // checks in the sentry.
    17  package seccheck
    18  
    19  import (
    20  	"google.golang.org/protobuf/proto"
    21  	"github.com/nicocha30/gvisor-ligolo/pkg/atomicbitops"
    22  	"github.com/nicocha30/gvisor-ligolo/pkg/context"
    23  	pb "github.com/nicocha30/gvisor-ligolo/pkg/sentry/seccheck/points/points_go_proto"
    24  	"github.com/nicocha30/gvisor-ligolo/pkg/sync"
    25  )
    26  
    27  // A Point represents a checkpoint, a point at which a security check occurs.
    28  type Point uint
    29  
    30  // PointX represents the checkpoint X.
    31  const (
    32  	totalPoints            = int(pointLengthBeforeSyscalls) + syscallPoints
    33  	numPointsPerUint32     = 32
    34  	numPointBitmaskUint32s = (totalPoints-1)/numPointsPerUint32 + 1
    35  )
    36  
    37  // FieldSet contains all optional fields to be collected by a given Point.
    38  type FieldSet struct {
    39  	// Local indicates which optional fields from the Point that needs to be
    40  	// collected, e.g. resolving path from an FD, or collecting a large field.
    41  	Local FieldMask
    42  
    43  	// Context indicates which optional fields from the Context that needs to be
    44  	// collected, e.g. PID, credentials, current time.
    45  	Context FieldMask
    46  }
    47  
    48  // Field represents the index of a single optional field to be collect for a
    49  // Point.
    50  type Field uint
    51  
    52  // FieldMask is a bitmask with a single bit representing an optional field to be
    53  // collected. The meaning of each bit varies per point. The mask is currently
    54  // limited to 64 fields. If more are needed, FieldMask can be expanded to
    55  // support additional fields.
    56  type FieldMask struct {
    57  	mask uint64
    58  }
    59  
    60  // MakeFieldMask creates a FieldMask from a set of Fields.
    61  func MakeFieldMask(fields ...Field) FieldMask {
    62  	var m FieldMask
    63  	for _, field := range fields {
    64  		m.Add(field)
    65  	}
    66  	return m
    67  }
    68  
    69  // Contains returns true if the mask contains the Field.
    70  func (fm *FieldMask) Contains(field Field) bool {
    71  	return fm.mask&(1<<field) != 0
    72  }
    73  
    74  // Add adds a Field to the mask.
    75  func (fm *FieldMask) Add(field Field) {
    76  	fm.mask |= 1 << field
    77  }
    78  
    79  // Remove removes a Field from the mask.
    80  func (fm *FieldMask) Remove(field Field) {
    81  	fm.mask &^= 1 << field
    82  }
    83  
    84  // Empty returns true if no bits are set.
    85  func (fm *FieldMask) Empty() bool {
    86  	return fm.mask == 0
    87  }
    88  
    89  // A Sink performs security checks at checkpoints.
    90  //
    91  // Each Sink method X is called at checkpoint X; if the method may return a
    92  // non-nil error and does so, it causes the checked operation to fail
    93  // immediately (without calling subsequent Sinks) and return the error. The
    94  // info argument contains information relevant to the check. The mask argument
    95  // indicates what fields in info are valid; the mask should usually be a
    96  // superset of fields requested by the Sink's corresponding PointReq, but
    97  // may be missing requested fields in some cases (e.g. if the Sink is
    98  // registered concurrently with invocations of checkpoints).
    99  type Sink interface {
   100  	// Name return the sink name.
   101  	Name() string
   102  	// Status returns the sink runtime status.
   103  	Status() SinkStatus
   104  	// Stop requests the sink to stop.
   105  	Stop()
   106  
   107  	Clone(ctx context.Context, fields FieldSet, info *pb.CloneInfo) error
   108  	Execve(ctx context.Context, fields FieldSet, info *pb.ExecveInfo) error
   109  	ExitNotifyParent(ctx context.Context, fields FieldSet, info *pb.ExitNotifyParentInfo) error
   110  	TaskExit(context.Context, FieldSet, *pb.TaskExit) error
   111  
   112  	ContainerStart(context.Context, FieldSet, *pb.Start) error
   113  
   114  	Syscall(context.Context, FieldSet, *pb.ContextData, pb.MessageType, proto.Message) error
   115  	RawSyscall(context.Context, FieldSet, *pb.Syscall) error
   116  }
   117  
   118  // SinkStatus represents stats about each Sink instance.
   119  type SinkStatus struct {
   120  	// DroppedCount is the number of trace points dropped.
   121  	DroppedCount uint64
   122  }
   123  
   124  // SinkDefaults may be embedded by implementations of Sink to obtain
   125  // no-op implementations of Sink methods that may be explicitly overridden.
   126  type SinkDefaults struct{}
   127  
   128  // Add functions missing in SinkDefaults to make it possible to check for the
   129  // implementation below to catch missing functions more easily.
   130  type sinkDefaultsImpl struct {
   131  	SinkDefaults
   132  }
   133  
   134  // Name implements Sink.Name.
   135  func (sinkDefaultsImpl) Name() string { return "" }
   136  
   137  var _ Sink = (*sinkDefaultsImpl)(nil)
   138  
   139  // Status implements Sink.Status.
   140  func (SinkDefaults) Status() SinkStatus {
   141  	return SinkStatus{}
   142  }
   143  
   144  // Stop implements Sink.Stop.
   145  func (SinkDefaults) Stop() {}
   146  
   147  // Clone implements Sink.Clone.
   148  func (SinkDefaults) Clone(context.Context, FieldSet, *pb.CloneInfo) error {
   149  	return nil
   150  }
   151  
   152  // Execve implements Sink.Execve.
   153  func (SinkDefaults) Execve(context.Context, FieldSet, *pb.ExecveInfo) error {
   154  	return nil
   155  }
   156  
   157  // ExitNotifyParent implements Sink.ExitNotifyParent.
   158  func (SinkDefaults) ExitNotifyParent(context.Context, FieldSet, *pb.ExitNotifyParentInfo) error {
   159  	return nil
   160  }
   161  
   162  // ContainerStart implements Sink.ContainerStart.
   163  func (SinkDefaults) ContainerStart(context.Context, FieldSet, *pb.Start) error {
   164  	return nil
   165  }
   166  
   167  // TaskExit implements Sink.TaskExit.
   168  func (SinkDefaults) TaskExit(context.Context, FieldSet, *pb.TaskExit) error {
   169  	return nil
   170  }
   171  
   172  // RawSyscall implements Sink.RawSyscall.
   173  func (SinkDefaults) RawSyscall(context.Context, FieldSet, *pb.Syscall) error {
   174  	return nil
   175  }
   176  
   177  // Syscall implements Sink.Syscall.
   178  func (SinkDefaults) Syscall(context.Context, FieldSet, *pb.ContextData, pb.MessageType, proto.Message) error {
   179  	return nil
   180  }
   181  
   182  // PointReq indicates what Point a corresponding Sink runs at, and what
   183  // information it requires at those Points.
   184  type PointReq struct {
   185  	Pt     Point
   186  	Fields FieldSet
   187  }
   188  
   189  // Global is the method receiver of all seccheck functions.
   190  var Global State
   191  
   192  // State is the type of global, and is separated out for testing.
   193  type State struct {
   194  	// registrationMu serializes all changes to the set of registered Sinks
   195  	// for all checkpoints.
   196  	registrationMu sync.RWMutex
   197  
   198  	// enabledPoints is a bitmask of checkpoints for which at least one Sink
   199  	// is registered.
   200  	//
   201  	// Mutation of enabledPoints is serialized by registrationMu.
   202  	enabledPoints [numPointBitmaskUint32s]atomicbitops.Uint32
   203  
   204  	// registrationSeq supports store-free atomic reads of registeredSinks.
   205  	registrationSeq sync.SeqCount
   206  
   207  	// sinks is the set of all registered Sinks in order of execution.
   208  	//
   209  	// sinks is accessed using instantiations of SeqAtomic functions.
   210  	// Mutation of sinks is serialized by registrationMu.
   211  	sinks []Sink
   212  
   213  	// syscallFlagListeners is the set of registered SyscallFlagListeners.
   214  	//
   215  	// They are notified when the enablement of a syscall point changes.
   216  	// Mutation of syscallFlagListeners is serialized by registrationMu.
   217  	syscallFlagListeners []SyscallFlagListener
   218  
   219  	pointFields map[Point]FieldSet
   220  }
   221  
   222  // AppendSink registers the given Sink to execute at checkpoints. The
   223  // Sink will execute after all previously-registered sinks, and only if
   224  // those Sinks return a nil error.
   225  func (s *State) AppendSink(c Sink, reqs []PointReq) {
   226  	s.registrationMu.Lock()
   227  	defer s.registrationMu.Unlock()
   228  
   229  	s.appendSinkLocked(c)
   230  	if s.pointFields == nil {
   231  		s.pointFields = make(map[Point]FieldSet)
   232  	}
   233  	updateSyscalls := false
   234  	for _, req := range reqs {
   235  		word, bit := req.Pt/numPointsPerUint32, req.Pt%numPointsPerUint32
   236  		s.enabledPoints[word].Store(s.enabledPoints[word].RacyLoad() | (uint32(1) << bit))
   237  		if req.Pt >= pointLengthBeforeSyscalls {
   238  			updateSyscalls = true
   239  		}
   240  		s.pointFields[req.Pt] = req.Fields
   241  	}
   242  	if updateSyscalls {
   243  		for _, listener := range s.syscallFlagListeners {
   244  			listener.UpdateSecCheck(s)
   245  		}
   246  	}
   247  }
   248  
   249  func (s *State) clearSink() {
   250  	s.registrationMu.Lock()
   251  	defer s.registrationMu.Unlock()
   252  
   253  	updateSyscalls := false
   254  	for i := range s.enabledPoints {
   255  		s.enabledPoints[i].Store(0)
   256  		// We use i+1 here because we want to check the last bit that may have been changed within i.
   257  		if Point((i+1)*numPointsPerUint32) >= pointLengthBeforeSyscalls {
   258  			updateSyscalls = true
   259  		}
   260  	}
   261  	if updateSyscalls {
   262  		for _, listener := range s.syscallFlagListeners {
   263  			listener.UpdateSecCheck(s)
   264  		}
   265  	}
   266  	s.pointFields = nil
   267  
   268  	oldSinks := s.getSinks()
   269  	s.registrationSeq.BeginWrite()
   270  	s.sinks = nil
   271  	s.registrationSeq.EndWrite()
   272  	for _, sink := range oldSinks {
   273  		sink.Stop()
   274  	}
   275  }
   276  
   277  // AddSyscallFlagListener adds a listener to the State.
   278  //
   279  // The listener will be notified whenever syscall point enablement changes.
   280  func (s *State) AddSyscallFlagListener(listener SyscallFlagListener) {
   281  	s.registrationMu.Lock()
   282  	defer s.registrationMu.Unlock()
   283  	s.syscallFlagListeners = append(s.syscallFlagListeners, listener)
   284  }
   285  
   286  // Enabled returns true if any Sink is registered for the given checkpoint.
   287  func (s *State) Enabled(p Point) bool {
   288  	word, bit := p/numPointsPerUint32, p%numPointsPerUint32
   289  	if int(word) >= len(s.enabledPoints) {
   290  		return false
   291  	}
   292  	return s.enabledPoints[word].Load()&(uint32(1)<<bit) != 0
   293  }
   294  
   295  func (s *State) getSinks() []Sink {
   296  	return SeqAtomicLoadSinkSlice(&s.registrationSeq, &s.sinks)
   297  }
   298  
   299  // Preconditions: s.registrationMu must be locked.
   300  func (s *State) appendSinkLocked(c Sink) {
   301  	s.registrationSeq.BeginWrite()
   302  	s.sinks = append(s.sinks, c)
   303  	s.registrationSeq.EndWrite()
   304  }
   305  
   306  // SentToSinks iterates over all sinks and calls fn for each one of them.
   307  func (s *State) SentToSinks(fn func(c Sink) error) error {
   308  	for _, c := range s.getSinks() {
   309  		if err := fn(c); err != nil {
   310  			return err
   311  		}
   312  	}
   313  	return nil
   314  }
   315  
   316  // GetFieldSet returns the FieldSet that has been configured for a given Point.
   317  func (s *State) GetFieldSet(p Point) FieldSet {
   318  	s.registrationMu.RLock()
   319  	defer s.registrationMu.RUnlock()
   320  	return s.pointFields[p]
   321  }