github.com/aquasecurity/trivy-iac@v0.8.1-0.20240127024015-3d8e412cf0ab/pkg/scanners/helm/test/scanner_test.go (about) 1 package test 2 3 import ( 4 "context" 5 "io" 6 "os" 7 "path/filepath" 8 "sort" 9 "strings" 10 "testing" 11 12 "github.com/aquasecurity/defsec/pkg/scanners/options" 13 "github.com/aquasecurity/trivy-iac/pkg/scanners/helm" 14 "github.com/stretchr/testify/assert" 15 "github.com/stretchr/testify/require" 16 ) 17 18 func Test_helm_scanner_with_archive(t *testing.T) { 19 20 tests := []struct { 21 testName string 22 chartName string 23 path string 24 archiveName string 25 }{ 26 { 27 testName: "Parsing tarball 'mysql-8.8.26.tar'", 28 chartName: "mysql", 29 path: filepath.Join("testdata", "mysql-8.8.26.tar"), 30 archiveName: "mysql-8.8.26.tar", 31 }, 32 } 33 34 for _, test := range tests { 35 t.Logf("Running test: %s", test.testName) 36 37 helmScanner := helm.New(options.ScannerWithEmbeddedPolicies(true), options.ScannerWithEmbeddedLibraries(true)) 38 39 testTemp := t.TempDir() 40 testFileName := filepath.Join(testTemp, test.archiveName) 41 require.NoError(t, copyArchive(test.path, testFileName)) 42 43 testFs := os.DirFS(testTemp) 44 results, err := helmScanner.ScanFS(context.TODO(), testFs, ".") 45 require.NoError(t, err) 46 require.NotNil(t, results) 47 48 failed := results.GetFailed() 49 assert.Equal(t, 13, len(failed)) 50 51 visited := make(map[string]bool) 52 var errorCodes []string 53 for _, result := range failed { 54 id := result.Flatten().RuleID 55 if _, exists := visited[id]; !exists { 56 visited[id] = true 57 errorCodes = append(errorCodes, id) 58 } 59 } 60 assert.Len(t, errorCodes, 13) 61 62 sort.Strings(errorCodes) 63 64 assert.Equal(t, []string{ 65 "AVD-KSV-0001", "AVD-KSV-0003", 66 "AVD-KSV-0011", "AVD-KSV-0012", "AVD-KSV-0014", 67 "AVD-KSV-0015", "AVD-KSV-0016", "AVD-KSV-0018", 68 "AVD-KSV-0020", "AVD-KSV-0021", "AVD-KSV-0030", 69 "AVD-KSV-0104", "AVD-KSV-0106", 70 }, errorCodes) 71 } 72 } 73 74 func Test_helm_scanner_with_missing_name_can_recover(t *testing.T) { 75 76 tests := []struct { 77 testName string 78 chartName string 79 path string 80 archiveName string 81 }{ 82 { 83 testName: "Parsing tarball 'aws-cluster-autoscaler-bad.tar.gz'", 84 chartName: "aws-cluster-autoscaler", 85 path: filepath.Join("testdata", "aws-cluster-autoscaler-bad.tar.gz"), 86 archiveName: "aws-cluster-autoscaler-bad.tar.gz", 87 }, 88 } 89 90 for _, test := range tests { 91 t.Logf("Running test: %s", test.testName) 92 93 helmScanner := helm.New(options.ScannerWithEmbeddedPolicies(true), options.ScannerWithEmbeddedLibraries(true)) 94 95 testTemp := t.TempDir() 96 testFileName := filepath.Join(testTemp, test.archiveName) 97 require.NoError(t, copyArchive(test.path, testFileName)) 98 99 testFs := os.DirFS(testTemp) 100 _, err := helmScanner.ScanFS(context.TODO(), testFs, ".") 101 require.NoError(t, err) 102 } 103 } 104 105 func Test_helm_scanner_with_dir(t *testing.T) { 106 107 tests := []struct { 108 testName string 109 chartName string 110 }{ 111 { 112 testName: "Parsing directory testchart'", 113 chartName: "testchart", 114 }, 115 } 116 117 for _, test := range tests { 118 119 t.Logf("Running test: %s", test.testName) 120 121 helmScanner := helm.New(options.ScannerWithEmbeddedPolicies(true), options.ScannerWithEmbeddedLibraries(true)) 122 123 testFs := os.DirFS(filepath.Join("testdata", test.chartName)) 124 results, err := helmScanner.ScanFS(context.TODO(), testFs, ".") 125 require.NoError(t, err) 126 require.NotNil(t, results) 127 128 failed := results.GetFailed() 129 assert.Equal(t, 14, len(failed)) 130 131 visited := make(map[string]bool) 132 var errorCodes []string 133 for _, result := range failed { 134 id := result.Flatten().RuleID 135 if _, exists := visited[id]; !exists { 136 visited[id] = true 137 errorCodes = append(errorCodes, id) 138 } 139 } 140 141 sort.Strings(errorCodes) 142 143 assert.Equal(t, []string{ 144 "AVD-KSV-0001", "AVD-KSV-0003", 145 "AVD-KSV-0011", "AVD-KSV-0012", "AVD-KSV-0014", 146 "AVD-KSV-0015", "AVD-KSV-0016", "AVD-KSV-0018", 147 "AVD-KSV-0020", "AVD-KSV-0021", "AVD-KSV-0030", 148 "AVD-KSV-0104", "AVD-KSV-0106", 149 "AVD-KSV-0117", 150 }, errorCodes) 151 } 152 } 153 154 func Test_helm_scanner_with_custom_policies(t *testing.T) { 155 regoRule := ` 156 package user.kubernetes.ID001 157 158 159 __rego_metadata__ := { 160 "id": "ID001", 161 "avd_id": "AVD-USR-ID001", 162 "title": "Services not allowed", 163 "severity": "LOW", 164 "description": "Services are not allowed because of some reasons.", 165 } 166 167 __rego_input__ := { 168 "selector": [ 169 {"type": "kubernetes"}, 170 ], 171 } 172 173 deny[res] { 174 input.kind == "Service" 175 msg := sprintf("Found service '%s' but services are not allowed", [input.metadata.name]) 176 res := result.new(msg, input) 177 } 178 ` 179 tests := []struct { 180 testName string 181 chartName string 182 path string 183 archiveName string 184 }{ 185 { 186 testName: "Parsing tarball 'mysql-8.8.26.tar'", 187 chartName: "mysql", 188 path: filepath.Join("testdata", "mysql-8.8.26.tar"), 189 archiveName: "mysql-8.8.26.tar", 190 }, 191 } 192 193 for _, test := range tests { 194 t.Run(test.testName, func(t *testing.T) { 195 t.Logf("Running test: %s", test.testName) 196 197 helmScanner := helm.New(options.ScannerWithEmbeddedPolicies(true), options.ScannerWithEmbeddedLibraries(true), 198 options.ScannerWithPolicyDirs("rules"), 199 options.ScannerWithPolicyNamespaces("user")) 200 201 testTemp := t.TempDir() 202 testFileName := filepath.Join(testTemp, test.archiveName) 203 require.NoError(t, copyArchive(test.path, testFileName)) 204 205 policyDirName := filepath.Join(testTemp, "rules") 206 require.NoError(t, os.Mkdir(policyDirName, 0o700)) 207 require.NoError(t, os.WriteFile(filepath.Join(policyDirName, "rule.rego"), []byte(regoRule), 0o600)) 208 209 testFs := os.DirFS(testTemp) 210 211 results, err := helmScanner.ScanFS(context.TODO(), testFs, ".") 212 require.NoError(t, err) 213 require.NotNil(t, results) 214 215 failed := results.GetFailed() 216 assert.Equal(t, 15, len(failed)) 217 218 visited := make(map[string]bool) 219 var errorCodes []string 220 for _, result := range failed { 221 id := result.Flatten().RuleID 222 if _, exists := visited[id]; !exists { 223 visited[id] = true 224 errorCodes = append(errorCodes, id) 225 } 226 } 227 assert.Len(t, errorCodes, 14) 228 229 sort.Strings(errorCodes) 230 231 assert.Equal(t, []string{ 232 "AVD-KSV-0001", "AVD-KSV-0003", 233 "AVD-KSV-0011", "AVD-KSV-0012", "AVD-KSV-0014", 234 "AVD-KSV-0015", "AVD-KSV-0016", "AVD-KSV-0018", 235 "AVD-KSV-0020", "AVD-KSV-0021", "AVD-KSV-0030", 236 "AVD-KSV-0104", "AVD-KSV-0106", "AVD-USR-ID001", 237 }, errorCodes) 238 }) 239 } 240 } 241 242 func copyArchive(src, dst string) error { 243 in, err := os.Open(src) 244 if err != nil { 245 return err 246 } 247 defer func() { _ = in.Close() }() 248 249 out, err := os.Create(dst) 250 if err != nil { 251 return err 252 } 253 defer func() { _ = out.Close() }() 254 255 if _, err := io.Copy(out, in); err != nil { 256 return err 257 } 258 return nil 259 } 260 261 func Test_helm_chart_with_templated_name(t *testing.T) { 262 helmScanner := helm.New(options.ScannerWithEmbeddedPolicies(true), options.ScannerWithEmbeddedLibraries(true)) 263 testFs := os.DirFS(filepath.Join("testdata", "templated-name")) 264 _, err := helmScanner.ScanFS(context.TODO(), testFs, ".") 265 require.NoError(t, err) 266 } 267 268 func TestCodeShouldNotBeMissing(t *testing.T) { 269 policy := `# METADATA 270 # title: "Test rego" 271 # description: "Test rego" 272 # scope: package 273 # schemas: 274 # - input: schema["kubernetes"] 275 # custom: 276 # id: ID001 277 # avd_id: AVD-USR-ID001 278 # severity: LOW 279 # input: 280 # selector: 281 # - type: kubernetes 282 package user.kubernetes.ID001 283 284 deny[res] { 285 input.spec.replicas == 3 286 res := result.new("Replicas are not allowed", input) 287 } 288 ` 289 helmScanner := helm.New( 290 options.ScannerWithEmbeddedPolicies(false), 291 options.ScannerWithEmbeddedLibraries(false), 292 options.ScannerWithPolicyNamespaces("user"), 293 options.ScannerWithPolicyReader(strings.NewReader(policy)), 294 ) 295 296 results, err := helmScanner.ScanFS(context.TODO(), os.DirFS("testdata/simmilar-templates"), ".") 297 require.NoError(t, err) 298 299 failedResults := results.GetFailed() 300 require.Len(t, failedResults, 1) 301 302 failed := failedResults[0] 303 code, err := failed.GetCode() 304 require.NoError(t, err) 305 assert.NotNil(t, code) 306 }