github.com/khulnasoft-lab/defsec@v1.0.5-0.20230827010352-5e9f46893d95/internal/rules/register_test.go (about) 1 package rules 2 3 import ( 4 "fmt" 5 "testing" 6 7 "github.com/stretchr/testify/require" 8 9 "github.com/khulnasoft-lab/defsec/pkg/framework" 10 "github.com/khulnasoft-lab/defsec/pkg/scan" 11 "github.com/stretchr/testify/assert" 12 ) 13 14 func Test_Reset(t *testing.T) { 15 rule := scan.Rule{} 16 _ = Register(rule, nil) 17 assert.Equal(t, 1, len(GetFrameworkRules())) 18 Reset() 19 assert.Equal(t, 0, len(GetFrameworkRules())) 20 } 21 22 func Test_Registration(t *testing.T) { 23 var tests = []struct { 24 name string 25 registeredFrameworks map[framework.Framework][]string 26 inputFrameworks []framework.Framework 27 expected bool 28 }{ 29 { 30 name: "rule without framework specified should be returned when no frameworks are requested", 31 expected: true, 32 }, 33 { 34 name: "rule without framework specified should not be returned when a specific framework is requested", 35 inputFrameworks: []framework.Framework{framework.CIS_AWS_1_2}, 36 expected: false, 37 }, 38 { 39 name: "rule without framework specified should be returned when the default framework is requested", 40 inputFrameworks: []framework.Framework{framework.Default}, 41 expected: true, 42 }, 43 { 44 name: "rule with default framework specified should be returned when the default framework is requested", 45 registeredFrameworks: map[framework.Framework][]string{framework.Default: {"1.1"}}, 46 inputFrameworks: []framework.Framework{framework.Default}, 47 expected: true, 48 }, 49 { 50 name: "rule with default framework specified should not be returned when a specific framework is requested", 51 registeredFrameworks: map[framework.Framework][]string{framework.Default: {"1.1"}}, 52 inputFrameworks: []framework.Framework{framework.CIS_AWS_1_2}, 53 expected: false, 54 }, 55 { 56 name: "rule with specific framework specified should not be returned when a default framework is requested", 57 registeredFrameworks: map[framework.Framework][]string{framework.CIS_AWS_1_2: {"1.1"}}, 58 inputFrameworks: []framework.Framework{framework.Default}, 59 expected: false, 60 }, 61 { 62 name: "rule with specific framework specified should be returned when the specific framework is requested", 63 registeredFrameworks: map[framework.Framework][]string{framework.CIS_AWS_1_2: {"1.1"}}, 64 inputFrameworks: []framework.Framework{framework.CIS_AWS_1_2}, 65 expected: true, 66 }, 67 { 68 name: "rule with multiple frameworks specified should be returned when the specific framework is requested", 69 registeredFrameworks: map[framework.Framework][]string{framework.CIS_AWS_1_2: {"1.1"}, "blah": {"1.2"}}, 70 inputFrameworks: []framework.Framework{framework.CIS_AWS_1_2}, 71 expected: true, 72 }, 73 { 74 name: "rule with multiple frameworks specified should be returned only once when multiple matching frameworks are requested", 75 registeredFrameworks: map[framework.Framework][]string{framework.CIS_AWS_1_2: {"1.1"}, "blah": {"1.2"}, "something": {"1.3"}}, 76 inputFrameworks: []framework.Framework{framework.CIS_AWS_1_2, "blah", "other"}, 77 expected: true, 78 }, 79 } 80 81 for i, test := range tests { 82 t.Run(test.name, func(t *testing.T) { 83 Reset() 84 rule := scan.Rule{ 85 AVDID: fmt.Sprintf("%d-%s", i, test.name), 86 Frameworks: test.registeredFrameworks, 87 } 88 _ = Register(rule, nil) 89 var found bool 90 for _, matchedRule := range GetFrameworkRules(test.inputFrameworks...) { 91 if matchedRule.Rule().AVDID == rule.AVDID { 92 assert.False(t, found, "rule should not be returned more than once") 93 found = true 94 } 95 } 96 assert.Equal(t, test.expected, found, "rule should be returned if it matches any of the input frameworks") 97 }) 98 } 99 } 100 101 func Test_Deregistration(t *testing.T) { 102 Reset() 103 registrationA := Register(scan.Rule{ 104 AVDID: "A", 105 }, nil) 106 registrationB := Register(scan.Rule{ 107 AVDID: "B", 108 }, nil) 109 assert.Equal(t, 2, len(GetFrameworkRules())) 110 Deregister(registrationA) 111 actual := GetFrameworkRules() 112 require.Equal(t, 1, len(actual)) 113 assert.Equal(t, "B", actual[0].Rule().AVDID) 114 Deregister(registrationB) 115 assert.Equal(t, 0, len(GetFrameworkRules())) 116 } 117 118 func Test_DeregistrationMultipleFrameworks(t *testing.T) { 119 Reset() 120 registrationA := Register(scan.Rule{ 121 AVDID: "A", 122 }, nil) 123 registrationB := Register(scan.Rule{ 124 AVDID: "B", 125 Frameworks: map[framework.Framework][]string{ 126 "a": nil, 127 "b": nil, 128 "c": nil, 129 framework.Default: nil, 130 }, 131 }, nil) 132 assert.Equal(t, 2, len(GetFrameworkRules())) 133 Deregister(registrationA) 134 actual := GetFrameworkRules() 135 require.Equal(t, 1, len(actual)) 136 assert.Equal(t, "B", actual[0].Rule().AVDID) 137 Deregister(registrationB) 138 assert.Equal(t, 0, len(GetFrameworkRules())) 139 }