gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/pkg/seccomp/seccomp.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package seccomp provides generation of basic seccomp filters. Currently,
    16  // only little endian systems are supported.
    17  package seccomp
    18  
    19  import (
    20  	"fmt"
    21  	"sort"
    22  	"strings"
    23  	"time"
    24  
    25  	"gvisor.dev/gvisor/pkg/abi/linux"
    26  	"gvisor.dev/gvisor/pkg/bpf"
    27  	"gvisor.dev/gvisor/pkg/log"
    28  )
    29  
    30  const (
    31  	// skipOneInst is the offset to take for skipping one instruction.
    32  	skipOneInst = 1
    33  
    34  	// defaultLabel is the label for the default action.
    35  	defaultLabel = label("default_action")
    36  
    37  	// vsyscallPageIPMask is the bit we expect to see in the instruction
    38  	// pointer of a vsyscall call.
    39  	vsyscallPageIPMask = 1 << 31
    40  )
    41  
    42  // Install generates BPF code based on the set of syscalls provided. It only
    43  // allows syscalls that conform to the specification. Syscalls that violate the
    44  // specification will trigger RET_KILL_PROCESS. If RET_KILL_PROCESS is not
    45  // supported, violations will trigger RET_TRAP instead. RET_KILL_THREAD is not
    46  // used because it only kills the offending thread and often keeps the sentry
    47  // hanging.
    48  //
    49  // denyRules describes forbidden syscalls. rules describes allowed syscalls.
    50  // denyRules is executed before rules.
    51  //
    52  // Be aware that RET_TRAP sends SIGSYS to the process and it may be ignored,
    53  // making it possible for the process to continue running after a violation.
    54  // However, it will leave a SECCOMP audit event trail behind. In any case, the
    55  // syscall is still blocked from executing.
    56  func Install(rules SyscallRules, denyRules SyscallRules, options ProgramOptions) error {
    57  	// ***   DEBUG TIP   ***
    58  	// If you suspect the Sentry is getting killed due to a seccomp violation,
    59  	// look for the `debugFilter` boolean in `//runsc/boot/filter/filter.go`.
    60  
    61  	log.Infof("Installing seccomp filters for %d syscalls (action=%v)", rules.Size(), options.DefaultAction)
    62  
    63  	instrs, _, err := BuildProgram([]RuleSet{
    64  		{
    65  			Rules:  denyRules,
    66  			Action: options.DefaultAction,
    67  		},
    68  		{
    69  			Rules:  rules,
    70  			Action: linux.SECCOMP_RET_ALLOW,
    71  		},
    72  	}, options)
    73  	if log.IsLogging(log.Debug) {
    74  		programStr, errDecode := bpf.DecodeInstructions(instrs)
    75  		if errDecode != nil {
    76  			programStr = fmt.Sprintf("Error: %v\n%s", errDecode, programStr)
    77  		}
    78  		log.Debugf("Seccomp program dump:\n%s", programStr)
    79  	}
    80  	if err != nil {
    81  		return err
    82  	}
    83  
    84  	// Perform the actual installation.
    85  	if err := SetFilter(instrs); err != nil {
    86  		return fmt.Errorf("failed to set filter: %v", err)
    87  	}
    88  
    89  	log.Infof("Seccomp filters installed.")
    90  	return nil
    91  }
    92  
    93  // DefaultAction returns a sane default for a failure to match
    94  // a seccomp-bpf filter. Either kill the process, or trap.
    95  func DefaultAction() (linux.BPFAction, error) {
    96  	available, err := isKillProcessAvailable()
    97  	if err != nil {
    98  		return 0, err
    99  	}
   100  	if available {
   101  		return linux.SECCOMP_RET_KILL_PROCESS, nil
   102  	}
   103  	return linux.SECCOMP_RET_TRAP, nil
   104  }
   105  
   106  // RuleSet is a set of rules and associated action.
   107  type RuleSet struct {
   108  	Rules  SyscallRules
   109  	Action linux.BPFAction
   110  
   111  	// Vsyscall indicates that a check is made for a function being called
   112  	// from kernel mappings. This is where the vsyscall page is located
   113  	// (and typically) emulated, so this RuleSet will not match any
   114  	// functions not dispatched from the vsyscall page.
   115  	Vsyscall bool
   116  }
   117  
   118  // SyscallName gives names to system calls. It is used purely for debugging purposes.
   119  //
   120  // An alternate namer can be provided to the package at initialization time.
   121  var SyscallName = func(sysno uintptr) string {
   122  	return fmt.Sprintf("syscall_%d", sysno)
   123  }
   124  
   125  // syscallProgram builds a BPF program for applying syscall rules.
   126  // It is a stateful struct that is updated as the program is built.
   127  type syscallProgram struct {
   128  	// program is the underlying BPF program being built.
   129  	program *bpf.ProgramBuilder
   130  }
   131  
   132  // Stmt adds a statement to the program.
   133  func (s *syscallProgram) Stmt(code uint16, k uint32) {
   134  	s.program.AddStmt(code, k)
   135  }
   136  
   137  // label is a custom label type which is returned by `labelSet`.
   138  type label string
   139  
   140  // JumpTo adds a jump instruction to the program, jumping to the given label.
   141  func (s *syscallProgram) JumpTo(label label) {
   142  	s.program.AddDirectJumpLabel(string(label))
   143  }
   144  
   145  // If checks a condition and jumps to a label if the condition is true.
   146  // If the condition is false, the program continues executing (no jumping).
   147  func (s *syscallProgram) If(code uint16, k uint32, jt label) {
   148  	s.program.AddJump(code, k, 0, skipOneInst)
   149  	s.JumpTo(jt)
   150  }
   151  
   152  // IfNot checks a condition and jumps to a label if the condition is false.
   153  // If the condition is true, the program continues executing (no jumping).
   154  func (s *syscallProgram) IfNot(code uint16, k uint32, jf label) {
   155  	s.program.AddJump(code, k, skipOneInst, 0)
   156  	s.JumpTo(jf)
   157  }
   158  
   159  // Ret adds a return instruction to the program.
   160  func (s *syscallProgram) Ret(action linux.BPFAction) {
   161  	s.Stmt(bpf.Ret|bpf.K, uint32(action))
   162  }
   163  
   164  // Label adds a label to the program.
   165  // It panics if this label has already been added to the program.
   166  func (s *syscallProgram) Label(label label) {
   167  	if err := s.program.AddLabel(string(label)); err != nil {
   168  		panic(fmt.Sprintf("cannot add label %q to program: %v", label, err))
   169  	}
   170  }
   171  
   172  // Record starts recording the instructions added to the program from now on.
   173  // It returns a syscallFragment which can be used to perform assertions on the
   174  // possible set of outcomes of the set of instruction that has been added
   175  // since `Record` was called.
   176  func (s *syscallProgram) Record() syscallProgramFragment {
   177  	return syscallProgramFragment{s.program.Record()}
   178  }
   179  
   180  // syscallProgramFragment represents a fragment of the syscall program.
   181  type syscallProgramFragment struct {
   182  	getFragment func() bpf.ProgramFragment
   183  }
   184  
   185  // MustHaveJumpedTo asserts that the fragment must jump to one of the
   186  // given labels.
   187  // The fragment may not jump to any other label, nor return, nor fall through.
   188  func (f syscallProgramFragment) MustHaveJumpedTo(labels ...label) {
   189  	f.MustHaveJumpedToOrReturned(labels, nil)
   190  }
   191  
   192  // MustHaveJumpedToOrReturned asserts that the fragment must jump to one of
   193  // the given labels, or have returned one of the given return values.
   194  // The fragment may not jump to any other label, nor fall through,
   195  // nor return a non-deterministic value.
   196  func (f syscallProgramFragment) MustHaveJumpedToOrReturned(possibleLabels []label, possibleReturnValues map[linux.BPFAction]struct{}) {
   197  	fragment := f.getFragment()
   198  	outcomes := fragment.Outcomes()
   199  	if outcomes.MayFallThrough {
   200  		panic(fmt.Sprintf("fragment %v may fall through", fragment))
   201  	}
   202  	if len(possibleReturnValues) == 0 && outcomes.MayReturn() {
   203  		panic(fmt.Sprintf("fragment %v may return", fragment))
   204  	}
   205  	if outcomes.MayReturnRegisterA {
   206  		panic(fmt.Sprintf("fragment %v may return register A", fragment))
   207  	}
   208  	if outcomes.MayJumpToKnownOffsetBeyondFragment {
   209  		panic(fmt.Sprintf("fragment %v may jump to an offset beyond the fragment", fragment))
   210  	}
   211  	for jumpLabel := range outcomes.MayJumpToUnresolvedLabels {
   212  		found := false
   213  		for _, wantLabel := range possibleLabels {
   214  			if jumpLabel == string(wantLabel) {
   215  				found = true
   216  				break
   217  			}
   218  		}
   219  		if !found {
   220  			panic(fmt.Sprintf("fragment %v may jump to a label %q which is not one of %v", fragment, jumpLabel, possibleLabels))
   221  		}
   222  	}
   223  	for returnValue := range outcomes.MayReturnImmediate {
   224  		if _, found := possibleReturnValues[returnValue]; !found {
   225  			panic(fmt.Sprintf("fragment %v may return a value %q which is not one of %v", fragment, returnValue, possibleReturnValues))
   226  		}
   227  	}
   228  }
   229  
   230  // labelSet keeps track of labels that individual rules may jump to if they
   231  // either match or mismatch.
   232  // It can generate unique label names, and can be used recursively within
   233  // rules.
   234  type labelSet struct {
   235  	// prefix is a label prefix used when generating label names.
   236  	prefix string
   237  
   238  	// labelCounter is used to generate unique label names.
   239  	labelCounter int
   240  
   241  	// ruleMatched is the label that a rule should jump to if it matches.
   242  	ruleMatched label
   243  
   244  	// ruleMismatched is the label that a rule should jump to if it doesn't
   245  	// match.
   246  	ruleMismatched label
   247  }
   248  
   249  // NewLabel returns a new unique label.
   250  func (l *labelSet) NewLabel() label {
   251  	newLabel := label(fmt.Sprintf("%s#%d", l.prefix, l.labelCounter))
   252  	l.labelCounter++
   253  	return newLabel
   254  }
   255  
   256  // Matched returns the label to jump to if the rule matches.
   257  func (l *labelSet) Matched() label {
   258  	return l.ruleMatched
   259  }
   260  
   261  // Mismatched returns the label to jump to if the rule does not match.
   262  func (l *labelSet) Mismatched() label {
   263  	return l.ruleMismatched
   264  }
   265  
   266  // Push creates a new labelSet meant to be used in a recursive context of the
   267  // rule currently being rendered.
   268  // Labels generated by this new labelSet will have `labelSuffix` appended to
   269  // this labelSet's current prefix, and will have its matched/mismatched labels
   270  // point to the given labels.
   271  func (l *labelSet) Push(labelSuffix string, newRuleMatch, newRuleMismatch label) *labelSet {
   272  	newPrefix := labelSuffix
   273  	if l.prefix != "" {
   274  		newPrefix = fmt.Sprintf("%s_%s", l.prefix, labelSuffix)
   275  	}
   276  	return &labelSet{
   277  		prefix:         newPrefix,
   278  		ruleMatched:    newRuleMatch,
   279  		ruleMismatched: newRuleMismatch,
   280  	}
   281  }
   282  
   283  // matchedValue keeps track of BPF instructions needed to load a 64-bit value
   284  // being matched against. Since BPF can only do operations on 32-bit
   285  // instructions, value-matching code needs to selectively load one or the
   286  // other half of the 64-bit value.
   287  type matchedValue struct {
   288  	program        *syscallProgram
   289  	dataOffsetHigh uint32
   290  	dataOffsetLow  uint32
   291  }
   292  
   293  // LoadHigh32Bits loads the high 32-bit of the 64-bit value into register A.
   294  func (m matchedValue) LoadHigh32Bits() {
   295  	m.program.Stmt(bpf.Ld|bpf.Abs|bpf.W, m.dataOffsetHigh)
   296  }
   297  
   298  // LoadLow32Bits loads the low 32-bit of the 64-bit value into register A.
   299  func (m matchedValue) LoadLow32Bits() {
   300  	m.program.Stmt(bpf.Ld|bpf.Abs|bpf.W, m.dataOffsetLow)
   301  }
   302  
   303  // ProgramOptions configure a seccomp program.
   304  type ProgramOptions struct {
   305  	// DefaultAction is the action returned when none of the rules match.
   306  	DefaultAction linux.BPFAction
   307  
   308  	// BadArchAction is the action returned when the architecture of the
   309  	// syscall structure input doesn't match the one the program expects.
   310  	BadArchAction linux.BPFAction
   311  
   312  	// Optimize specifies whether optimizations should be applied to the
   313  	// syscall rules and generated BPF bytecode.
   314  	Optimize bool
   315  
   316  	// HotSyscalls is the set of syscall numbers that are the hottest,
   317  	// where "hotness" refers to frequency (regardless of the amount of
   318  	// computation that the kernel will do handling them, and regardless of
   319  	// the complexity of the syscall rule for this).
   320  	// It should only contain very hot syscalls (i.e. any syscall that is
   321  	// called >10% of the time out of all syscalls made).
   322  	// It is ordered from most frequent to least frequent.
   323  	HotSyscalls []uintptr
   324  }
   325  
   326  // DefaultProgramOptions returns the default program options.
   327  func DefaultProgramOptions() ProgramOptions {
   328  	action, err := DefaultAction()
   329  	if err != nil {
   330  		panic(fmt.Sprintf("cannot determine default seccomp action: %v", err))
   331  	}
   332  	return ProgramOptions{
   333  		DefaultAction: action,
   334  		BadArchAction: action,
   335  		Optimize:      true,
   336  	}
   337  }
   338  
   339  // BuildStats contains information about seccomp program generation.
   340  type BuildStats struct {
   341  	// SizeBeforeOptimizations and SizeAfterOptimizations correspond to the
   342  	// number of instructions in the program before vs after optimization.
   343  	SizeBeforeOptimizations, SizeAfterOptimizations int
   344  
   345  	// BuildDuration is the amount of time it took to build the program (before
   346  	// BPF bytecode optimizations).
   347  	BuildDuration time.Duration
   348  
   349  	// RuleOptimizeDuration is the amount of time it took to run SyscallRule
   350  	// optimizations.
   351  	RuleOptimizeDuration time.Duration
   352  
   353  	// BPFOptimizeDuration is the amount of time it took to run BPF bytecode
   354  	// optimizations.
   355  	BPFOptimizeDuration time.Duration
   356  }
   357  
   358  // BuildProgram builds a BPF program from the given map of actions to matching
   359  // SyscallRules. The single generated program covers all provided RuleSets.
   360  func BuildProgram(rules []RuleSet, options ProgramOptions) ([]bpf.Instruction, BuildStats, error) {
   361  	start := time.Now()
   362  	// Make a copy of the syscall rules and maybe optimize them.
   363  	ors, ruleOptimizeDuration, err := orderRuleSets(rules, options)
   364  	if err != nil {
   365  		return nil, BuildStats{}, err
   366  	}
   367  
   368  	possibleActions := make(map[linux.BPFAction]struct{})
   369  	for _, ruleSet := range rules {
   370  		possibleActions[ruleSet.Action] = struct{}{}
   371  	}
   372  
   373  	program := &syscallProgram{
   374  		program: bpf.NewProgramBuilder(),
   375  	}
   376  
   377  	// Be paranoid and check that syscall is done in the expected architecture.
   378  	//
   379  	// A = seccomp_data.arch
   380  	// if (A != AUDIT_ARCH) goto badArchLabel.
   381  	badArchLabel := label("badarch")
   382  	program.Stmt(bpf.Ld|bpf.Abs|bpf.W, seccompDataOffsetArch)
   383  	program.IfNot(bpf.Jmp|bpf.Jeq|bpf.K, LINUX_AUDIT_ARCH, badArchLabel)
   384  	orsFrag := program.Record()
   385  	if err := ors.render(program); err != nil {
   386  		return nil, BuildStats{}, err
   387  	}
   388  	orsFrag.MustHaveJumpedToOrReturned([]label{defaultLabel}, possibleActions)
   389  
   390  	// Default label if none of the rules matched:
   391  	program.Label(defaultLabel)
   392  	program.Ret(options.DefaultAction)
   393  
   394  	// Label if the architecture didn't match:
   395  	program.Label(badArchLabel)
   396  	program.Ret(options.BadArchAction)
   397  
   398  	insns, err := program.program.Instructions()
   399  	if err != nil {
   400  		return nil, BuildStats{}, err
   401  	}
   402  	beforeOpt := len(insns)
   403  	buildDuration := time.Since(start) - ruleOptimizeDuration
   404  	var bpfOptimizeDuration time.Duration
   405  	afterOpt := beforeOpt
   406  	if options.Optimize {
   407  		insns = bpf.Optimize(insns)
   408  		bpfOptimizeDuration = time.Since(start) - buildDuration - ruleOptimizeDuration
   409  		afterOpt = len(insns)
   410  		log.Debugf("Seccomp program optimized from %d to %d instructions; took %v to build and %v to optimize", beforeOpt, afterOpt, buildDuration, bpfOptimizeDuration)
   411  	}
   412  	return insns, BuildStats{
   413  		SizeBeforeOptimizations: beforeOpt,
   414  		SizeAfterOptimizations:  afterOpt,
   415  		BuildDuration:           buildDuration,
   416  		RuleOptimizeDuration:    ruleOptimizeDuration,
   417  		BPFOptimizeDuration:     bpfOptimizeDuration,
   418  	}, nil
   419  }
   420  
   421  // singleSyscallRuleSet represents what to do for a single syscall.
   422  // It is used inside `orderedRules`.
   423  type singleSyscallRuleSet struct {
   424  	sysno    uintptr
   425  	rules    []syscallRuleAction
   426  	vsyscall bool
   427  }
   428  
   429  // Render renders the ruleset for this syscall.
   430  func (ssrs singleSyscallRuleSet) Render(program *syscallProgram, ls *labelSet, noMatch label) {
   431  	frag := program.Record()
   432  	if ssrs.vsyscall {
   433  		// Emit a vsyscall check.
   434  		// This rule ensures that the top bit is set in the
   435  		// instruction pointer, which is where the vsyscall page
   436  		// will be mapped.
   437  		program.Stmt(bpf.Ld|bpf.Abs|bpf.W, seccompDataOffsetIPHigh)
   438  		program.IfNot(bpf.Jmp|bpf.Jset|bpf.K, vsyscallPageIPMask, noMatch)
   439  	}
   440  	var nextRule label
   441  	actions := make(map[linux.BPFAction]struct{})
   442  	for i, ra := range ssrs.rules {
   443  		actions[ra.action] = struct{}{}
   444  
   445  		// Render the rule.
   446  		nextRule = ls.NewLabel()
   447  		ruleLabels := ls.Push(fmt.Sprintf("sysno%d_rule%d", ssrs.sysno, i), ls.NewLabel(), nextRule)
   448  		ruleFrag := program.Record()
   449  		ra.rule.Render(program, ruleLabels)
   450  		program.Label(ruleLabels.Matched())
   451  		program.Ret(ra.action)
   452  		ruleFrag.MustHaveJumpedToOrReturned(
   453  			[]label{nextRule},
   454  			map[linux.BPFAction]struct{}{
   455  				ra.action: struct{}{},
   456  			})
   457  		program.Label(nextRule)
   458  	}
   459  	program.JumpTo(noMatch)
   460  	frag.MustHaveJumpedToOrReturned([]label{noMatch}, actions)
   461  }
   462  
   463  // String returns a human-friendly representation of the
   464  // `singleSyscallRuleSet`.
   465  func (ssrs singleSyscallRuleSet) String() string {
   466  	var sb strings.Builder
   467  	if ssrs.vsyscall {
   468  		sb.WriteString("Vsyscall ")
   469  	} else {
   470  		sb.WriteString("Syscall  ")
   471  	}
   472  	sb.WriteString(fmt.Sprintf("%3d: ", ssrs.sysno))
   473  	switch len(ssrs.rules) {
   474  	case 0:
   475  		sb.WriteString("(no rules)")
   476  	case 1:
   477  		sb.WriteString(ssrs.rules[0].String())
   478  	default:
   479  		sb.WriteRune('{')
   480  		for i, r := range ssrs.rules {
   481  			if i != 0 {
   482  				sb.WriteString("; ")
   483  			}
   484  			sb.WriteString(r.String())
   485  		}
   486  		sb.WriteRune('}')
   487  	}
   488  	return sb.String()
   489  }
   490  
   491  // syscallRuleAction groups a `SyscallRule` and an action that should be
   492  // returned if the rule matches.
   493  type syscallRuleAction struct {
   494  	rule   SyscallRule
   495  	action linux.BPFAction
   496  }
   497  
   498  // String returns a human-friendly representation of the `syscallRuleAction`.
   499  func (sra syscallRuleAction) String() string {
   500  	if _, isMatchAll := sra.rule.(MatchAll); isMatchAll {
   501  		return sra.action.String()
   502  	}
   503  	return fmt.Sprintf("(%v) => %v", sra.rule.String(), sra.action)
   504  }
   505  
   506  // orderedRules contains an ordering of syscall rules used to render a
   507  // program. It is derived from a list of `RuleSet`s and `ProgramOptions`.
   508  // Its fields represent the order in which rulesets are rendered.
   509  // There are three categorization criteria:
   510  //   - "Hot" vs "cold": hot syscalls go first and are checked linearly, cold
   511  //     syscalls go later.
   512  //   - "Trivial" vs "non-trivial": A "trivial" syscall rule means one that
   513  //     does not require checking any argument or RIP data. This basically
   514  //     means a syscall mapped to `MatchAll{}`.
   515  //     If a syscall shows up in multiple RuleSets where any of them is
   516  //     non-trivial, the whole syscall is considered non-trivial.
   517  //   - "vsyscall" vs "non-vsyscall": A syscall that needs vsyscall checking
   518  //     checks that the function is dispatched from the vsyscall page by
   519  //     checking RIP. This inherently makes it non-trivial. All trivial
   520  //     rules are non-vsyscall, but not all non-vsyscall rules are trivial.
   521  type orderedRuleSets struct {
   522  	// hotNonTrivial is the set of hot syscalls that are non-trivial
   523  	// and may or may not require vsyscall checking.
   524  	// They come first and are checked linearly using `hotNonTrivialOrder`.
   525  	hotNonTrivial map[uintptr]singleSyscallRuleSet
   526  
   527  	// hotNonTrivial is the set of hot syscalls that are non-trivial
   528  	// and may or may not require vsyscall checking.
   529  	// They come first and are checked linearly using `hotNonTrivialOrder`.
   530  	hotNonTrivialOrder []uintptr
   531  
   532  	// coldNonTrivial is the set of non-hot syscalls that are non-trivial.
   533  	// They may or may not require vsyscall checking.
   534  	// They come second.
   535  	coldNonTrivial map[uintptr]singleSyscallRuleSet
   536  
   537  	// trivial is the set of syscalls that are trivial. They may or may not be
   538  	// hot, but they may not require vsyscall checking (otherwise they would
   539  	// be non-trivial).
   540  	// They come last. This is because the host kernel will cache the results
   541  	// of these system calls, and will never execute them on the hot path.
   542  	trivial map[uintptr]singleSyscallRuleSet
   543  }
   544  
   545  // orderRuleSets converts a set of `RuleSet`s into an `orderedRuleSets`.
   546  // It orders the rulesets, along with the time to optimize the
   547  // rules (if any).
   548  func orderRuleSets(rules []RuleSet, options ProgramOptions) (orderedRuleSets, time.Duration, error) {
   549  	// Do a pass to determine if vsyscall is consistent across syscall numbers.
   550  	vsyscallBySysno := make(map[uintptr]bool)
   551  	for _, rs := range rules {
   552  		for sysno := range rs.Rules.rules {
   553  			if prevVsyscall, ok := vsyscallBySysno[sysno]; ok {
   554  				if prevVsyscall != rs.Vsyscall {
   555  					return orderedRuleSets{}, 0, fmt.Errorf("syscall %d has conflicting vsyscall checking rules", sysno)
   556  				}
   557  			} else {
   558  				vsyscallBySysno[sysno] = rs.Vsyscall
   559  			}
   560  		}
   561  	}
   562  
   563  	// Build a single map of per-syscall syscallRuleActions.
   564  	// We will split this map up later.
   565  	allSyscallRuleActions := make(map[uintptr][]syscallRuleAction)
   566  	for _, rs := range rules {
   567  		for sysno, rule := range rs.Rules.rules {
   568  			existing, found := allSyscallRuleActions[sysno]
   569  			if !found {
   570  				allSyscallRuleActions[sysno] = []syscallRuleAction{{
   571  					rule:   rule,
   572  					action: rs.Action,
   573  				}}
   574  				continue
   575  			}
   576  			if existing[len(existing)-1].action == rs.Action {
   577  				// If the last action for this syscall is the same, union the rules.
   578  				existing[len(existing)-1].rule = Or{existing[len(existing)-1].rule, rule}
   579  			} else {
   580  				// Otherwise, add it as a new ruleset.
   581  				existing = append(existing, syscallRuleAction{
   582  					rule:   rule,
   583  					action: rs.Action,
   584  				})
   585  			}
   586  			allSyscallRuleActions[sysno] = existing
   587  		}
   588  	}
   589  
   590  	// Optimize all rules.
   591  	var optimizeDuration time.Duration
   592  	if options.Optimize {
   593  		optimizeStart := time.Now()
   594  		for _, ruleActions := range allSyscallRuleActions {
   595  			for i, ra := range ruleActions {
   596  				ra.rule = optimizeSyscallRule(ra.rule)
   597  				ruleActions[i] = ra
   598  			}
   599  		}
   600  		optimizeDuration = time.Since(optimizeStart)
   601  	}
   602  
   603  	// Do a pass that checks which syscall numbers are trivial.
   604  	isTrivial := make(map[uintptr]bool)
   605  	for sysno, ruleActions := range allSyscallRuleActions {
   606  		for _, ra := range ruleActions {
   607  			_, isMatchAll := ra.rule.(MatchAll)
   608  			isVsyscall := vsyscallBySysno[sysno]
   609  			trivial := isMatchAll && !isVsyscall
   610  			if prevTrivial, ok := isTrivial[sysno]; ok {
   611  				isTrivial[sysno] = prevTrivial && trivial
   612  			} else {
   613  				isTrivial[sysno] = trivial
   614  			}
   615  		}
   616  	}
   617  
   618  	// Compute the set of non-trivial hot syscalls and their order.
   619  	hotNonTrivialSyscallsIndex := make(map[uintptr]int, len(options.HotSyscalls))
   620  	for i, sysno := range options.HotSyscalls {
   621  		if _, hasRule := allSyscallRuleActions[sysno]; !hasRule {
   622  			continue
   623  		}
   624  		if isTrivial[sysno] {
   625  			continue
   626  		}
   627  		if _, ok := hotNonTrivialSyscallsIndex[sysno]; ok {
   628  			continue
   629  		}
   630  		hotNonTrivialSyscallsIndex[sysno] = i
   631  	}
   632  	hotNonTrivialOrder := make([]uintptr, 0, len(hotNonTrivialSyscallsIndex))
   633  	for sysno := range hotNonTrivialSyscallsIndex {
   634  		hotNonTrivialOrder = append(hotNonTrivialOrder, sysno)
   635  	}
   636  	sort.Slice(hotNonTrivialOrder, func(i, j int) bool {
   637  		return hotNonTrivialSyscallsIndex[hotNonTrivialOrder[i]] < hotNonTrivialSyscallsIndex[hotNonTrivialOrder[j]]
   638  	})
   639  
   640  	// Now split up the map and build the `orderedRuleSets`.
   641  	ors := orderedRuleSets{
   642  		hotNonTrivial:      make(map[uintptr]singleSyscallRuleSet),
   643  		hotNonTrivialOrder: hotNonTrivialOrder,
   644  		coldNonTrivial:     make(map[uintptr]singleSyscallRuleSet),
   645  		trivial:            make(map[uintptr]singleSyscallRuleSet),
   646  	}
   647  	for sysno, ruleActions := range allSyscallRuleActions {
   648  		_, hot := hotNonTrivialSyscallsIndex[sysno]
   649  		trivial := isTrivial[sysno]
   650  		var subMap map[uintptr]singleSyscallRuleSet
   651  		switch {
   652  		case trivial:
   653  			subMap = ors.trivial
   654  		case hot:
   655  			subMap = ors.hotNonTrivial
   656  		default:
   657  			subMap = ors.coldNonTrivial
   658  		}
   659  		subMap[sysno] = singleSyscallRuleSet{
   660  			sysno:    sysno,
   661  			vsyscall: vsyscallBySysno[sysno],
   662  			rules:    ruleActions,
   663  		}
   664  	}
   665  
   666  	// Log our findings.
   667  	if log.IsLogging(log.Debug) {
   668  		ors.log(log.Debugf)
   669  	}
   670  
   671  	return ors, optimizeDuration, nil
   672  }
   673  
   674  // log logs the set of seccomp rules to the given logger.
   675  func (ors orderedRuleSets) log(logFn func(string, ...any)) {
   676  	logFn("Ordered seccomp rules:")
   677  	for _, sm := range []struct {
   678  		name string
   679  		m    map[uintptr]singleSyscallRuleSet
   680  	}{
   681  		{"Hot non-trivial", ors.hotNonTrivial},
   682  		{"Cold non-trivial", ors.coldNonTrivial},
   683  		{"Trivial", ors.trivial},
   684  	} {
   685  		if len(sm.m) == 0 {
   686  			logFn("  %s syscalls: None.", sm.name)
   687  			continue
   688  		}
   689  		logFn("  %s syscalls:", sm.name)
   690  		orderedSysnos := make([]int, 0, len(sm.m))
   691  		for sysno := range sm.m {
   692  			orderedSysnos = append(orderedSysnos, int(sysno))
   693  		}
   694  		sort.Ints(orderedSysnos)
   695  		for _, sysno := range orderedSysnos {
   696  			logFn("    - %s", sm.m[uintptr(sysno)].String())
   697  		}
   698  	}
   699  	logFn("End of ordered seccomp rules.")
   700  }
   701  
   702  // render renders all rulesets in the given program.
   703  func (ors orderedRuleSets) render(program *syscallProgram) error {
   704  	ls := &labelSet{prefix: string("ors")}
   705  
   706  	// totalFrag wraps the entire output of the `render` function.
   707  	totalFrag := program.Record()
   708  
   709  	// Load syscall number into register A.
   710  	program.Stmt(bpf.Ld|bpf.Abs|bpf.W, seccompDataOffsetNR)
   711  
   712  	// Keep track of which syscalls we've already looked for.
   713  	sysnosChecked := make(map[uintptr]struct{})
   714  
   715  	// First render hot syscalls linearly.
   716  	if len(ors.hotNonTrivialOrder) > 0 {
   717  		notHotLabel := ls.NewLabel()
   718  		// hotFrag wraps the "hot syscalls" part of the program.
   719  		// It must either return one of `hotActions`, or jump to `defaultLabel` if
   720  		// the syscall number matched but the vsyscall match failed, or
   721  		// `notHotLabel` if none of the hot syscall numbers matched.
   722  		hotFrag := program.Record()
   723  		possibleActions := ors.renderLinear(program, ls, sysnosChecked, ors.hotNonTrivial, ors.hotNonTrivialOrder, notHotLabel)
   724  		hotFrag.MustHaveJumpedToOrReturned([]label{notHotLabel, defaultLabel}, possibleActions)
   725  		program.Label(notHotLabel)
   726  	}
   727  
   728  	// Now render the cold non-trivial rules as a binary search tree:
   729  	if len(ors.coldNonTrivial) > 0 {
   730  		frag := program.Record()
   731  		noSycallNumberMatch := ls.NewLabel()
   732  		possibleActions, err := ors.renderBST(program, ls, sysnosChecked, ors.coldNonTrivial, noSycallNumberMatch)
   733  		if err != nil {
   734  			return err
   735  		}
   736  		frag.MustHaveJumpedToOrReturned([]label{noSycallNumberMatch, defaultLabel}, possibleActions)
   737  		program.Label(noSycallNumberMatch)
   738  	}
   739  
   740  	// Finally render the trivial rules as a binary search tree:
   741  	if len(ors.trivial) > 0 {
   742  		frag := program.Record()
   743  		noSycallNumberMatch := ls.NewLabel()
   744  		possibleActions, err := ors.renderBST(program, ls, sysnosChecked, ors.trivial, noSycallNumberMatch)
   745  		if err != nil {
   746  			return err
   747  		}
   748  		frag.MustHaveJumpedToOrReturned([]label{noSycallNumberMatch, defaultLabel}, possibleActions)
   749  		program.Label(noSycallNumberMatch)
   750  	}
   751  	program.JumpTo(defaultLabel)
   752  
   753  	// Reached the end of the program.
   754  	// Independently verify the set of all possible actions.
   755  	allPossibleActions := make(map[linux.BPFAction]struct{})
   756  	for _, mapping := range []map[uintptr]singleSyscallRuleSet{
   757  		ors.hotNonTrivial,
   758  		ors.coldNonTrivial,
   759  		ors.trivial,
   760  	} {
   761  		for _, ssrs := range mapping {
   762  			for _, ra := range ssrs.rules {
   763  				allPossibleActions[ra.action] = struct{}{}
   764  			}
   765  		}
   766  	}
   767  	totalFrag.MustHaveJumpedToOrReturned([]label{defaultLabel}, allPossibleActions)
   768  	return nil
   769  }
   770  
   771  // renderLinear renders linear search code that searches for syscall matches
   772  // in the given order. It assumes the syscall number is loaded into register
   773  // A. Rulesets for all syscall numbers in `order` must exist in `syscallMap`.
   774  // It returns the list of possible actions the generated code may return.
   775  // `alreadyChecked` will be updated with the syscalls that have been checked.
   776  func (ors orderedRuleSets) renderLinear(program *syscallProgram, ls *labelSet, alreadyChecked map[uintptr]struct{}, syscallMap map[uintptr]singleSyscallRuleSet, order []uintptr, noSycallNumberMatch label) map[linux.BPFAction]struct{} {
   777  	allActions := make(map[linux.BPFAction]struct{})
   778  	for _, sysno := range order {
   779  		ssrs, found := syscallMap[sysno]
   780  		if !found {
   781  			panic(fmt.Sprintf("syscall %d found in linear order but not map", sysno))
   782  		}
   783  		nextSyscall := ls.NewLabel()
   784  		// sysnoFrag wraps the "statements about this syscall number" part of
   785  		// the program. It must either return one of the actions specified in
   786  		// that syscall number's rules (`sysnoActions`), or jump to
   787  		// `nextSyscall`.
   788  		sysnoFrag := program.Record()
   789  		sysnoActions := make(map[linux.BPFAction]struct{})
   790  		for _, ra := range ssrs.rules {
   791  			sysnoActions[ra.action] = struct{}{}
   792  			allActions[ra.action] = struct{}{}
   793  		}
   794  		program.IfNot(bpf.Jmp|bpf.Jeq|bpf.K, uint32(ssrs.sysno), nextSyscall)
   795  		ssrs.Render(program, ls, defaultLabel)
   796  		sysnoFrag.MustHaveJumpedToOrReturned([]label{nextSyscall, defaultLabel}, sysnoActions)
   797  		program.Label(nextSyscall)
   798  	}
   799  	program.JumpTo(noSycallNumberMatch)
   800  	for _, sysno := range order {
   801  		alreadyChecked[sysno] = struct{}{}
   802  	}
   803  	return allActions
   804  }
   805  
   806  // renderBST renders a binary search tree that searches the given map of
   807  // syscalls. It assumes the syscall number is loaded into register A.
   808  // It returns the list of possible actions the generated code may return.
   809  // `alreadyChecked` will be updated with the syscalls that the BST has
   810  // searched.
   811  func (ors orderedRuleSets) renderBST(program *syscallProgram, ls *labelSet, alreadyChecked map[uintptr]struct{}, syscallMap map[uintptr]singleSyscallRuleSet, noSycallNumberMatch label) (map[linux.BPFAction]struct{}, error) {
   812  	possibleActions := make(map[linux.BPFAction]struct{})
   813  	orderedSysnos := make([]uintptr, 0, len(syscallMap))
   814  	for sysno, ruleActions := range syscallMap {
   815  		orderedSysnos = append(orderedSysnos, sysno)
   816  		for _, ra := range ruleActions.rules {
   817  			possibleActions[ra.action] = struct{}{}
   818  		}
   819  	}
   820  	sort.Slice(orderedSysnos, func(i, j int) bool {
   821  		return orderedSysnos[i] < orderedSysnos[j]
   822  	})
   823  	frag := program.Record()
   824  	root := createBST(orderedSysnos)
   825  	root.root = true
   826  	knownRng := knownRange{
   827  		lowerBoundExclusive: -1,
   828  		// sysno fits in 32 bits, so this is definitely out of bounds:
   829  		upperBoundExclusive: 1 << 32,
   830  		previouslyChecked:   alreadyChecked,
   831  	}
   832  	if err := root.traverse(renderBSTTraversal, knownRng, syscallMap, program, noSycallNumberMatch); err != nil {
   833  		return nil, err
   834  	}
   835  	if err := root.traverse(renderBSTRules, knownRng, syscallMap, program, noSycallNumberMatch); err != nil {
   836  		return nil, err
   837  	}
   838  	frag.MustHaveJumpedToOrReturned([]label{noSycallNumberMatch, defaultLabel}, possibleActions)
   839  	for sysno := range syscallMap {
   840  		alreadyChecked[sysno] = struct{}{}
   841  	}
   842  	return possibleActions, nil
   843  }
   844  
   845  // createBST converts sorted syscall slice into a balanced BST.
   846  // Panics if syscalls is empty.
   847  func createBST(syscalls []uintptr) *node {
   848  	i := len(syscalls) / 2
   849  	parent := node{value: syscalls[i]}
   850  	if i > 0 {
   851  		parent.left = createBST(syscalls[:i])
   852  	}
   853  	if i+1 < len(syscalls) {
   854  		parent.right = createBST(syscalls[i+1:])
   855  	}
   856  	return &parent
   857  }
   858  
   859  // renderBSTTraversal renders the traversal bytecode for a binary search tree.
   860  // The outline of the code is as follows, given a BST with:
   861  //
   862  //		     22
   863  //		    /  \
   864  //		   9    24
   865  //		  /    /  \
   866  //	   8   23    50
   867  //
   868  //		index_22: // SYS_PIPE(22), root
   869  //		(A < 22) ? goto index_9  : continue
   870  //		(A > 22) ? goto index_24 : continue
   871  //		goto checkArgs_22
   872  //
   873  //		index_9: // SYS_MMAP(9), single child
   874  //		(A < 9)  ? goto index_8  : continue
   875  //		(A == 9) ? continue : goto defaultLabel
   876  //		goto checkArgs_9
   877  //
   878  //		index_8: // SYS_LSEEK(8), leaf
   879  //		(A == 8) ? continue : goto defaultLabel
   880  //		goto checkArgs_8
   881  //
   882  //		index_24: // SYS_SCHED_YIELD(24)
   883  //		(A < 24) ? goto index_23 : continue
   884  //		(A > 22) ? goto index_50 : continue
   885  //		goto checkArgs_24
   886  //
   887  //		index_23: // SYS_SELECT(23), leaf with parent nodes adjacent in value
   888  //		# Notice that we do not check for equality at all here, since we've
   889  //		# already established that the syscall number is 23 from the
   890  //		# two parent nodes that we've already traversed.
   891  //		# This is tracked in the `rng knownRange` argument during traversal.
   892  //		goto rules_23
   893  //
   894  //		index_50: // SYS_LISTEN(50), leaf
   895  //		(A == 50) ? continue : goto defaultLabel
   896  //		goto checkArgs_50
   897  //
   898  // All of the "checkArgs_XYZ" labels are not defined in this function; they
   899  // are created using the `renderBSTRules` function, which is expected to be
   900  // called after this one on the entire BST.
   901  func renderBSTTraversal(n *node, rng knownRange, syscallMap map[uintptr]singleSyscallRuleSet, program *syscallProgram, searchFailed label) error {
   902  	// Root node is never referenced by label, skip it.
   903  	if !n.root {
   904  		program.Label(n.label())
   905  	}
   906  	sysno := n.value
   907  	nodeFrag := program.Record()
   908  	checkArgsLabel := label(fmt.Sprintf("checkArgs_%d", sysno))
   909  	if n.left != nil {
   910  		program.IfNot(bpf.Jmp|bpf.Jge|bpf.K, uint32(sysno), n.left.label())
   911  		rng.lowerBoundExclusive = int(sysno - 1)
   912  	}
   913  	if n.right != nil {
   914  		program.If(bpf.Jmp|bpf.Jgt|bpf.K, uint32(sysno), n.right.label())
   915  		rng.upperBoundExclusive = int(sysno + 1)
   916  	}
   917  	if rng.lowerBoundExclusive != int(sysno-1) || rng.upperBoundExclusive != int(sysno+1) {
   918  		// If the previous BST nodes we traversed haven't fully established
   919  		// that the current node's syscall value is exactly `sysno`, we still
   920  		// need to verify it.
   921  		program.IfNot(bpf.Jmp|bpf.Jeq|bpf.K, uint32(sysno), searchFailed)
   922  	}
   923  	program.JumpTo(checkArgsLabel)
   924  	nodeFrag.MustHaveJumpedTo(n.left.label(), n.right.label(), checkArgsLabel, searchFailed)
   925  	return nil
   926  }
   927  
   928  // renderBSTRules renders the `checkArgs_XYZ` labels that `renderBSTTraversal`
   929  // jumps to as part of the BST traversal code. It contains all the
   930  // argument-specific syscall rules for each syscall number.
   931  func renderBSTRules(n *node, rng knownRange, syscallMap map[uintptr]singleSyscallRuleSet, program *syscallProgram, searchFailed label) error {
   932  	sysno := n.value
   933  	checkArgsLabel := label(fmt.Sprintf("checkArgs_%d", sysno))
   934  	program.Label(checkArgsLabel)
   935  	ruleSetsFrag := program.Record()
   936  	possibleActions := make(map[linux.BPFAction]struct{})
   937  	for _, ra := range syscallMap[sysno].rules {
   938  		possibleActions[ra.action] = struct{}{}
   939  	}
   940  	nodeLabelSet := &labelSet{prefix: string(n.label())}
   941  	syscallMap[sysno].Render(program, nodeLabelSet, defaultLabel)
   942  	ruleSetsFrag.MustHaveJumpedToOrReturned(
   943  		[]label{
   944  			defaultLabel, // Either we jumped to the default label (if the rules didn't match)...
   945  		},
   946  		possibleActions, // ... or we returned one of the actions of the rulesets.
   947  	)
   948  	return nil
   949  }
   950  
   951  // node represents a tree node.
   952  type node struct {
   953  	value uintptr
   954  	left  *node
   955  	right *node
   956  	root  bool
   957  }
   958  
   959  // label returns the label corresponding to this node.
   960  //
   961  // If n is nil, then the defaultLabel is returned.
   962  func (n *node) label() label {
   963  	if n == nil {
   964  		return defaultLabel
   965  	}
   966  	return label(fmt.Sprintf("node_%d", n.value))
   967  }
   968  
   969  // knownRange represents the known set of node numbers that we've
   970  // already checked. This is used as part of BST traversal.
   971  type knownRange struct {
   972  	lowerBoundExclusive int
   973  	upperBoundExclusive int
   974  
   975  	// alreadyChecked is a set of node values that were already checked
   976  	// earlier in the program (prior to the BST being built).
   977  	// It is *not* updated during BST traversal.
   978  	previouslyChecked map[uintptr]struct{}
   979  }
   980  
   981  // withLowerBoundExclusive returns an updated `knownRange` with the given
   982  // new exclusive lower bound. The actual exclusive lower bound of the
   983  // returned `knownRange` may be higher, in case `previouslyChecked` covers
   984  // more numbers.
   985  func (kr knownRange) withLowerBoundExclusive(newLowerBoundExclusive int) knownRange {
   986  	nkr := knownRange{
   987  		lowerBoundExclusive: newLowerBoundExclusive,
   988  		upperBoundExclusive: kr.upperBoundExclusive,
   989  		previouslyChecked:   kr.previouslyChecked,
   990  	}
   991  	for ; nkr.lowerBoundExclusive < nkr.upperBoundExclusive; nkr.lowerBoundExclusive++ {
   992  		if _, ok := nkr.previouslyChecked[uintptr(nkr.lowerBoundExclusive+1)]; !ok {
   993  			break
   994  		}
   995  	}
   996  	return nkr
   997  }
   998  
   999  // withUpperBoundExclusive returns an updated `knownRange` with the given
  1000  // new exclusive upper bound. The actual exclusive upper bound of the
  1001  // returned `knownRange` may be lower, in case `previouslyChecked` covers
  1002  // more numbers.
  1003  func (kr knownRange) withUpperBoundExclusive(newUpperBoundExclusive int) knownRange {
  1004  	nkr := knownRange{
  1005  		lowerBoundExclusive: kr.lowerBoundExclusive,
  1006  		upperBoundExclusive: newUpperBoundExclusive,
  1007  		previouslyChecked:   kr.previouslyChecked,
  1008  	}
  1009  	for ; nkr.lowerBoundExclusive < nkr.upperBoundExclusive; nkr.upperBoundExclusive-- {
  1010  		if _, ok := nkr.previouslyChecked[uintptr(nkr.upperBoundExclusive-1)]; !ok {
  1011  			break
  1012  		}
  1013  	}
  1014  	return nkr
  1015  }
  1016  
  1017  // traverseFunc is called as the BST is traversed.
  1018  type traverseFunc func(*node, knownRange, map[uintptr]singleSyscallRuleSet, *syscallProgram, label) error
  1019  
  1020  func (n *node) traverse(fn traverseFunc, kr knownRange, syscallMap map[uintptr]singleSyscallRuleSet, program *syscallProgram, searchFailed label) error {
  1021  	if n == nil {
  1022  		return nil
  1023  	}
  1024  	if err := fn(n, kr, syscallMap, program, searchFailed); err != nil {
  1025  		return err
  1026  	}
  1027  	if err := n.left.traverse(
  1028  		fn,
  1029  		kr.withUpperBoundExclusive(int(n.value)),
  1030  		syscallMap,
  1031  		program,
  1032  		searchFailed,
  1033  	); err != nil {
  1034  		return err
  1035  	}
  1036  	return n.right.traverse(
  1037  		fn,
  1038  		kr.withLowerBoundExclusive(int(n.value)),
  1039  		syscallMap,
  1040  		program,
  1041  		searchFailed,
  1042  	)
  1043  }
  1044  
  1045  // DataAsBPFInput converts a linux.SeccompData to a bpf.Input.
  1046  // It uses `buf` as scratch buffer; this buffer must be wide enough
  1047  // to accommodate a mashaled version of `d`.
  1048  func DataAsBPFInput(d *linux.SeccompData, buf []byte) bpf.Input {
  1049  	if len(buf) < d.SizeBytes() {
  1050  		panic(fmt.Sprintf("buffer must be at least %d bytes long", d.SizeBytes()))
  1051  	}
  1052  	d.MarshalUnsafe(buf)
  1053  	return buf[:d.SizeBytes()]
  1054  }