github.com/metacubex/gvisor@v0.0.0-20240320004321-933faba989ec/pkg/seccomp/seccomp_rules.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package seccomp
    16  
    17  import (
    18  	"fmt"
    19  	"sort"
    20  	"strings"
    21  
    22  	"golang.org/x/sys/unix"
    23  	"github.com/metacubex/gvisor/pkg/bpf"
    24  )
    25  
    26  // The offsets are based on the following struct in include/linux/seccomp.h.
    27  //
    28  //	struct seccomp_data {
    29  //		int nr;
    30  //		__u32 arch;
    31  //		__u64 instruction_pointer;
    32  //		__u64 args[6];
    33  //	};
    34  const (
    35  	seccompDataOffsetNR     = 0
    36  	seccompDataOffsetArch   = 4
    37  	seccompDataOffsetIPLow  = 8
    38  	seccompDataOffsetIPHigh = 12
    39  	seccompDataOffsetArgs   = 16
    40  )
    41  
    42  func seccompDataOffsetArgLow(i int) uint32 {
    43  	return uint32(seccompDataOffsetArgs + i*8)
    44  }
    45  
    46  func seccompDataOffsetArgHigh(i int) uint32 {
    47  	return seccompDataOffsetArgLow(i) + 4
    48  }
    49  
    50  // ValueMatcher verifies a numerical value, typically a syscall argument
    51  // or RIP value.
    52  type ValueMatcher interface {
    53  	// String returns a human-readable representation of the match rule.
    54  	// If the returned string contains "VAL", it will be replaced with
    55  	// the symbolic name of the value being matched against.
    56  	String() string
    57  
    58  	// Repr returns a string that will be used for asserting equality between
    59  	// two `ValueMatcher` instances. It must therefore be unique to the
    60  	// `ValueMatcher` implementation and to its parameters.
    61  	// It must not contain the character ";".
    62  	Repr() string
    63  
    64  	// Render should add rules to the given program that verify the value
    65  	// loadable from `value` matches this rule or not.
    66  	// The rules should indicate this by either jumping to `labelSet.Matched()`
    67  	// or `labelSet.Mismatched()`. They may not fall through.
    68  	Render(program *syscallProgram, labelSet *labelSet, value matchedValue)
    69  
    70  	// InterestingValues returns a list of values that may be interesting to
    71  	// test this `ValueMatcher` against.
    72  	InterestingValues() []uint64
    73  }
    74  
    75  // halfValueMatcher verifies a 32-bit value.
    76  type halfValueMatcher interface {
    77  	// String returns a human-friendly representation of the check being done
    78  	// against the 32-bit value.
    79  	// The string "x.(high|low) {{halfValueMatcher.String()}}" should read well,
    80  	// e.g. "x.low == 0xffff".
    81  	String() string
    82  
    83  	// Repr returns a string that will be used for asserting equality between
    84  	// two `halfValueMatcher` instances. It must therefore be unique to the
    85  	// `halfValueMatcher` implementation and to its parameters.
    86  	// It must not contain the character ";".
    87  	Repr() string
    88  
    89  	// HalfRender should add rules to the given program that verify the value
    90  	// loaded into the "A" register matches this 32-bit value or not.
    91  	// The rules should indicate this by either jumping to `labelSet.Matched()`
    92  	// or `labelSet.Mismatched()`. They may not fall through.
    93  	HalfRender(program *syscallProgram, labelSet *labelSet)
    94  
    95  	// InterestingValues returns a list of values that may be interesting to
    96  	// test this `halfValueMatcher` against.
    97  	InterestingValues() []uint32
    98  }
    99  
   100  // halfAnyValue implements `halfValueMatcher` and matches any value.
   101  type halfAnyValue struct{}
   102  
   103  // String implements `halfValueMatcher.String`.
   104  func (halfAnyValue) String() string {
   105  	return "== *"
   106  }
   107  
   108  // Repr implements `halfValueMatcher.Repr`.
   109  func (halfAnyValue) Repr() string {
   110  	return "halfAnyValue"
   111  }
   112  
   113  // HalfRender implements `halfValueMatcher.HalfRender`.
   114  func (halfAnyValue) HalfRender(program *syscallProgram, labelSet *labelSet) {
   115  	program.JumpTo(labelSet.Matched())
   116  }
   117  
   118  // halfEqualTo implements `halfValueMatcher` and matches a specific 32-bit value.
   119  type halfEqualTo uint32
   120  
   121  // String implements `halfValueMatcher.String`.
   122  func (heq halfEqualTo) String() string {
   123  	if heq == 0 {
   124  		return "== 0"
   125  	}
   126  	return fmt.Sprintf("== %#x", uint32(heq))
   127  }
   128  
   129  // Repr implements `halfValueMatcher.Repr`.
   130  func (heq halfEqualTo) Repr() string {
   131  	return fmt.Sprintf("halfEq(%#x)", uint32(heq))
   132  }
   133  
   134  // HalfRender implements `halfValueMatcher.HalfRender`.
   135  func (heq halfEqualTo) HalfRender(program *syscallProgram, labelSet *labelSet) {
   136  	program.If(bpf.Jmp|bpf.Jeq|bpf.K, uint32(heq), labelSet.Matched())
   137  	program.JumpTo(labelSet.Mismatched())
   138  }
   139  
   140  // halfNotSet implements `halfValueMatcher` and matches using the "set"
   141  // bitwise operation.
   142  type halfNotSet uint32
   143  
   144  // String implements `halfValueMatcher.String`.
   145  func (hns halfNotSet) String() string {
   146  	return fmt.Sprintf("& %#x == 0", uint32(hns))
   147  }
   148  
   149  // Repr implements `halfValueMatcher.Repr`.
   150  func (hns halfNotSet) Repr() string {
   151  	return fmt.Sprintf("halfNotSet(%#x)", uint32(hns))
   152  }
   153  
   154  // HalfRender implements `halfValueMatcher.HalfRender`.
   155  func (hns halfNotSet) HalfRender(program *syscallProgram, labelSet *labelSet) {
   156  	program.If(bpf.Jmp|bpf.Jset|bpf.K, uint32(hns), labelSet.Mismatched())
   157  	program.JumpTo(labelSet.Matched())
   158  }
   159  
   160  // halfMaskedEqual implements `halfValueMatcher` and verifies that the value
   161  // is equal after applying a bit mask.
   162  type halfMaskedEqual struct {
   163  	mask  uint32
   164  	value uint32
   165  }
   166  
   167  // String implements `halfValueMatcher.String`.
   168  func (hmeq halfMaskedEqual) String() string {
   169  	if hmeq.value == 0 {
   170  		return fmt.Sprintf("& %#x == 0", hmeq.mask)
   171  	}
   172  	return fmt.Sprintf("& %#x == %#x", hmeq.mask, hmeq.value)
   173  }
   174  
   175  // Repr implements `halfValueMatcher.Repr`.
   176  func (hmeq halfMaskedEqual) Repr() string {
   177  	return fmt.Sprintf("halfMaskedEqual(%#x, %#x)", hmeq.mask, hmeq.value)
   178  }
   179  
   180  // HalfRender implements `halfValueMatcher.HalfRender`.
   181  func (hmeq halfMaskedEqual) HalfRender(program *syscallProgram, labelSet *labelSet) {
   182  	program.Stmt(bpf.Alu|bpf.And|bpf.K, hmeq.mask)
   183  	program.IfNot(bpf.Jmp|bpf.Jeq|bpf.K, hmeq.value, labelSet.Mismatched())
   184  	program.JumpTo(labelSet.Matched())
   185  }
   186  
   187  // splitMatcher implements `ValueMatcher` and verifies each half of the 64-bit
   188  // value independently (with AND semantics).
   189  // It implements `ValueMatcher`, but is never used directly in seccomp filter
   190  // rules. Rather, it acts as an intermediate representation for the rules that
   191  // can be expressed as an AND of two 32-bit values.
   192  type splitMatcher struct {
   193  	// repr is the `Repr()` of the original `ValueMatcher` (pre-split).
   194  	repr string
   195  	// highMatcher is the half-value matcher to verify the high 32 bits.
   196  	highMatcher halfValueMatcher
   197  	// lowMatcher is the half-value matcher to verify the low 32 bits.
   198  	lowMatcher halfValueMatcher
   199  }
   200  
   201  // String implements `ValueMatcher.String`.
   202  func (sm splitMatcher) String() string {
   203  	if sm.repr == "" {
   204  		_, highIsAnyValue := sm.highMatcher.(halfAnyValue)
   205  		_, lowIsAnyValue := sm.lowMatcher.(halfAnyValue)
   206  		if highIsAnyValue && lowIsAnyValue {
   207  			return "== *"
   208  		}
   209  		if highIsAnyValue {
   210  			return fmt.Sprintf("VAL.low %s", sm.lowMatcher.String())
   211  		}
   212  		if lowIsAnyValue {
   213  			return fmt.Sprintf("VAL.high %s", sm.highMatcher.String())
   214  		}
   215  		return fmt.Sprintf("(VAL.high %s && VAL.low %s)", sm.highMatcher.String(), sm.lowMatcher.String())
   216  	}
   217  	return sm.repr
   218  }
   219  
   220  // Repr implements `ValueMatcher.Repr`.
   221  func (sm splitMatcher) Repr() string {
   222  	if sm.repr == "" {
   223  		_, highIsAnyValue := sm.highMatcher.(halfAnyValue)
   224  		_, lowIsAnyValue := sm.lowMatcher.(halfAnyValue)
   225  		if highIsAnyValue && lowIsAnyValue {
   226  			return "split(*)"
   227  		}
   228  		if highIsAnyValue {
   229  			return fmt.Sprintf("low=%s", sm.lowMatcher.Repr())
   230  		}
   231  		if lowIsAnyValue {
   232  			return fmt.Sprintf("high=%s", sm.highMatcher.Repr())
   233  		}
   234  		return fmt.Sprintf("(high=%s && low=%s)", sm.highMatcher.Repr(), sm.lowMatcher.Repr())
   235  	}
   236  	return sm.repr
   237  }
   238  
   239  // Render implements `ValueMatcher.Render`.
   240  func (sm splitMatcher) Render(program *syscallProgram, labelSet *labelSet, value matchedValue) {
   241  	_, highIsAny := sm.highMatcher.(halfAnyValue)
   242  	_, lowIsAny := sm.lowMatcher.(halfAnyValue)
   243  	if highIsAny && lowIsAny {
   244  		program.JumpTo(labelSet.Matched())
   245  		return
   246  	}
   247  	if highIsAny {
   248  		value.LoadLow32Bits()
   249  		sm.lowMatcher.HalfRender(program, labelSet)
   250  		return
   251  	}
   252  	if lowIsAny {
   253  		value.LoadHigh32Bits()
   254  		sm.highMatcher.HalfRender(program, labelSet)
   255  		return
   256  	}
   257  	// We render the "low" bits first on the assumption that most syscall
   258  	// arguments fit within 32-bits, and those rules actually only care
   259  	// about the value of the low 32 bits. This way, we only check the
   260  	// high 32 bits if the low 32 bits have already matched.
   261  	lowLabels := labelSet.Push("low", labelSet.NewLabel(), labelSet.Mismatched())
   262  	lowFrag := program.Record()
   263  	value.LoadLow32Bits()
   264  	sm.lowMatcher.HalfRender(program, lowLabels)
   265  	lowFrag.MustHaveJumpedTo(lowLabels.Matched(), labelSet.Mismatched())
   266  
   267  	program.Label(lowLabels.Matched())
   268  	highFrag := program.Record()
   269  	value.LoadHigh32Bits()
   270  	sm.highMatcher.HalfRender(program, labelSet.Push("high", labelSet.Matched(), labelSet.Mismatched()))
   271  	highFrag.MustHaveJumpedTo(labelSet.Matched(), labelSet.Mismatched())
   272  }
   273  
   274  // high32BitsMatch returns a `splitMatcher` that only matches the high 32 bits
   275  // of a 64-bit value.
   276  func high32BitsMatch(hvm halfValueMatcher) splitMatcher {
   277  	return splitMatcher{
   278  		highMatcher: hvm,
   279  		lowMatcher:  halfAnyValue{},
   280  	}
   281  }
   282  
   283  // low32BitsMatch returns a `splitMatcher` that only matches the low 32 bits
   284  // of a 64-bit value.
   285  func low32BitsMatch(hvm halfValueMatcher) splitMatcher {
   286  	return splitMatcher{
   287  		highMatcher: halfAnyValue{},
   288  		lowMatcher:  hvm,
   289  	}
   290  }
   291  
   292  // splittableValueMatcher should be implemented by `ValueMatcher` that can
   293  // be expressed as a `splitMatcher`.
   294  type splittableValueMatcher interface {
   295  	// split converts this `ValueMatcher` into a `splitMatcher`.
   296  	split() splitMatcher
   297  }
   298  
   299  // renderSplittable is a helper function for the `ValueMatcher.Render`
   300  // implementation of `splittableValueMatcher`s.
   301  func renderSplittable(sm splittableValueMatcher, program *syscallProgram, labelSet *labelSet, value matchedValue) {
   302  	sm.split().Render(program, labelSet, value)
   303  }
   304  
   305  // high32Bits returns the higher 32-bits of the given value.
   306  func high32Bits(val uintptr) uint32 {
   307  	return uint32(val >> 32)
   308  }
   309  
   310  // low32Bits returns the lower 32-bits of the given value.
   311  func low32Bits(val uintptr) uint32 {
   312  	return uint32(val)
   313  }
   314  
   315  // AnyValue is marker to indicate any value will be accepted.
   316  // It implements ValueMatcher.
   317  type AnyValue struct{}
   318  
   319  // String implements `ValueMatcher.String`.
   320  func (AnyValue) String() string {
   321  	return "== *"
   322  }
   323  
   324  // Repr implements `ValueMatcher.Repr`.
   325  func (av AnyValue) Repr() string {
   326  	return av.String()
   327  }
   328  
   329  // Render implements `ValueMatcher.Render`.
   330  func (av AnyValue) Render(program *syscallProgram, labelSet *labelSet, value matchedValue) {
   331  	program.JumpTo(labelSet.Matched())
   332  }
   333  
   334  // EqualTo specifies a value that needs to be strictly matched.
   335  // It implements ValueMatcher.
   336  type EqualTo uintptr
   337  
   338  // String implements `ValueMatcher.String`.
   339  func (eq EqualTo) String() string {
   340  	if eq == 0 {
   341  		return "== 0"
   342  	}
   343  	return fmt.Sprintf("== %#x", uintptr(eq))
   344  }
   345  
   346  // Repr implements `ValueMatcher.Repr`.
   347  func (eq EqualTo) Repr() string {
   348  	return eq.String()
   349  }
   350  
   351  // Render implements `ValueMatcher.Render`.
   352  func (eq EqualTo) Render(program *syscallProgram, labelSet *labelSet, value matchedValue) {
   353  	renderSplittable(eq, program, labelSet, value)
   354  }
   355  
   356  // split implements `splittableValueMatcher.split`.
   357  func (eq EqualTo) split() splitMatcher {
   358  	return splitMatcher{
   359  		repr:        eq.Repr(),
   360  		highMatcher: halfEqualTo(high32Bits(uintptr(eq))),
   361  		lowMatcher:  halfEqualTo(low32Bits(uintptr(eq))),
   362  	}
   363  }
   364  
   365  // NotEqual specifies a value that is strictly not equal.
   366  type NotEqual uintptr
   367  
   368  // String implements `ValueMatcher.String`.
   369  func (ne NotEqual) String() string {
   370  	return fmt.Sprintf("!= %#x", uintptr(ne))
   371  }
   372  
   373  // Repr implements `ValueMatcher.Repr`.
   374  func (ne NotEqual) Repr() string {
   375  	return ne.String()
   376  }
   377  
   378  // Render implements `ValueMatcher.Render`.
   379  func (ne NotEqual) Render(program *syscallProgram, labelSet *labelSet, value matchedValue) {
   380  	// Note that `NotEqual` is *not* a splittable rule by itself, because it is not the
   381  	// conjunction of two `halfValueMatchers` (it is the *disjunction* of them).
   382  	// However, it is also the exact inverse of `EqualTo`.
   383  	// Therefore, we can use `EqualTo` here, and simply invert the
   384  	// matched/mismatched labels.
   385  	EqualTo(ne).Render(program, labelSet.Push("inverted", labelSet.Mismatched(), labelSet.Matched()), value)
   386  }
   387  
   388  // GreaterThan specifies a value that needs to be strictly smaller.
   389  type GreaterThan uintptr
   390  
   391  // String implements `ValueMatcher.String`.
   392  func (gt GreaterThan) String() string {
   393  	return fmt.Sprintf("> %#x", uintptr(gt))
   394  }
   395  
   396  // Repr implements `ValueMatcher.Repr`.
   397  func (gt GreaterThan) Repr() string {
   398  	return gt.String()
   399  }
   400  
   401  // Render implements `ValueMatcher.Render`.
   402  func (gt GreaterThan) Render(program *syscallProgram, labelSet *labelSet, value matchedValue) {
   403  	high := high32Bits(uintptr(gt))
   404  	// Assert the higher 32bits are greater than or equal.
   405  	// arg_high >= high ? continue : violation (arg_high < high)
   406  	value.LoadHigh32Bits()
   407  	program.IfNot(bpf.Jmp|bpf.Jge|bpf.K, high, labelSet.Mismatched())
   408  	// arg_high == high ? continue : success (arg_high > high)
   409  	program.IfNot(bpf.Jmp|bpf.Jeq|bpf.K, high, labelSet.Matched())
   410  	// Assert that the lower 32bits are greater.
   411  	// arg_low > low ? continue/success : violation (arg_high == high and arg_low <= low)
   412  	value.LoadLow32Bits()
   413  	program.IfNot(bpf.Jmp|bpf.Jgt|bpf.K, low32Bits(uintptr(gt)), labelSet.Mismatched())
   414  	program.JumpTo(labelSet.Matched())
   415  }
   416  
   417  // GreaterThanOrEqual specifies a value that needs to be smaller or equal.
   418  type GreaterThanOrEqual uintptr
   419  
   420  // String implements `ValueMatcher.String`.
   421  func (ge GreaterThanOrEqual) String() string {
   422  	return fmt.Sprintf(">= %#x", uintptr(ge))
   423  }
   424  
   425  // Repr implements `ValueMatcher.Repr`.
   426  func (ge GreaterThanOrEqual) Repr() string {
   427  	return ge.String()
   428  }
   429  
   430  // Render implements `ValueMatcher.Render`.
   431  func (ge GreaterThanOrEqual) Render(program *syscallProgram, labelSet *labelSet, value matchedValue) {
   432  	high := high32Bits(uintptr(ge))
   433  	// Assert the higher 32bits are greater than or equal.
   434  	// arg_high >= high ? continue : violation (arg_high < high)
   435  	value.LoadHigh32Bits()
   436  	program.IfNot(bpf.Jmp|bpf.Jge|bpf.K, high, labelSet.Mismatched())
   437  	// arg_high == high ? continue : success (arg_high > high)
   438  	program.IfNot(bpf.Jmp|bpf.Jeq|bpf.K, high, labelSet.Matched())
   439  	// Assert that the lower 32bits are greater or equal (assuming the
   440  	// higher bits are equal).
   441  	// arg_low >= low ? continue/success : violation (arg_high == high and arg_low < low)
   442  	value.LoadLow32Bits()
   443  	program.IfNot(bpf.Jmp|bpf.Jge|bpf.K, low32Bits(uintptr(ge)), labelSet.Mismatched())
   444  	program.JumpTo(labelSet.Matched())
   445  }
   446  
   447  // LessThan specifies a value that needs to be strictly greater.
   448  type LessThan uintptr
   449  
   450  // String implements `ValueMatcher.String`.
   451  func (lt LessThan) String() string {
   452  	return fmt.Sprintf("< %#x", uintptr(lt))
   453  }
   454  
   455  // Repr implements `ValueMatcher.Repr`.
   456  func (lt LessThan) Repr() string {
   457  	return lt.String()
   458  }
   459  
   460  // Render implements `ValueMatcher.Render`.
   461  func (lt LessThan) Render(program *syscallProgram, labelSet *labelSet, value matchedValue) {
   462  	high := high32Bits(uintptr(lt))
   463  	// Assert the higher 32bits are less than or equal.
   464  	// arg_high > high ? violation : continue
   465  	value.LoadHigh32Bits()
   466  	program.If(bpf.Jmp|bpf.Jgt|bpf.K, high, labelSet.Mismatched())
   467  	// arg_high == high ? continue : success (arg_high < high)
   468  	program.IfNot(bpf.Jmp|bpf.Jeq|bpf.K, high, labelSet.Matched())
   469  	// Assert that the lower 32bits are less (assuming the
   470  	// higher bits are equal).
   471  	// arg_low >= low ? violation : continue
   472  	value.LoadLow32Bits()
   473  	program.If(bpf.Jmp|bpf.Jge|bpf.K, low32Bits(uintptr(lt)), labelSet.Mismatched())
   474  	program.JumpTo(labelSet.Matched())
   475  }
   476  
   477  // LessThanOrEqual specifies a value that needs to be greater or equal.
   478  type LessThanOrEqual uintptr
   479  
   480  // String implements `ValueMatcher.String`.
   481  func (le LessThanOrEqual) String() string {
   482  	return fmt.Sprintf("<= %#x", uintptr(le))
   483  }
   484  
   485  // Repr implements `ValueMatcher.Repr`.
   486  func (le LessThanOrEqual) Repr() string {
   487  	return le.String()
   488  }
   489  
   490  // Render implements `ValueMatcher.Render`.
   491  func (le LessThanOrEqual) Render(program *syscallProgram, labelSet *labelSet, value matchedValue) {
   492  	high := high32Bits(uintptr(le))
   493  	// Assert the higher 32bits are less than or equal.
   494  	// assert arg_high > high ? violation : continue
   495  	value.LoadHigh32Bits()
   496  	program.If(bpf.Jmp|bpf.Jgt|bpf.K, high, labelSet.Mismatched())
   497  	// arg_high == high ? continue : success
   498  	program.IfNot(bpf.Jmp|bpf.Jeq|bpf.K, high, labelSet.Matched())
   499  	// Assert the lower bits are less than or equal (assuming
   500  	// the higher bits are equal).
   501  	// arg_low > low ? violation : success
   502  	value.LoadLow32Bits()
   503  	program.If(bpf.Jmp|bpf.Jgt|bpf.K, low32Bits(uintptr(le)), labelSet.Mismatched())
   504  	program.JumpTo(labelSet.Matched())
   505  }
   506  
   507  // NonNegativeFD ensures that an FD argument is a non-negative int32.
   508  type NonNegativeFD struct{}
   509  
   510  // String implements `ValueMatcher.String`.
   511  func (NonNegativeFD) String() string {
   512  	return "is non-negative FD"
   513  }
   514  
   515  // Repr implements `ValueMatcher.Repr`.
   516  func (NonNegativeFD) Repr() string {
   517  	return "NonNegativeFD"
   518  }
   519  
   520  // Render implements `ValueMatcher.Render`.
   521  func (nnfd NonNegativeFD) Render(program *syscallProgram, labelSet *labelSet, value matchedValue) {
   522  	renderSplittable(nnfd, program, labelSet, value)
   523  }
   524  
   525  // split implements `splittableValueMatcher.split`.
   526  func (nnfd NonNegativeFD) split() splitMatcher {
   527  	return splitMatcher{
   528  		repr: nnfd.Repr(),
   529  		// FDs are 32 bits, so the high 32 bits must all be zero.
   530  		// Negative int32 has the MSB (31st bit) set.
   531  		// So the low 32bits of the FD value must not have the 31st bit set.
   532  		highMatcher: halfEqualTo(0),
   533  		lowMatcher:  halfNotSet(1 << 31),
   534  	}
   535  }
   536  
   537  // MaskedEqual specifies a value that matches the input after the input is
   538  // masked (bitwise &) against the given mask. It implements `ValueMatcher`.
   539  type maskedEqual struct {
   540  	mask  uintptr
   541  	value uintptr
   542  }
   543  
   544  // String implements `ValueMatcher.String`.
   545  func (me maskedEqual) String() string {
   546  	return fmt.Sprintf("& %#x == %#x", me.mask, me.value)
   547  }
   548  
   549  // Repr implements `ValueMatcher.Repr`.
   550  func (me maskedEqual) Repr() string {
   551  	return me.String()
   552  }
   553  
   554  // Render implements `ValueMatcher.Render`.
   555  func (me maskedEqual) Render(program *syscallProgram, labelSet *labelSet, value matchedValue) {
   556  	renderSplittable(me, program, labelSet, value)
   557  }
   558  
   559  // split implements `splittableValueMatcher.Split`.
   560  func (me maskedEqual) split() splitMatcher {
   561  	return splitMatcher{
   562  		repr:        me.Repr(),
   563  		highMatcher: halfMaskedEqual{high32Bits(me.mask), high32Bits(me.value)},
   564  		lowMatcher:  halfMaskedEqual{low32Bits(me.mask), low32Bits(me.value)},
   565  	}
   566  }
   567  
   568  // MaskedEqual specifies a value that matches the input after the input is
   569  // masked (bitwise &) against the given mask. Can be used to verify that input
   570  // only includes certain approved flags.
   571  func MaskedEqual(mask, value uintptr) ValueMatcher {
   572  	return maskedEqual{
   573  		mask:  mask,
   574  		value: value,
   575  	}
   576  }
   577  
   578  // BitsAllowlist specifies that a value can only have non-zero bits within
   579  // the mask specified in `allowlist`. It implements `ValueMatcher`.
   580  func BitsAllowlist(allowlist uintptr) ValueMatcher {
   581  	return MaskedEqual(^allowlist, 0)
   582  }
   583  
   584  // SyscallRule expresses a set of rules to verify the arguments of a specific
   585  // syscall.
   586  type SyscallRule interface {
   587  	// Render renders the syscall rule in the given `program`.
   588  	// The emitted instructions **must** end up jumping to either
   589  	// `labelSet.Matched()` or `labelSet.Mismatched()`; they may
   590  	// not "fall through" to whatever instructions will be added
   591  	// next into the program.
   592  	Render(program *syscallProgram, labelSet *labelSet)
   593  
   594  	// Copy returns a copy of this `SyscallRule`.
   595  	Copy() SyscallRule
   596  
   597  	// Recurse should call the given function on all `SyscallRule`s that are
   598  	// part of this `SyscallRule`, and should replace them with the returned
   599  	// `SyscallRule`. For example, conjunctive rules should call the given
   600  	// function on each of the `SyscallRule`s that they are ANDing, replacing
   601  	// them with the rule returned by the function.
   602  	Recurse(func(SyscallRule) SyscallRule)
   603  
   604  	// String returns a human-readable string representing what the rule does.
   605  	String() string
   606  }
   607  
   608  // MatchAll implements `SyscallRule` and matches everything.
   609  type MatchAll struct{}
   610  
   611  // Render implements `SyscallRule.Render`.
   612  func (MatchAll) Render(program *syscallProgram, labelSet *labelSet) {
   613  	program.JumpTo(labelSet.Matched())
   614  }
   615  
   616  // Copy implements `SyscallRule.Copy`.
   617  func (MatchAll) Copy() SyscallRule {
   618  	return MatchAll{}
   619  }
   620  
   621  // Recurse implements `SyscallRule.Recurse`.
   622  func (MatchAll) Recurse(func(SyscallRule) SyscallRule) {}
   623  
   624  // String implements `SyscallRule.String`.
   625  func (MatchAll) String() string { return "true" }
   626  
   627  // Or expresses an "OR" (a disjunction) over a set of `SyscallRule`s.
   628  // An `Or` may not be empty.
   629  type Or []SyscallRule
   630  
   631  // Render implements `SyscallRule.Render`.
   632  func (or Or) Render(program *syscallProgram, labelSet *labelSet) {
   633  	if len(or) == 0 {
   634  		panic("Or expression cannot be empty")
   635  	}
   636  	// If `len(or) == 1`, this will be optimized away to be the same as
   637  	// rendering the single rule in the disjunction.
   638  	for i, rule := range or {
   639  		frag := program.Record()
   640  		nextRuleLabel := labelSet.NewLabel()
   641  		rule.Render(program, labelSet.Push(fmt.Sprintf("or[%d]", i), labelSet.Matched(), nextRuleLabel))
   642  		frag.MustHaveJumpedTo(labelSet.Matched(), nextRuleLabel)
   643  		program.Label(nextRuleLabel)
   644  	}
   645  	program.JumpTo(labelSet.Mismatched())
   646  }
   647  
   648  // Copy implements `SyscallRule.Copy`.
   649  func (or Or) Copy() SyscallRule {
   650  	orCopy := make([]SyscallRule, len(or))
   651  	for i, rule := range or {
   652  		orCopy[i] = rule.Copy()
   653  	}
   654  	return Or(orCopy)
   655  }
   656  
   657  // Recurse implements `SyscallRule.Recurse`.
   658  func (or Or) Recurse(fn func(SyscallRule) SyscallRule) {
   659  	for i, rule := range or {
   660  		or[i] = fn(rule)
   661  	}
   662  }
   663  
   664  // String implements `SyscallRule.String`.
   665  func (or Or) String() string {
   666  	switch len(or) {
   667  	case 0:
   668  		return "invalid"
   669  	case 1:
   670  		return or[0].String()
   671  	default:
   672  		var sb strings.Builder
   673  		sb.WriteRune('(')
   674  		for i, rule := range or {
   675  			if i != 0 {
   676  				sb.WriteString(" || ")
   677  			}
   678  			sb.WriteString(rule.String())
   679  		}
   680  		sb.WriteRune(')')
   681  		return sb.String()
   682  	}
   683  }
   684  
   685  // And expresses an "AND" (a conjunction) over a set of `SyscallRule`s.
   686  // An `And` may not be empty.
   687  type And []SyscallRule
   688  
   689  // Render implements `SyscallRule.Render`.
   690  func (and And) Render(program *syscallProgram, labelSet *labelSet) {
   691  	if len(and) == 0 {
   692  		panic("And expression cannot be empty")
   693  	}
   694  	// If `len(and) == 1`, this will be optimized away to be the same as
   695  	// rendering the single rule in the conjunction.
   696  	for i, rule := range and {
   697  		frag := program.Record()
   698  		nextRuleLabel := labelSet.NewLabel()
   699  		rule.Render(program, labelSet.Push(fmt.Sprintf("and[%d]", i), nextRuleLabel, labelSet.Mismatched()))
   700  		frag.MustHaveJumpedTo(nextRuleLabel, labelSet.Mismatched())
   701  		program.Label(nextRuleLabel)
   702  	}
   703  	program.JumpTo(labelSet.Matched())
   704  }
   705  
   706  // Copy implements `SyscallRule.Copy`.
   707  func (and And) Copy() SyscallRule {
   708  	andCopy := make([]SyscallRule, len(and))
   709  	for i, rule := range and {
   710  		andCopy[i] = rule.Copy()
   711  	}
   712  	return And(andCopy)
   713  }
   714  
   715  // Recurse implements `SyscallRule.Recurse`.
   716  func (and And) Recurse(fn func(SyscallRule) SyscallRule) {
   717  	for i, rule := range and {
   718  		and[i] = fn(rule)
   719  	}
   720  }
   721  
   722  // String implements `SyscallRule.String`.
   723  func (and And) String() string {
   724  	switch len(and) {
   725  	case 0:
   726  		return "invalid"
   727  	case 1:
   728  		return and[0].String()
   729  	default:
   730  		var sb strings.Builder
   731  		sb.WriteRune('(')
   732  		for i, rule := range and {
   733  			if i != 0 {
   734  				sb.WriteString(" && ")
   735  			}
   736  			sb.WriteString(rule.String())
   737  		}
   738  		sb.WriteRune(')')
   739  		return sb.String()
   740  	}
   741  }
   742  
   743  // PerArg implements SyscallRule and verifies the syscall arguments and RIP.
   744  //
   745  // For example:
   746  //
   747  //	rule := PerArg{
   748  //		EqualTo(linux.ARCH_GET_FS | linux.ARCH_SET_FS), // arg0
   749  //	}
   750  type PerArg [7]ValueMatcher // 6 arguments + RIP
   751  
   752  // RuleIP indicates what rules in the Rule array have to be applied to
   753  // instruction pointer.
   754  const RuleIP = 6
   755  
   756  // clone returns a copy of this `PerArg`.
   757  // It is more efficient than `Copy` because it returns a `PerArg`
   758  // directly, rather than a `SyscallRule` interface.
   759  func (pa PerArg) clone() PerArg {
   760  	return PerArg{
   761  		pa[0],
   762  		pa[1],
   763  		pa[2],
   764  		pa[3],
   765  		pa[4],
   766  		pa[5],
   767  		pa[6],
   768  	}
   769  }
   770  
   771  // Copy implements `SyscallRule.Copy`.
   772  func (pa PerArg) Copy() SyscallRule {
   773  	return pa.clone()
   774  }
   775  
   776  // Render implements `SyscallRule.Render`.
   777  func (pa PerArg) Render(program *syscallProgram, labelSet *labelSet) {
   778  	for i, arg := range pa {
   779  		if arg == nil {
   780  			continue
   781  		}
   782  		frag := program.Record()
   783  		nextArgLabel := labelSet.NewLabel()
   784  		labelSuffix := fmt.Sprintf("arg[%d]", i)
   785  		// Determine the data offset for low and high bits of input.
   786  		dataOffsetLow := seccompDataOffsetArgLow(i)
   787  		dataOffsetHigh := seccompDataOffsetArgHigh(i)
   788  		if i == RuleIP {
   789  			dataOffsetLow = seccompDataOffsetIPLow
   790  			dataOffsetHigh = seccompDataOffsetIPHigh
   791  			labelSuffix = "rip"
   792  		}
   793  		ls := labelSet.Push(labelSuffix, nextArgLabel, labelSet.Mismatched())
   794  		arg.Render(program, ls, matchedValue{
   795  			program:        program,
   796  			dataOffsetHigh: dataOffsetHigh,
   797  			dataOffsetLow:  dataOffsetLow,
   798  		})
   799  		frag.MustHaveJumpedTo(ls.Matched(), ls.Mismatched())
   800  		program.Label(nextArgLabel)
   801  	}
   802  	// Matched all argument-wise rules, jump to the final rule matched label.
   803  	program.JumpTo(labelSet.Matched())
   804  }
   805  
   806  // Recurse implements `SyscallRule.Recurse`.
   807  func (PerArg) Recurse(fn func(SyscallRule) SyscallRule) {}
   808  
   809  // String implements `SyscallRule.String`.
   810  func (pa PerArg) String() string {
   811  	var sb strings.Builder
   812  	writtenArgs := 0
   813  	for i, arg := range pa {
   814  		if arg == nil {
   815  			continue
   816  		}
   817  		if _, isAny := arg.(AnyValue); isAny {
   818  			continue
   819  		}
   820  		if writtenArgs != 0 {
   821  			sb.WriteString(" && ")
   822  		}
   823  		str := arg.String()
   824  		var varName string
   825  		if i == RuleIP {
   826  			varName = "rip"
   827  		} else {
   828  			varName = fmt.Sprintf("arg[%d]", i)
   829  		}
   830  		if strings.Contains(str, "VAL") {
   831  			sb.WriteString(strings.ReplaceAll(str, "VAL", varName))
   832  		} else {
   833  			sb.WriteString(varName)
   834  			sb.WriteRune(' ')
   835  			sb.WriteString(str)
   836  		}
   837  		writtenArgs++
   838  	}
   839  	if writtenArgs == 0 {
   840  		return "true"
   841  	}
   842  	if writtenArgs == 1 {
   843  		return sb.String()
   844  	}
   845  	return "(" + sb.String() + ")"
   846  }
   847  
   848  // SyscallRules maps syscall numbers to their corresponding rules.
   849  //
   850  // For example:
   851  //
   852  //	rules := MakeSyscallRules(map[uintptr]SyscallRule{
   853  //		syscall.SYS_FUTEX: Or{
   854  //			PerArg{
   855  //				AnyValue{},
   856  //				EqualTo(linux.FUTEX_WAIT | linux.FUTEX_PRIVATE_FLAG),
   857  //			},
   858  //			PerArg{
   859  //				AnyValue{},
   860  //				EqualTo(linux.FUTEX_WAKE | linux.FUTEX_PRIVATE_FLAG),
   861  //			},
   862  //		},
   863  //		syscall.SYS_GETPID: MatchAll{},
   864  //	})
   865  type SyscallRules struct {
   866  	rules map[uintptr]SyscallRule
   867  }
   868  
   869  // NewSyscallRules returns a new SyscallRules.
   870  func NewSyscallRules() SyscallRules {
   871  	return MakeSyscallRules(nil)
   872  }
   873  
   874  // MakeSyscallRules returns a new SyscallRules with the given set of rules.
   875  func MakeSyscallRules(rules map[uintptr]SyscallRule) SyscallRules {
   876  	if rules == nil {
   877  		rules = make(map[uintptr]SyscallRule)
   878  	}
   879  	return SyscallRules{rules: rules}
   880  }
   881  
   882  // String returns a string representation of the syscall rules, one syscall
   883  // per line.
   884  func (sr SyscallRules) String() string {
   885  	if len(sr.rules) == 0 {
   886  		return "(no rules)"
   887  	}
   888  	sysnums := make([]uintptr, 0, len(sr.rules))
   889  	for sysno := range sr.rules {
   890  		sysnums = append(sysnums, sysno)
   891  	}
   892  	sort.Slice(sysnums, func(i, j int) bool {
   893  		return sysnums[i] < sysnums[j]
   894  	})
   895  	var sb strings.Builder
   896  	for _, sysno := range sysnums {
   897  		sb.WriteString(fmt.Sprintf("syscall %d: %v\n", sysno, sr.rules[sysno]))
   898  	}
   899  	return strings.TrimSpace(sb.String())
   900  }
   901  
   902  // Size returns the number of syscall numbers for which a rule is defined.
   903  func (sr SyscallRules) Size() int {
   904  	return len(sr.rules)
   905  }
   906  
   907  // Get returns the rule defined for the given syscall number.
   908  func (sr SyscallRules) Get(sysno uintptr) SyscallRule {
   909  	return sr.rules[sysno]
   910  }
   911  
   912  // Has returns whether there is a rule defined for the given syscall number.
   913  func (sr SyscallRules) Has(sysno uintptr) bool {
   914  	_, has := sr.rules[sysno]
   915  	return has
   916  }
   917  
   918  // Add adds the given rule. It will create a new entry for a new syscall, otherwise
   919  // it will append to the existing rules.
   920  // Returns itself for chainability.
   921  func (sr SyscallRules) Add(sysno uintptr, r SyscallRule) SyscallRules {
   922  	if cur, ok := sr.rules[sysno]; ok {
   923  		sr.rules[sysno] = Or{cur, r}
   924  	} else {
   925  		sr.rules[sysno] = r
   926  	}
   927  	return sr
   928  }
   929  
   930  // Set sets the rule for the given syscall number.
   931  // Panics if there is already a rule for this syscall number.
   932  // This is useful for deterministic rules where the set of syscall rules is
   933  // added in multiple chunks but is known to never overlap by syscall number.
   934  // Returns itself for chainability.
   935  func (sr SyscallRules) Set(sysno uintptr, r SyscallRule) SyscallRules {
   936  	if cur, ok := sr.rules[sysno]; ok {
   937  		panic(fmt.Sprintf("tried to set syscall rule for sysno=%d to %v but it is already set to %v", sysno, r, cur))
   938  	}
   939  	sr.rules[sysno] = r
   940  	return sr
   941  }
   942  
   943  // Remove clears the syscall rule for the given syscall number.
   944  // It will panic if there is no syscall rule for this syscall number.
   945  func (sr SyscallRules) Remove(sysno uintptr) {
   946  	if !sr.Has(sysno) {
   947  		panic(fmt.Sprintf("tried to remove syscall rule for sysno=%d but it is not set", sysno))
   948  	}
   949  	delete(sr.rules, sysno)
   950  }
   951  
   952  // Merge merges the given SyscallRules.
   953  // Returns itself for chainability.
   954  func (sr SyscallRules) Merge(other SyscallRules) SyscallRules {
   955  	for sysno, r := range other.rules {
   956  		sr.Add(sysno, r)
   957  	}
   958  	return sr
   959  }
   960  
   961  // Copy returns a deep copy of these SyscallRules.
   962  func (sr SyscallRules) Copy() SyscallRules {
   963  	rulesCopy := make(map[uintptr]SyscallRule, len(sr.rules))
   964  	for sysno, r := range sr.rules {
   965  		rulesCopy[sysno] = r.Copy()
   966  	}
   967  	return MakeSyscallRules(rulesCopy)
   968  }
   969  
   970  // ForSingleArgument runs the given function on the `ValueMatcher` rules
   971  // for a single specific syscall argument of the given syscall number.
   972  // If the function returns an error, it will be propagated along with some
   973  // details as to which rule caused the error to be returned.
   974  // ForSingleArgument also returns an error if there are no rules defined for
   975  // the given syscall number, or if at least one rule for this syscall number
   976  // is not either a `PerArg` rule or a rule with children rules (as this would
   977  // indicate that the `PerArg` rules alone may not be a good representation of
   978  // the entire set of rules for this system call).
   979  func (sr SyscallRules) ForSingleArgument(sysno uintptr, argNum int, fn func(ValueMatcher) error) error {
   980  	if argNum < 0 || argNum >= len(PerArg{}) {
   981  		return fmt.Errorf("invalid argument number %d", argNum)
   982  	}
   983  	if !sr.Has(sysno) {
   984  		return fmt.Errorf("syscall %d has no rules defined", sysno)
   985  	}
   986  	var err error
   987  	var process func(SyscallRule) SyscallRule
   988  	var callCount int
   989  	process = func(r SyscallRule) SyscallRule {
   990  		callCount++
   991  		pa, isPerArg := r.(PerArg)
   992  		if isPerArg {
   993  			if gotErr := fn(pa[argNum]); gotErr != nil && err == nil {
   994  				err = fmt.Errorf("PerArg rule %v: arg[%d] = %v (type %T): %v", pa, argNum, pa[argNum], pa[argNum], gotErr)
   995  			}
   996  		} else {
   997  			beforeRecurse := callCount
   998  			r.Recurse(process)
   999  			if callCount == beforeRecurse {
  1000  				err = fmt.Errorf("rule %v (type: %T) is not a PerArg or a recursive rule", r, r)
  1001  			}
  1002  		}
  1003  		return r
  1004  	}
  1005  	process(sr.rules[sysno])
  1006  	return err
  1007  }
  1008  
  1009  // DenyNewExecMappings is a set of rules that denies creating new executable
  1010  // mappings and converting existing ones.
  1011  var DenyNewExecMappings = MakeSyscallRules(map[uintptr]SyscallRule{
  1012  	unix.SYS_MMAP: PerArg{
  1013  		AnyValue{},
  1014  		AnyValue{},
  1015  		MaskedEqual(unix.PROT_EXEC, unix.PROT_EXEC),
  1016  	},
  1017  	unix.SYS_MPROTECT: PerArg{
  1018  		AnyValue{},
  1019  		AnyValue{},
  1020  		MaskedEqual(unix.PROT_EXEC, unix.PROT_EXEC),
  1021  	},
  1022  })