github.com/khulnasoft-lab/defsec@v1.0.5-0.20230827010352-5e9f46893d95/internal/rules/register.go (about)

     1  package rules
     2  
     3  import (
     4  	"sync"
     5  
     6  	"github.com/khulnasoft-lab/defsec/pkg/framework"
     7  	"github.com/khulnasoft-lab/defsec/pkg/scan"
     8  	"github.com/khulnasoft-lab/defsec/pkg/state"
     9  	"github.com/khulnasoft-lab/defsec/pkg/types"
    10  	"github.com/khulnasoft-lab/defsec/rules/specs"
    11  	"gopkg.in/yaml.v3"
    12  )
    13  
    14  type RegisteredRule struct {
    15  	number    int
    16  	rule      scan.Rule
    17  	checkFunc scan.CheckFunc
    18  }
    19  
    20  func (r RegisteredRule) HasLogic() bool {
    21  	return r.checkFunc != nil
    22  }
    23  
    24  func (r RegisteredRule) Evaluate(s *state.State) scan.Results {
    25  	if r.checkFunc == nil {
    26  		return nil
    27  	}
    28  	results := r.checkFunc(s)
    29  	for i := range results {
    30  		results[i].SetRule(r.rule)
    31  	}
    32  	return results
    33  }
    34  
    35  func (r RegisteredRule) Rule() scan.Rule {
    36  	return r.rule
    37  }
    38  
    39  func (r *RegisteredRule) AddLink(link string) {
    40  	r.rule.Links = append([]string{link}, r.rule.Links...)
    41  }
    42  
    43  type registry struct {
    44  	sync.RWMutex
    45  	index      int
    46  	frameworks map[framework.Framework][]RegisteredRule
    47  }
    48  
    49  var coreRegistry = registry{
    50  	frameworks: make(map[framework.Framework][]RegisteredRule),
    51  }
    52  
    53  func Reset() {
    54  	coreRegistry.Reset()
    55  }
    56  
    57  func Register(rule scan.Rule, f scan.CheckFunc) RegisteredRule {
    58  	return coreRegistry.register(rule, f)
    59  }
    60  
    61  func Deregister(rule RegisteredRule) {
    62  	coreRegistry.deregister(rule)
    63  }
    64  
    65  func (r *registry) register(rule scan.Rule, f scan.CheckFunc) RegisteredRule {
    66  	r.Lock()
    67  	defer r.Unlock()
    68  	if len(rule.Frameworks) == 0 {
    69  		rule.Frameworks = map[framework.Framework][]string{framework.Default: nil}
    70  	}
    71  	registeredRule := RegisteredRule{
    72  		number:    r.index,
    73  		rule:      rule,
    74  		checkFunc: f,
    75  	}
    76  	r.index++
    77  	for fw := range rule.Frameworks {
    78  		r.frameworks[fw] = append(r.frameworks[fw], registeredRule)
    79  	}
    80  
    81  	r.frameworks[framework.ALL] = append(r.frameworks[framework.ALL], registeredRule)
    82  
    83  	return registeredRule
    84  }
    85  
    86  func (r *registry) deregister(rule RegisteredRule) {
    87  	r.Lock()
    88  	defer r.Unlock()
    89  	for fw := range r.frameworks {
    90  		for i, registered := range r.frameworks[fw] {
    91  			if registered.number == rule.number {
    92  				r.frameworks[fw] = append(r.frameworks[fw][:i], r.frameworks[fw][i+1:]...)
    93  				break
    94  			}
    95  		}
    96  	}
    97  }
    98  
    99  func (r *registry) getFrameworkRules(fw ...framework.Framework) []RegisteredRule {
   100  	r.RLock()
   101  	defer r.RUnlock()
   102  	var registered []RegisteredRule
   103  	if len(fw) == 0 {
   104  		fw = []framework.Framework{framework.Default}
   105  	}
   106  	unique := make(map[int]struct{})
   107  	for _, f := range fw {
   108  		for _, rule := range r.frameworks[f] {
   109  			if _, ok := unique[rule.number]; ok {
   110  				continue
   111  			}
   112  			registered = append(registered, rule)
   113  			unique[rule.number] = struct{}{}
   114  		}
   115  	}
   116  	return registered
   117  }
   118  
   119  func (r *registry) getSpecRules(spec string) []RegisteredRule {
   120  	r.RLock()
   121  	defer r.RUnlock()
   122  	var specRules []RegisteredRule
   123  
   124  	var complianceSpec types.ComplianceSpec
   125  	specContent := specs.GetSpec(spec)
   126  	if err := yaml.Unmarshal([]byte(specContent), &complianceSpec); err != nil {
   127  		return nil
   128  	}
   129  
   130  	registered := r.getFrameworkRules(framework.ALL)
   131  	for _, rule := range registered {
   132  		for _, csRule := range complianceSpec.Spec.Controls {
   133  			if len(csRule.Checks) > 0 {
   134  				for _, c := range csRule.Checks {
   135  					if rule.Rule().AVDID == c.ID {
   136  						specRules = append(specRules, rule)
   137  					}
   138  				}
   139  			}
   140  		}
   141  	}
   142  
   143  	return specRules
   144  }
   145  
   146  func (r *registry) Reset() {
   147  	r.Lock()
   148  	defer r.Unlock()
   149  	r.frameworks = make(map[framework.Framework][]RegisteredRule)
   150  }
   151  
   152  func GetFrameworkRules(fw ...framework.Framework) []RegisteredRule {
   153  	return coreRegistry.getFrameworkRules(fw...)
   154  }
   155  
   156  func GetSpecRules(spec string) []RegisteredRule {
   157  	if len(spec) > 0 {
   158  		return coreRegistry.getSpecRules(spec)
   159  	}
   160  
   161  	return GetFrameworkRules()
   162  }