github.com/khulnasoft-lab/defsec@v1.0.5-0.20230827010352-5e9f46893d95/internal/rules/register_test.go (about)

     1  package rules
     2  
     3  import (
     4  	"fmt"
     5  	"testing"
     6  
     7  	"github.com/stretchr/testify/require"
     8  
     9  	"github.com/khulnasoft-lab/defsec/pkg/framework"
    10  	"github.com/khulnasoft-lab/defsec/pkg/scan"
    11  	"github.com/stretchr/testify/assert"
    12  )
    13  
    14  func Test_Reset(t *testing.T) {
    15  	rule := scan.Rule{}
    16  	_ = Register(rule, nil)
    17  	assert.Equal(t, 1, len(GetFrameworkRules()))
    18  	Reset()
    19  	assert.Equal(t, 0, len(GetFrameworkRules()))
    20  }
    21  
    22  func Test_Registration(t *testing.T) {
    23  	var tests = []struct {
    24  		name                 string
    25  		registeredFrameworks map[framework.Framework][]string
    26  		inputFrameworks      []framework.Framework
    27  		expected             bool
    28  	}{
    29  		{
    30  			name:     "rule without framework specified should be returned when no frameworks are requested",
    31  			expected: true,
    32  		},
    33  		{
    34  			name:            "rule without framework specified should not be returned when a specific framework is requested",
    35  			inputFrameworks: []framework.Framework{framework.CIS_AWS_1_2},
    36  			expected:        false,
    37  		},
    38  		{
    39  			name:            "rule without framework specified should be returned when the default framework is requested",
    40  			inputFrameworks: []framework.Framework{framework.Default},
    41  			expected:        true,
    42  		},
    43  		{
    44  			name:                 "rule with default framework specified should be returned when the default framework is requested",
    45  			registeredFrameworks: map[framework.Framework][]string{framework.Default: {"1.1"}},
    46  			inputFrameworks:      []framework.Framework{framework.Default},
    47  			expected:             true,
    48  		},
    49  		{
    50  			name:                 "rule with default framework specified should not be returned when a specific framework is requested",
    51  			registeredFrameworks: map[framework.Framework][]string{framework.Default: {"1.1"}},
    52  			inputFrameworks:      []framework.Framework{framework.CIS_AWS_1_2},
    53  			expected:             false,
    54  		},
    55  		{
    56  			name:                 "rule with specific framework specified should not be returned when a default framework is requested",
    57  			registeredFrameworks: map[framework.Framework][]string{framework.CIS_AWS_1_2: {"1.1"}},
    58  			inputFrameworks:      []framework.Framework{framework.Default},
    59  			expected:             false,
    60  		},
    61  		{
    62  			name:                 "rule with specific framework specified should be returned when the specific framework is requested",
    63  			registeredFrameworks: map[framework.Framework][]string{framework.CIS_AWS_1_2: {"1.1"}},
    64  			inputFrameworks:      []framework.Framework{framework.CIS_AWS_1_2},
    65  			expected:             true,
    66  		},
    67  		{
    68  			name:                 "rule with multiple frameworks specified should be returned when the specific framework is requested",
    69  			registeredFrameworks: map[framework.Framework][]string{framework.CIS_AWS_1_2: {"1.1"}, "blah": {"1.2"}},
    70  			inputFrameworks:      []framework.Framework{framework.CIS_AWS_1_2},
    71  			expected:             true,
    72  		},
    73  		{
    74  			name:                 "rule with multiple frameworks specified should be returned only once when multiple matching frameworks are requested",
    75  			registeredFrameworks: map[framework.Framework][]string{framework.CIS_AWS_1_2: {"1.1"}, "blah": {"1.2"}, "something": {"1.3"}},
    76  			inputFrameworks:      []framework.Framework{framework.CIS_AWS_1_2, "blah", "other"},
    77  			expected:             true,
    78  		},
    79  	}
    80  
    81  	for i, test := range tests {
    82  		t.Run(test.name, func(t *testing.T) {
    83  			Reset()
    84  			rule := scan.Rule{
    85  				AVDID:      fmt.Sprintf("%d-%s", i, test.name),
    86  				Frameworks: test.registeredFrameworks,
    87  			}
    88  			_ = Register(rule, nil)
    89  			var found bool
    90  			for _, matchedRule := range GetFrameworkRules(test.inputFrameworks...) {
    91  				if matchedRule.Rule().AVDID == rule.AVDID {
    92  					assert.False(t, found, "rule should not be returned more than once")
    93  					found = true
    94  				}
    95  			}
    96  			assert.Equal(t, test.expected, found, "rule should be returned if it matches any of the input frameworks")
    97  		})
    98  	}
    99  }
   100  
   101  func Test_Deregistration(t *testing.T) {
   102  	Reset()
   103  	registrationA := Register(scan.Rule{
   104  		AVDID: "A",
   105  	}, nil)
   106  	registrationB := Register(scan.Rule{
   107  		AVDID: "B",
   108  	}, nil)
   109  	assert.Equal(t, 2, len(GetFrameworkRules()))
   110  	Deregister(registrationA)
   111  	actual := GetFrameworkRules()
   112  	require.Equal(t, 1, len(actual))
   113  	assert.Equal(t, "B", actual[0].Rule().AVDID)
   114  	Deregister(registrationB)
   115  	assert.Equal(t, 0, len(GetFrameworkRules()))
   116  }
   117  
   118  func Test_DeregistrationMultipleFrameworks(t *testing.T) {
   119  	Reset()
   120  	registrationA := Register(scan.Rule{
   121  		AVDID: "A",
   122  	}, nil)
   123  	registrationB := Register(scan.Rule{
   124  		AVDID: "B",
   125  		Frameworks: map[framework.Framework][]string{
   126  			"a":               nil,
   127  			"b":               nil,
   128  			"c":               nil,
   129  			framework.Default: nil,
   130  		},
   131  	}, nil)
   132  	assert.Equal(t, 2, len(GetFrameworkRules()))
   133  	Deregister(registrationA)
   134  	actual := GetFrameworkRules()
   135  	require.Equal(t, 1, len(actual))
   136  	assert.Equal(t, "B", actual[0].Rule().AVDID)
   137  	Deregister(registrationB)
   138  	assert.Equal(t, 0, len(GetFrameworkRules()))
   139  }