github.com/aquasecurity/trivy-iac@v0.8.1-0.20240127024015-3d8e412cf0ab/pkg/scanners/helm/test/scanner_test.go (about)

     1  package test
     2  
     3  import (
     4  	"context"
     5  	"io"
     6  	"os"
     7  	"path/filepath"
     8  	"sort"
     9  	"strings"
    10  	"testing"
    11  
    12  	"github.com/aquasecurity/defsec/pkg/scanners/options"
    13  	"github.com/aquasecurity/trivy-iac/pkg/scanners/helm"
    14  	"github.com/stretchr/testify/assert"
    15  	"github.com/stretchr/testify/require"
    16  )
    17  
    18  func Test_helm_scanner_with_archive(t *testing.T) {
    19  
    20  	tests := []struct {
    21  		testName    string
    22  		chartName   string
    23  		path        string
    24  		archiveName string
    25  	}{
    26  		{
    27  			testName:    "Parsing tarball 'mysql-8.8.26.tar'",
    28  			chartName:   "mysql",
    29  			path:        filepath.Join("testdata", "mysql-8.8.26.tar"),
    30  			archiveName: "mysql-8.8.26.tar",
    31  		},
    32  	}
    33  
    34  	for _, test := range tests {
    35  		t.Logf("Running test: %s", test.testName)
    36  
    37  		helmScanner := helm.New(options.ScannerWithEmbeddedPolicies(true), options.ScannerWithEmbeddedLibraries(true))
    38  
    39  		testTemp := t.TempDir()
    40  		testFileName := filepath.Join(testTemp, test.archiveName)
    41  		require.NoError(t, copyArchive(test.path, testFileName))
    42  
    43  		testFs := os.DirFS(testTemp)
    44  		results, err := helmScanner.ScanFS(context.TODO(), testFs, ".")
    45  		require.NoError(t, err)
    46  		require.NotNil(t, results)
    47  
    48  		failed := results.GetFailed()
    49  		assert.Equal(t, 13, len(failed))
    50  
    51  		visited := make(map[string]bool)
    52  		var errorCodes []string
    53  		for _, result := range failed {
    54  			id := result.Flatten().RuleID
    55  			if _, exists := visited[id]; !exists {
    56  				visited[id] = true
    57  				errorCodes = append(errorCodes, id)
    58  			}
    59  		}
    60  		assert.Len(t, errorCodes, 13)
    61  
    62  		sort.Strings(errorCodes)
    63  
    64  		assert.Equal(t, []string{
    65  			"AVD-KSV-0001", "AVD-KSV-0003",
    66  			"AVD-KSV-0011", "AVD-KSV-0012", "AVD-KSV-0014",
    67  			"AVD-KSV-0015", "AVD-KSV-0016", "AVD-KSV-0018",
    68  			"AVD-KSV-0020", "AVD-KSV-0021", "AVD-KSV-0030",
    69  			"AVD-KSV-0104", "AVD-KSV-0106",
    70  		}, errorCodes)
    71  	}
    72  }
    73  
    74  func Test_helm_scanner_with_missing_name_can_recover(t *testing.T) {
    75  
    76  	tests := []struct {
    77  		testName    string
    78  		chartName   string
    79  		path        string
    80  		archiveName string
    81  	}{
    82  		{
    83  			testName:    "Parsing tarball 'aws-cluster-autoscaler-bad.tar.gz'",
    84  			chartName:   "aws-cluster-autoscaler",
    85  			path:        filepath.Join("testdata", "aws-cluster-autoscaler-bad.tar.gz"),
    86  			archiveName: "aws-cluster-autoscaler-bad.tar.gz",
    87  		},
    88  	}
    89  
    90  	for _, test := range tests {
    91  		t.Logf("Running test: %s", test.testName)
    92  
    93  		helmScanner := helm.New(options.ScannerWithEmbeddedPolicies(true), options.ScannerWithEmbeddedLibraries(true))
    94  
    95  		testTemp := t.TempDir()
    96  		testFileName := filepath.Join(testTemp, test.archiveName)
    97  		require.NoError(t, copyArchive(test.path, testFileName))
    98  
    99  		testFs := os.DirFS(testTemp)
   100  		_, err := helmScanner.ScanFS(context.TODO(), testFs, ".")
   101  		require.NoError(t, err)
   102  	}
   103  }
   104  
   105  func Test_helm_scanner_with_dir(t *testing.T) {
   106  
   107  	tests := []struct {
   108  		testName  string
   109  		chartName string
   110  	}{
   111  		{
   112  			testName:  "Parsing directory testchart'",
   113  			chartName: "testchart",
   114  		},
   115  	}
   116  
   117  	for _, test := range tests {
   118  
   119  		t.Logf("Running test: %s", test.testName)
   120  
   121  		helmScanner := helm.New(options.ScannerWithEmbeddedPolicies(true), options.ScannerWithEmbeddedLibraries(true))
   122  
   123  		testFs := os.DirFS(filepath.Join("testdata", test.chartName))
   124  		results, err := helmScanner.ScanFS(context.TODO(), testFs, ".")
   125  		require.NoError(t, err)
   126  		require.NotNil(t, results)
   127  
   128  		failed := results.GetFailed()
   129  		assert.Equal(t, 14, len(failed))
   130  
   131  		visited := make(map[string]bool)
   132  		var errorCodes []string
   133  		for _, result := range failed {
   134  			id := result.Flatten().RuleID
   135  			if _, exists := visited[id]; !exists {
   136  				visited[id] = true
   137  				errorCodes = append(errorCodes, id)
   138  			}
   139  		}
   140  
   141  		sort.Strings(errorCodes)
   142  
   143  		assert.Equal(t, []string{
   144  			"AVD-KSV-0001", "AVD-KSV-0003",
   145  			"AVD-KSV-0011", "AVD-KSV-0012", "AVD-KSV-0014",
   146  			"AVD-KSV-0015", "AVD-KSV-0016", "AVD-KSV-0018",
   147  			"AVD-KSV-0020", "AVD-KSV-0021", "AVD-KSV-0030",
   148  			"AVD-KSV-0104", "AVD-KSV-0106",
   149  			"AVD-KSV-0117",
   150  		}, errorCodes)
   151  	}
   152  }
   153  
   154  func Test_helm_scanner_with_custom_policies(t *testing.T) {
   155  	regoRule := `
   156  package user.kubernetes.ID001
   157  
   158  
   159  __rego_metadata__ := {
   160      "id": "ID001",
   161  	"avd_id": "AVD-USR-ID001",
   162      "title": "Services not allowed",
   163      "severity": "LOW",
   164      "description": "Services are not allowed because of some reasons.",
   165  }
   166  
   167  __rego_input__ := {
   168      "selector": [
   169          {"type": "kubernetes"},
   170      ],
   171  }
   172  
   173  deny[res] {
   174      input.kind == "Service"
   175      msg := sprintf("Found service '%s' but services are not allowed", [input.metadata.name])
   176      res := result.new(msg, input)
   177  }
   178  `
   179  	tests := []struct {
   180  		testName    string
   181  		chartName   string
   182  		path        string
   183  		archiveName string
   184  	}{
   185  		{
   186  			testName:    "Parsing tarball 'mysql-8.8.26.tar'",
   187  			chartName:   "mysql",
   188  			path:        filepath.Join("testdata", "mysql-8.8.26.tar"),
   189  			archiveName: "mysql-8.8.26.tar",
   190  		},
   191  	}
   192  
   193  	for _, test := range tests {
   194  		t.Run(test.testName, func(t *testing.T) {
   195  			t.Logf("Running test: %s", test.testName)
   196  
   197  			helmScanner := helm.New(options.ScannerWithEmbeddedPolicies(true), options.ScannerWithEmbeddedLibraries(true),
   198  				options.ScannerWithPolicyDirs("rules"),
   199  				options.ScannerWithPolicyNamespaces("user"))
   200  
   201  			testTemp := t.TempDir()
   202  			testFileName := filepath.Join(testTemp, test.archiveName)
   203  			require.NoError(t, copyArchive(test.path, testFileName))
   204  
   205  			policyDirName := filepath.Join(testTemp, "rules")
   206  			require.NoError(t, os.Mkdir(policyDirName, 0o700))
   207  			require.NoError(t, os.WriteFile(filepath.Join(policyDirName, "rule.rego"), []byte(regoRule), 0o600))
   208  
   209  			testFs := os.DirFS(testTemp)
   210  
   211  			results, err := helmScanner.ScanFS(context.TODO(), testFs, ".")
   212  			require.NoError(t, err)
   213  			require.NotNil(t, results)
   214  
   215  			failed := results.GetFailed()
   216  			assert.Equal(t, 15, len(failed))
   217  
   218  			visited := make(map[string]bool)
   219  			var errorCodes []string
   220  			for _, result := range failed {
   221  				id := result.Flatten().RuleID
   222  				if _, exists := visited[id]; !exists {
   223  					visited[id] = true
   224  					errorCodes = append(errorCodes, id)
   225  				}
   226  			}
   227  			assert.Len(t, errorCodes, 14)
   228  
   229  			sort.Strings(errorCodes)
   230  
   231  			assert.Equal(t, []string{
   232  				"AVD-KSV-0001", "AVD-KSV-0003",
   233  				"AVD-KSV-0011", "AVD-KSV-0012", "AVD-KSV-0014",
   234  				"AVD-KSV-0015", "AVD-KSV-0016", "AVD-KSV-0018",
   235  				"AVD-KSV-0020", "AVD-KSV-0021", "AVD-KSV-0030",
   236  				"AVD-KSV-0104", "AVD-KSV-0106", "AVD-USR-ID001",
   237  			}, errorCodes)
   238  		})
   239  	}
   240  }
   241  
   242  func copyArchive(src, dst string) error {
   243  	in, err := os.Open(src)
   244  	if err != nil {
   245  		return err
   246  	}
   247  	defer func() { _ = in.Close() }()
   248  
   249  	out, err := os.Create(dst)
   250  	if err != nil {
   251  		return err
   252  	}
   253  	defer func() { _ = out.Close() }()
   254  
   255  	if _, err := io.Copy(out, in); err != nil {
   256  		return err
   257  	}
   258  	return nil
   259  }
   260  
   261  func Test_helm_chart_with_templated_name(t *testing.T) {
   262  	helmScanner := helm.New(options.ScannerWithEmbeddedPolicies(true), options.ScannerWithEmbeddedLibraries(true))
   263  	testFs := os.DirFS(filepath.Join("testdata", "templated-name"))
   264  	_, err := helmScanner.ScanFS(context.TODO(), testFs, ".")
   265  	require.NoError(t, err)
   266  }
   267  
   268  func TestCodeShouldNotBeMissing(t *testing.T) {
   269  	policy := `# METADATA
   270  # title: "Test rego"
   271  # description: "Test rego"
   272  # scope: package
   273  # schemas:
   274  # - input: schema["kubernetes"]
   275  # custom:
   276  #   id: ID001
   277  #   avd_id: AVD-USR-ID001
   278  #   severity: LOW
   279  #   input:
   280  #     selector:
   281  #     - type: kubernetes
   282  package user.kubernetes.ID001
   283  
   284  deny[res] {
   285      input.spec.replicas == 3
   286      res := result.new("Replicas are not allowed", input)
   287  }
   288  `
   289  	helmScanner := helm.New(
   290  		options.ScannerWithEmbeddedPolicies(false),
   291  		options.ScannerWithEmbeddedLibraries(false),
   292  		options.ScannerWithPolicyNamespaces("user"),
   293  		options.ScannerWithPolicyReader(strings.NewReader(policy)),
   294  	)
   295  
   296  	results, err := helmScanner.ScanFS(context.TODO(), os.DirFS("testdata/simmilar-templates"), ".")
   297  	require.NoError(t, err)
   298  
   299  	failedResults := results.GetFailed()
   300  	require.Len(t, failedResults, 1)
   301  
   302  	failed := failedResults[0]
   303  	code, err := failed.GetCode()
   304  	require.NoError(t, err)
   305  	assert.NotNil(t, code)
   306  }