github.com/darmach/terratest@v0.34.8-0.20210517103231-80931f95e3ff/modules/azure/nsg.go (about) 1 package azure 2 3 import ( 4 "context" 5 "fmt" 6 "strconv" 7 "strings" 8 "testing" 9 10 "github.com/Azure/azure-sdk-for-go/services/network/mgmt/2019-09-01/network" 11 "github.com/stretchr/testify/assert" 12 "github.com/stretchr/testify/require" 13 ) 14 15 // NsgRuleSummaryList holds a colleciton of NsgRuleSummary rules 16 type NsgRuleSummaryList struct { 17 SummarizedRules []NsgRuleSummary 18 } 19 20 // NsgRuleSummary is a string-based (non-pointer) summary of an NSG rule with several helper methods attached 21 // to help with verification of rule configuratoin. 22 type NsgRuleSummary struct { 23 Name string 24 Description string 25 Protocol string 26 SourcePortRange string 27 DestinationPortRange string 28 SourceAddressPrefix string 29 DestinationAddressPrefix string 30 Access string 31 Priority int32 32 Direction string 33 } 34 35 // GetDefaultNsgRulesClient returns a rules client which can be used to read the list of *default* security rules 36 // defined on an network security group. Note that the "default" rules are those provided implicitly 37 // by the Azure platform. 38 // This function would fail the test if there is an error. 39 func GetDefaultNsgRulesClient(t *testing.T, subscriptionID string) network.DefaultSecurityRulesClient { 40 client, err := GetDefaultNsgRulesClientE(subscriptionID) 41 require.NoError(t, err) 42 return client 43 } 44 45 // GetDefaultNsgRulesClientE returns a rules client which can be used to read the list of *default* security rules 46 // defined on an network security group. Note that the "default" rules are those provided implicitly 47 // by the Azure platform. 48 func GetDefaultNsgRulesClientE(subscriptionID string) (network.DefaultSecurityRulesClient, error) { 49 // Get new default client from client factory 50 nsgClient, err := CreateNsgDefaultRulesClientE(subscriptionID) 51 if err != nil { 52 return network.DefaultSecurityRulesClient{}, err 53 } 54 55 // Get an authorizer 56 auth, err := NewAuthorizer() 57 if err != nil { 58 return network.DefaultSecurityRulesClient{}, err 59 } 60 61 nsgClient.Authorizer = *auth 62 return *nsgClient, nil 63 } 64 65 // GetCustomNsgRulesClient returns a rules client which can be used to read the list of *custom* security rules 66 // defined on an network security group. Note that the "custom" rules are those defined by 67 // end users. 68 // This function would fail the test if there is an error. 69 func GetCustomNsgRulesClient(t *testing.T, subscriptionID string) network.SecurityRulesClient { 70 client, err := GetCustomNsgRulesClientE(subscriptionID) 71 require.NoError(t, err) 72 return client 73 } 74 75 // GetCustomNsgRulesClientE returns a rules client which can be used to read the list of *custom* security rules 76 // defined on an network security group. Note that the "custom" rules are those defined by 77 // end users. 78 func GetCustomNsgRulesClientE(subscriptionID string) (network.SecurityRulesClient, error) { 79 // Get new custom rules client from client factory 80 nsgClient, err := CreateNsgCustomRulesClientE(subscriptionID) 81 if err != nil { 82 return network.SecurityRulesClient{}, err 83 } 84 85 // Get an authorizer 86 auth, err := NewAuthorizer() 87 if err != nil { 88 return network.SecurityRulesClient{}, err 89 } 90 91 nsgClient.Authorizer = *auth 92 return *nsgClient, nil 93 } 94 95 // GetAllNSGRules returns an NsgRuleSummaryList instance containing the combined "default" and "custom" rules from a network 96 // security group. 97 // This function would fail the test if there is an error. 98 func GetAllNSGRules(t *testing.T, resourceGroupName, nsgName, subscriptionID string) NsgRuleSummaryList { 99 results, err := GetAllNSGRulesE(resourceGroupName, nsgName, subscriptionID) 100 require.NoError(t, err) 101 return results 102 } 103 104 // GetAllNSGRulesE returns an NsgRuleSummaryList instance containing the combined "default" and "custom" rules from a network 105 // security group. 106 func GetAllNSGRulesE(resourceGroupName, nsgName, subscriptionID string) (NsgRuleSummaryList, error) { 107 defaultRulesClient, err := GetDefaultNsgRulesClientE(subscriptionID) 108 if err != nil { 109 return NsgRuleSummaryList{}, err 110 } 111 112 // Get a client instance 113 customRulesClient, err := GetCustomNsgRulesClientE(subscriptionID) 114 if err != nil { 115 return NsgRuleSummaryList{}, err 116 } 117 118 // Read all default (platform) rules. 119 defaultRuleList, err := defaultRulesClient.ListComplete(context.Background(), resourceGroupName, nsgName) 120 if err != nil { 121 return NsgRuleSummaryList{}, err 122 } 123 124 // Read any custom (user provided) rules 125 customRuleList, err := customRulesClient.ListComplete(context.Background(), resourceGroupName, nsgName) 126 if err != nil { 127 return NsgRuleSummaryList{}, err 128 } 129 130 // Convert the default list to our summary type 131 boundDefaultRules, err := bindRuleList(defaultRuleList) 132 if err != nil { 133 return NsgRuleSummaryList{}, err 134 } 135 136 // Convert the custom list to our summary type 137 boundCustomRules, err := bindRuleList(customRuleList) 138 if err != nil { 139 return NsgRuleSummaryList{}, err 140 } 141 142 // Join the summarized lists and wrap in NsgRuleSummaryList struct 143 allRules := append(boundDefaultRules, boundCustomRules...) 144 ruleList := NsgRuleSummaryList{} 145 ruleList.SummarizedRules = allRules 146 return ruleList, nil 147 } 148 149 // bindRuleList takes a raw list of security rules from the SDK and converts them into a string-based 150 // summary struct. 151 func bindRuleList(source network.SecurityRuleListResultIterator) ([]NsgRuleSummary, error) { 152 rules := make([]NsgRuleSummary, 0) 153 for source.NotDone() { 154 v := source.Value() 155 rules = append(rules, convertToNsgRuleSummary(v.Name, v.SecurityRulePropertiesFormat)) 156 err := source.NextWithContext(context.Background()) 157 if err != nil { 158 return []NsgRuleSummary{}, err 159 } 160 } 161 return rules, nil 162 } 163 164 // convertToNsgRuleSummary converts the raw SDK security rule type into a summarized struct, flattening the 165 // rules properties and name into a single, string-based struct. 166 func convertToNsgRuleSummary(name *string, rule *network.SecurityRulePropertiesFormat) NsgRuleSummary { 167 summary := NsgRuleSummary{} 168 summary.Description = safePtrToString(rule.Description) 169 summary.Name = safePtrToString(name) 170 summary.Protocol = string(rule.Protocol) 171 summary.SourcePortRange = safePtrToString(rule.SourcePortRange) 172 summary.DestinationPortRange = safePtrToString(rule.DestinationPortRange) 173 summary.SourceAddressPrefix = safePtrToString(rule.SourceAddressPrefix) 174 summary.DestinationAddressPrefix = safePtrToString(rule.DestinationAddressPrefix) 175 summary.Access = string(rule.Access) 176 summary.Priority = safePtrToInt32(rule.Priority) 177 summary.Direction = string(rule.Direction) 178 return summary 179 } 180 181 // FindRuleByName looks for a matching rule by name within the current collection of rules. 182 func (summarizedRules *NsgRuleSummaryList) FindRuleByName(name string) NsgRuleSummary { 183 for _, r := range summarizedRules.SummarizedRules { 184 if r.Name == name { 185 return r 186 } 187 } 188 189 return NsgRuleSummary{} 190 } 191 192 // AllowsDestinationPort checks to see if the rule allows a specific destination port. This is helpful when verifying 193 // that a given rule is configured properly for a given port. 194 func (summarizedRule *NsgRuleSummary) AllowsDestinationPort(t *testing.T, port string) bool { 195 allowed, err := portRangeAllowsPort(summarizedRule.DestinationPortRange, port) 196 assert.NoError(t, err) 197 return allowed && (summarizedRule.Access == "Allow") 198 } 199 200 // AllowsSourcePort checks to see if the rule allows a specific source port. This is helpful when verifying 201 // that a given rule is configured properly for a given port. 202 func (summarizedRule *NsgRuleSummary) AllowsSourcePort(t *testing.T, port string) bool { 203 allowed, err := portRangeAllowsPort(summarizedRule.SourcePortRange, port) 204 assert.NoError(t, err) 205 return allowed && (summarizedRule.Access == "Allow") 206 } 207 208 // portRangeAllowsPort is the internal impelmentation of AllowsSourcePort and AllowsDestinationPort. 209 func portRangeAllowsPort(portRange string, port string) (bool, error) { 210 if portRange == "*" { 211 return true, nil 212 } 213 214 // Decode the provided port range 215 low, high, parseErr := parsePortRangeString(portRange) 216 if parseErr != nil { 217 return false, parseErr 218 } 219 220 // Decode user-provided port 221 portAsInt, parseErr := strconv.ParseInt(port, 10, 16) 222 if (parseErr != nil) && (port != "*") { 223 return false, parseErr 224 } 225 226 // If the user wants to check "all", make sure we parsed input range to include all ports. 227 if (port == "*") && (low == 0) && (high == 65535) { 228 return true, nil 229 } 230 231 // Evaluate and return 232 return ((uint16(portAsInt) >= low) && (uint16(portAsInt) <= high)), nil 233 } 234 235 // parsePortRangeString decodes a range string ("2-100") or a single digit ("22") and returns 236 // a tuple in [low, hi] form. Note that if a single digit is supplied, both members of the 237 // return tuple will be the same value (e.g., "22" returns (22, 22)) 238 func parsePortRangeString(rangeString string) (uint16, uint16, error) { 239 // An asterisk means all ports 240 if rangeString == "*" { 241 return uint16(0), uint16(65535), nil 242 } 243 244 // Check for range string that contains hyphen separator 245 if !strings.Contains(rangeString, "-") { 246 val, parseErr := strconv.ParseInt(rangeString, 10, 16) 247 if parseErr != nil { 248 return 0, 0, parseErr 249 } 250 return uint16(val), uint16(val), nil 251 } 252 253 // Split the rang into parts and validate 254 parts := strings.Split(rangeString, "-") 255 if len(parts) != 2 { 256 return 0, 0, fmt.Errorf("Invalid port range specified; must be of the format '{low port}-{high port}'") 257 } 258 259 // Assume the low port is listed first; parse it 260 lowVal, parseErr := strconv.ParseInt(parts[0], 10, 16) 261 if parseErr != nil { 262 return 0, 0, parseErr 263 } 264 265 // Assume the hi port is listed first; parse it 266 highVal, parseErr := strconv.ParseInt(parts[1], 10, 16) 267 if parseErr != nil { 268 return 0, 0, parseErr 269 } 270 271 // Normalize ordering in the case that low and hi were reversed. 272 // This should _never_ happen, as the Azure API's won't allow it, but 273 // we shouldn't fail if it's the case. 274 if lowVal > highVal { 275 temp := lowVal 276 lowVal = highVal 277 highVal = temp 278 } 279 280 // Return values 281 return uint16(lowVal), uint16(highVal), nil 282 }