github.com/khulnasoft-lab/defsec@v1.0.5-0.20230827010352-5e9f46893d95/rules/cloud/policies/azure/database/no_public_firewall_access_test.go (about)

     1  package database
     2  
     3  import (
     4  	"testing"
     5  
     6  	defsecTypes "github.com/khulnasoft-lab/defsec/pkg/types"
     7  
     8  	"github.com/khulnasoft-lab/defsec/pkg/state"
     9  
    10  	"github.com/khulnasoft-lab/defsec/pkg/providers/azure/database"
    11  	"github.com/khulnasoft-lab/defsec/pkg/scan"
    12  
    13  	"github.com/stretchr/testify/assert"
    14  )
    15  
    16  func TestCheckNoPublicFirewallAccess(t *testing.T) {
    17  	tests := []struct {
    18  		name     string
    19  		input    database.Database
    20  		expected bool
    21  	}{
    22  		{
    23  			name: "MySQL server firewall allows public internet access",
    24  			input: database.Database{
    25  				MySQLServers: []database.MySQLServer{
    26  					{
    27  						Metadata: defsecTypes.NewTestMetadata(),
    28  						Server: database.Server{
    29  							Metadata: defsecTypes.NewTestMetadata(),
    30  							FirewallRules: []database.FirewallRule{
    31  								{
    32  									Metadata: defsecTypes.NewTestMetadata(),
    33  									StartIP:  defsecTypes.String("0.0.0.0", defsecTypes.NewTestMetadata()),
    34  									EndIP:    defsecTypes.String("255.255.255.255", defsecTypes.NewTestMetadata()),
    35  								},
    36  							},
    37  						},
    38  					},
    39  				},
    40  			},
    41  			expected: true,
    42  		},
    43  		{
    44  			name: "MySQL server firewall allows single public internet access",
    45  			input: database.Database{
    46  				MySQLServers: []database.MySQLServer{
    47  					{
    48  						Metadata: defsecTypes.NewTestMetadata(),
    49  						Server: database.Server{
    50  							Metadata: defsecTypes.NewTestMetadata(),
    51  							FirewallRules: []database.FirewallRule{
    52  								{
    53  									Metadata: defsecTypes.NewTestMetadata(),
    54  									StartIP:  defsecTypes.String("8.8.8.8", defsecTypes.NewTestMetadata()),
    55  									EndIP:    defsecTypes.String("8.8.8.8", defsecTypes.NewTestMetadata()),
    56  								},
    57  							},
    58  						},
    59  					},
    60  				},
    61  			},
    62  			expected: false,
    63  		},
    64  		{
    65  			name: "MS SQL server firewall allows public internet access",
    66  			input: database.Database{
    67  				MSSQLServers: []database.MSSQLServer{
    68  					{
    69  						Metadata: defsecTypes.NewTestMetadata(),
    70  						Server: database.Server{
    71  							Metadata: defsecTypes.NewTestMetadata(),
    72  							FirewallRules: []database.FirewallRule{
    73  								{
    74  									Metadata: defsecTypes.NewTestMetadata(),
    75  									StartIP:  defsecTypes.String("0.0.0.0", defsecTypes.NewTestMetadata()),
    76  									EndIP:    defsecTypes.String("255.255.255.255", defsecTypes.NewTestMetadata()),
    77  								},
    78  							},
    79  						},
    80  					},
    81  				},
    82  			},
    83  			expected: true,
    84  		},
    85  		{
    86  			name: "PostgreSQL server firewall allows public internet access",
    87  			input: database.Database{
    88  				PostgreSQLServers: []database.PostgreSQLServer{
    89  					{
    90  						Metadata: defsecTypes.NewTestMetadata(),
    91  						Server: database.Server{
    92  							Metadata: defsecTypes.NewTestMetadata(),
    93  							FirewallRules: []database.FirewallRule{
    94  								{
    95  									Metadata: defsecTypes.NewTestMetadata(),
    96  									StartIP:  defsecTypes.String("0.0.0.0", defsecTypes.NewTestMetadata()),
    97  									EndIP:    defsecTypes.String("255.255.255.255", defsecTypes.NewTestMetadata()),
    98  								},
    99  							},
   100  						},
   101  					},
   102  				},
   103  			},
   104  			expected: true,
   105  		},
   106  		{
   107  			name: "MariaDB server firewall allows public internet access",
   108  			input: database.Database{
   109  				MariaDBServers: []database.MariaDBServer{
   110  					{
   111  						Metadata: defsecTypes.NewTestMetadata(),
   112  						Server: database.Server{
   113  							Metadata: defsecTypes.NewTestMetadata(),
   114  							FirewallRules: []database.FirewallRule{
   115  								{
   116  									Metadata: defsecTypes.NewTestMetadata(),
   117  									StartIP:  defsecTypes.String("0.0.0.0", defsecTypes.NewTestMetadata()),
   118  									EndIP:    defsecTypes.String("255.255.255.255", defsecTypes.NewTestMetadata()),
   119  								},
   120  							},
   121  						},
   122  					},
   123  				},
   124  			},
   125  			expected: true,
   126  		},
   127  		{
   128  			name: "MySQL server firewall allows access to Azure services",
   129  			input: database.Database{
   130  				MySQLServers: []database.MySQLServer{
   131  					{
   132  						Metadata: defsecTypes.NewTestMetadata(),
   133  						Server: database.Server{
   134  							Metadata: defsecTypes.NewTestMetadata(),
   135  							FirewallRules: []database.FirewallRule{
   136  								{
   137  									Metadata: defsecTypes.NewTestMetadata(),
   138  									StartIP:  defsecTypes.String("0.0.0.0", defsecTypes.NewTestMetadata()),
   139  									EndIP:    defsecTypes.String("0.0.0.0", defsecTypes.NewTestMetadata()),
   140  								},
   141  							},
   142  						},
   143  					},
   144  				},
   145  			},
   146  			expected: false,
   147  		},
   148  		{
   149  			name: "MS SQL server firewall allows access to Azure services",
   150  			input: database.Database{
   151  				MSSQLServers: []database.MSSQLServer{
   152  					{
   153  						Metadata: defsecTypes.NewTestMetadata(),
   154  						Server: database.Server{
   155  							Metadata: defsecTypes.NewTestMetadata(),
   156  							FirewallRules: []database.FirewallRule{
   157  								{
   158  									Metadata: defsecTypes.NewTestMetadata(),
   159  									StartIP:  defsecTypes.String("0.0.0.0", defsecTypes.NewTestMetadata()),
   160  									EndIP:    defsecTypes.String("0.0.0.0", defsecTypes.NewTestMetadata()),
   161  								},
   162  							},
   163  						},
   164  					},
   165  				},
   166  			},
   167  			expected: false,
   168  		},
   169  		{
   170  			name: "PostgreSQL server firewall allows access to Azure services",
   171  			input: database.Database{
   172  				PostgreSQLServers: []database.PostgreSQLServer{
   173  					{
   174  						Metadata: defsecTypes.NewTestMetadata(),
   175  						Server: database.Server{
   176  							Metadata: defsecTypes.NewTestMetadata(),
   177  							FirewallRules: []database.FirewallRule{
   178  								{
   179  									Metadata: defsecTypes.NewTestMetadata(),
   180  									StartIP:  defsecTypes.String("0.0.0.0", defsecTypes.NewTestMetadata()),
   181  									EndIP:    defsecTypes.String("0.0.0.0", defsecTypes.NewTestMetadata()),
   182  								},
   183  							},
   184  						},
   185  					},
   186  				},
   187  			},
   188  			expected: false,
   189  		},
   190  		{
   191  			name: "MariaDB server firewall allows access to Azure services",
   192  			input: database.Database{
   193  				MariaDBServers: []database.MariaDBServer{
   194  					{
   195  						Metadata: defsecTypes.NewTestMetadata(),
   196  						Server: database.Server{
   197  							Metadata: defsecTypes.NewTestMetadata(),
   198  							FirewallRules: []database.FirewallRule{
   199  								{
   200  									Metadata: defsecTypes.NewTestMetadata(),
   201  									StartIP:  defsecTypes.String("0.0.0.0", defsecTypes.NewTestMetadata()),
   202  									EndIP:    defsecTypes.String("0.0.0.0", defsecTypes.NewTestMetadata()),
   203  								},
   204  							},
   205  						},
   206  					},
   207  				},
   208  			},
   209  			expected: false,
   210  		},
   211  	}
   212  	for _, test := range tests {
   213  		t.Run(test.name, func(t *testing.T) {
   214  			var testState state.State
   215  			testState.Azure.Database = test.input
   216  			results := CheckNoPublicFirewallAccess.Evaluate(&testState)
   217  			var found bool
   218  			for _, result := range results {
   219  				if result.Status() == scan.StatusFailed && result.Rule().LongID() == CheckNoPublicFirewallAccess.Rule().LongID() {
   220  					found = true
   221  				}
   222  			}
   223  			if test.expected {
   224  				assert.True(t, found, "Rule should have been found")
   225  			} else {
   226  				assert.False(t, found, "Rule should not have been found")
   227  			}
   228  		})
   229  	}
   230  }