github.com/khulnasoft-lab/defsec@v1.0.5-0.20230827010352-5e9f46893d95/internal/rules/register.go (about) 1 package rules 2 3 import ( 4 "sync" 5 6 "github.com/khulnasoft-lab/defsec/pkg/framework" 7 "github.com/khulnasoft-lab/defsec/pkg/scan" 8 "github.com/khulnasoft-lab/defsec/pkg/state" 9 "github.com/khulnasoft-lab/defsec/pkg/types" 10 "github.com/khulnasoft-lab/defsec/rules/specs" 11 "gopkg.in/yaml.v3" 12 ) 13 14 type RegisteredRule struct { 15 number int 16 rule scan.Rule 17 checkFunc scan.CheckFunc 18 } 19 20 func (r RegisteredRule) HasLogic() bool { 21 return r.checkFunc != nil 22 } 23 24 func (r RegisteredRule) Evaluate(s *state.State) scan.Results { 25 if r.checkFunc == nil { 26 return nil 27 } 28 results := r.checkFunc(s) 29 for i := range results { 30 results[i].SetRule(r.rule) 31 } 32 return results 33 } 34 35 func (r RegisteredRule) Rule() scan.Rule { 36 return r.rule 37 } 38 39 func (r *RegisteredRule) AddLink(link string) { 40 r.rule.Links = append([]string{link}, r.rule.Links...) 41 } 42 43 type registry struct { 44 sync.RWMutex 45 index int 46 frameworks map[framework.Framework][]RegisteredRule 47 } 48 49 var coreRegistry = registry{ 50 frameworks: make(map[framework.Framework][]RegisteredRule), 51 } 52 53 func Reset() { 54 coreRegistry.Reset() 55 } 56 57 func Register(rule scan.Rule, f scan.CheckFunc) RegisteredRule { 58 return coreRegistry.register(rule, f) 59 } 60 61 func Deregister(rule RegisteredRule) { 62 coreRegistry.deregister(rule) 63 } 64 65 func (r *registry) register(rule scan.Rule, f scan.CheckFunc) RegisteredRule { 66 r.Lock() 67 defer r.Unlock() 68 if len(rule.Frameworks) == 0 { 69 rule.Frameworks = map[framework.Framework][]string{framework.Default: nil} 70 } 71 registeredRule := RegisteredRule{ 72 number: r.index, 73 rule: rule, 74 checkFunc: f, 75 } 76 r.index++ 77 for fw := range rule.Frameworks { 78 r.frameworks[fw] = append(r.frameworks[fw], registeredRule) 79 } 80 81 r.frameworks[framework.ALL] = append(r.frameworks[framework.ALL], registeredRule) 82 83 return registeredRule 84 } 85 86 func (r *registry) deregister(rule RegisteredRule) { 87 r.Lock() 88 defer r.Unlock() 89 for fw := range r.frameworks { 90 for i, registered := range r.frameworks[fw] { 91 if registered.number == rule.number { 92 r.frameworks[fw] = append(r.frameworks[fw][:i], r.frameworks[fw][i+1:]...) 93 break 94 } 95 } 96 } 97 } 98 99 func (r *registry) getFrameworkRules(fw ...framework.Framework) []RegisteredRule { 100 r.RLock() 101 defer r.RUnlock() 102 var registered []RegisteredRule 103 if len(fw) == 0 { 104 fw = []framework.Framework{framework.Default} 105 } 106 unique := make(map[int]struct{}) 107 for _, f := range fw { 108 for _, rule := range r.frameworks[f] { 109 if _, ok := unique[rule.number]; ok { 110 continue 111 } 112 registered = append(registered, rule) 113 unique[rule.number] = struct{}{} 114 } 115 } 116 return registered 117 } 118 119 func (r *registry) getSpecRules(spec string) []RegisteredRule { 120 r.RLock() 121 defer r.RUnlock() 122 var specRules []RegisteredRule 123 124 var complianceSpec types.ComplianceSpec 125 specContent := specs.GetSpec(spec) 126 if err := yaml.Unmarshal([]byte(specContent), &complianceSpec); err != nil { 127 return nil 128 } 129 130 registered := r.getFrameworkRules(framework.ALL) 131 for _, rule := range registered { 132 for _, csRule := range complianceSpec.Spec.Controls { 133 if len(csRule.Checks) > 0 { 134 for _, c := range csRule.Checks { 135 if rule.Rule().AVDID == c.ID { 136 specRules = append(specRules, rule) 137 } 138 } 139 } 140 } 141 } 142 143 return specRules 144 } 145 146 func (r *registry) Reset() { 147 r.Lock() 148 defer r.Unlock() 149 r.frameworks = make(map[framework.Framework][]RegisteredRule) 150 } 151 152 func GetFrameworkRules(fw ...framework.Framework) []RegisteredRule { 153 return coreRegistry.getFrameworkRules(fw...) 154 } 155 156 func GetSpecRules(spec string) []RegisteredRule { 157 if len(spec) > 0 { 158 return coreRegistry.getSpecRules(spec) 159 } 160 161 return GetFrameworkRules() 162 }