github.com/khulnasoft-lab/defsec@v1.0.5-0.20230827010352-5e9f46893d95/pkg/scanners/helm/test/scanner_test.go (about)

     1  package test
     2  
     3  import (
     4  	"context"
     5  	"io"
     6  	"os"
     7  	"path/filepath"
     8  	"sort"
     9  	"testing"
    10  
    11  	"github.com/khulnasoft-lab/defsec/pkg/scanners/options"
    12  
    13  	"github.com/stretchr/testify/assert"
    14  	"github.com/stretchr/testify/require"
    15  
    16  	"github.com/khulnasoft-lab/defsec/pkg/scanners/helm"
    17  )
    18  
    19  func Test_helm_scanner_with_archive(t *testing.T) {
    20  
    21  	tests := []struct {
    22  		testName    string
    23  		chartName   string
    24  		path        string
    25  		archiveName string
    26  	}{
    27  		{
    28  			testName:    "Parsing tarball 'mysql-8.8.26.tar'",
    29  			chartName:   "mysql",
    30  			path:        filepath.Join("testdata", "mysql-8.8.26.tar"),
    31  			archiveName: "mysql-8.8.26.tar",
    32  		},
    33  	}
    34  
    35  	for _, test := range tests {
    36  		t.Logf("Running test: %s", test.testName)
    37  
    38  		helmScanner := helm.New(options.ScannerWithEmbeddedPolicies(true), options.ScannerWithEmbeddedLibraries(true))
    39  
    40  		testTemp := t.TempDir()
    41  		testFileName := filepath.Join(testTemp, test.archiveName)
    42  		require.NoError(t, copyArchive(test.path, testFileName))
    43  
    44  		testFs := os.DirFS(testTemp)
    45  		results, err := helmScanner.ScanFS(context.TODO(), testFs, ".")
    46  		require.NoError(t, err)
    47  		require.NotNil(t, results)
    48  
    49  		failed := results.GetFailed()
    50  		assert.Equal(t, 19, len(failed))
    51  
    52  		visited := make(map[string]bool)
    53  		var errorCodes []string
    54  		for _, result := range failed {
    55  			id := result.Flatten().RuleID
    56  			if _, exists := visited[id]; !exists {
    57  				visited[id] = true
    58  				errorCodes = append(errorCodes, id)
    59  			}
    60  		}
    61  		assert.Len(t, errorCodes, 14)
    62  
    63  		sort.Strings(errorCodes)
    64  
    65  		assert.Equal(t, []string{
    66  			"AVD-KSV-0001", "AVD-KSV-0003",
    67  			"AVD-KSV-0011", "AVD-KSV-0012", "AVD-KSV-0014",
    68  			"AVD-KSV-0015", "AVD-KSV-0016", "AVD-KSV-0018",
    69  			"AVD-KSV-0020", "AVD-KSV-0021", "AVD-KSV-0030",
    70  			"AVD-KSV-0104", "AVD-KSV-0106", "AVD-KSV-0116",
    71  		}, errorCodes)
    72  	}
    73  }
    74  
    75  func Test_helm_scanner_with_missing_name_can_recover(t *testing.T) {
    76  
    77  	tests := []struct {
    78  		testName    string
    79  		chartName   string
    80  		path        string
    81  		archiveName string
    82  	}{
    83  		{
    84  			testName:    "Parsing tarball 'aws-cluster-autoscaler-bad.tar.gz'",
    85  			chartName:   "aws-cluster-autoscaler",
    86  			path:        filepath.Join("testdata", "aws-cluster-autoscaler-bad.tar.gz"),
    87  			archiveName: "aws-cluster-autoscaler-bad.tar.gz",
    88  		},
    89  	}
    90  
    91  	for _, test := range tests {
    92  		t.Logf("Running test: %s", test.testName)
    93  
    94  		helmScanner := helm.New(options.ScannerWithEmbeddedPolicies(true), options.ScannerWithEmbeddedLibraries(true))
    95  
    96  		testTemp := t.TempDir()
    97  		testFileName := filepath.Join(testTemp, test.archiveName)
    98  		require.NoError(t, copyArchive(test.path, testFileName))
    99  
   100  		testFs := os.DirFS(testTemp)
   101  		_, err := helmScanner.ScanFS(context.TODO(), testFs, ".")
   102  		require.NoError(t, err)
   103  	}
   104  }
   105  
   106  func Test_helm_scanner_with_dir(t *testing.T) {
   107  
   108  	tests := []struct {
   109  		testName  string
   110  		chartName string
   111  	}{
   112  		{
   113  			testName:  "Parsing directory testchart'",
   114  			chartName: "testchart",
   115  		},
   116  	}
   117  
   118  	for _, test := range tests {
   119  
   120  		t.Logf("Running test: %s", test.testName)
   121  
   122  		helmScanner := helm.New(options.ScannerWithEmbeddedPolicies(true), options.ScannerWithEmbeddedLibraries(true))
   123  
   124  		testFs := os.DirFS(filepath.Join("testdata", test.chartName))
   125  		results, err := helmScanner.ScanFS(context.TODO(), testFs, ".")
   126  		require.NoError(t, err)
   127  		require.NotNil(t, results)
   128  
   129  		failed := results.GetFailed()
   130  		assert.Equal(t, 16, len(failed))
   131  
   132  		visited := make(map[string]bool)
   133  		var errorCodes []string
   134  		for _, result := range failed {
   135  			id := result.Flatten().RuleID
   136  			if _, exists := visited[id]; !exists {
   137  				visited[id] = true
   138  				errorCodes = append(errorCodes, id)
   139  			}
   140  		}
   141  
   142  		sort.Strings(errorCodes)
   143  
   144  		assert.Equal(t, []string{
   145  			"AVD-KSV-0001", "AVD-KSV-0003",
   146  			"AVD-KSV-0011", "AVD-KSV-0012", "AVD-KSV-0014",
   147  			"AVD-KSV-0015", "AVD-KSV-0016", "AVD-KSV-0018",
   148  			"AVD-KSV-0020", "AVD-KSV-0021", "AVD-KSV-0030",
   149  			"AVD-KSV-0104", "AVD-KSV-0106", "AVD-KSV-0116",
   150  		}, errorCodes)
   151  	}
   152  }
   153  
   154  func Test_helm_scanner_with_custom_policies(t *testing.T) {
   155  	regoRule := `
   156  package user.kubernetes.ID001
   157  
   158  
   159  __rego_metadata__ := {
   160      "id": "ID001",
   161  	"avd_id": "AVD-USR-ID001",
   162      "title": "Services not allowed",
   163      "severity": "LOW",
   164      "description": "Services are not allowed because of some reasons.",
   165  }
   166  
   167  __rego_input__ := {
   168      "selector": [
   169          {"type": "kubernetes"},
   170      ],
   171  }
   172  
   173  deny[res] {
   174      input.kind == "Service"
   175      msg := sprintf("Found service '%s' but services are not allowed", [input.metadata.name])
   176      res := result.new(msg, input)
   177  }
   178  `
   179  	tests := []struct {
   180  		testName    string
   181  		chartName   string
   182  		path        string
   183  		archiveName string
   184  	}{
   185  		{
   186  			testName:    "Parsing tarball 'mysql-8.8.26.tar'",
   187  			chartName:   "mysql",
   188  			path:        filepath.Join("testdata", "mysql-8.8.26.tar"),
   189  			archiveName: "mysql-8.8.26.tar",
   190  		},
   191  	}
   192  
   193  	for _, test := range tests {
   194  		t.Run(test.testName, func(t *testing.T) {
   195  			t.Logf("Running test: %s", test.testName)
   196  
   197  			helmScanner := helm.New(options.ScannerWithEmbeddedPolicies(true), options.ScannerWithEmbeddedLibraries(true),
   198  				options.ScannerWithPolicyDirs("rules"),
   199  				options.ScannerWithPolicyNamespaces("user"))
   200  
   201  			testTemp := t.TempDir()
   202  			testFileName := filepath.Join(testTemp, test.archiveName)
   203  			require.NoError(t, copyArchive(test.path, testFileName))
   204  
   205  			policyDirName := filepath.Join(testTemp, "rules")
   206  			require.NoError(t, os.Mkdir(policyDirName, 0o700))
   207  			require.NoError(t, os.WriteFile(filepath.Join(policyDirName, "rule.rego"), []byte(regoRule), 0o600))
   208  
   209  			testFs := os.DirFS(testTemp)
   210  
   211  			results, err := helmScanner.ScanFS(context.TODO(), testFs, ".")
   212  			require.NoError(t, err)
   213  			require.NotNil(t, results)
   214  
   215  			failed := results.GetFailed()
   216  			assert.Equal(t, 21, len(failed))
   217  
   218  			visited := make(map[string]bool)
   219  			var errorCodes []string
   220  			for _, result := range failed {
   221  				id := result.Flatten().RuleID
   222  				if _, exists := visited[id]; !exists {
   223  					visited[id] = true
   224  					errorCodes = append(errorCodes, id)
   225  				}
   226  			}
   227  			assert.Len(t, errorCodes, 15)
   228  
   229  			sort.Strings(errorCodes)
   230  
   231  			assert.Equal(t, []string{
   232  				"AVD-KSV-0001", "AVD-KSV-0003",
   233  				"AVD-KSV-0011", "AVD-KSV-0012", "AVD-KSV-0014",
   234  				"AVD-KSV-0015", "AVD-KSV-0016", "AVD-KSV-0018",
   235  				"AVD-KSV-0020", "AVD-KSV-0021", "AVD-KSV-0030",
   236  				"AVD-KSV-0104", "AVD-KSV-0106", "AVD-KSV-0116", "AVD-USR-ID001",
   237  			}, errorCodes)
   238  		})
   239  	}
   240  }
   241  
   242  func copyArchive(src, dst string) error {
   243  	in, err := os.Open(src)
   244  	if err != nil {
   245  		return err
   246  	}
   247  	defer func() { _ = in.Close() }()
   248  
   249  	out, err := os.Create(dst)
   250  	if err != nil {
   251  		return err
   252  	}
   253  	defer func() { _ = out.Close() }()
   254  
   255  	if _, err := io.Copy(out, in); err != nil {
   256  		return err
   257  	}
   258  	return nil
   259  }
   260  
   261  func Test_helm_chart_with_templated_name(t *testing.T) {
   262  	helmScanner := helm.New(options.ScannerWithEmbeddedPolicies(true), options.ScannerWithEmbeddedLibraries(true))
   263  	testFs := os.DirFS(filepath.Join("testdata", "templated-name"))
   264  	_, err := helmScanner.ScanFS(context.TODO(), testFs, ".")
   265  	require.NoError(t, err)
   266  }