
     1  package azure
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"strconv"
     7  	"strings"
     8  	"testing"
    10  	""
    11  	""
    12  	""
    13  )
    15  // NsgRuleSummaryList holds a colleciton of NsgRuleSummary rules
    16  type NsgRuleSummaryList struct {
    17  	SummarizedRules []NsgRuleSummary
    18  }
    20  // NsgRuleSummary is a string-based (non-pointer) summary of an NSG rule with several helper methods attached
    21  // to help with verification of rule configuratoin.
    22  type NsgRuleSummary struct {
    23  	Name                     string
    24  	Description              string
    25  	Protocol                 string
    26  	SourcePortRange          string
    27  	DestinationPortRange     string
    28  	SourceAddressPrefix      string
    29  	DestinationAddressPrefix string
    30  	Access                   string
    31  	Priority                 int32
    32  	Direction                string
    33  }
    35  // GetDefaultNsgRulesClient returns a rules client which can be used to read the list of *default* security rules
    36  // defined on an network security group. Note that the "default" rules are those provided implicitly
    37  // by the Azure platform.
    38  // This function would fail the test if there is an error.
    39  func GetDefaultNsgRulesClient(t *testing.T, subscriptionID string) network.DefaultSecurityRulesClient {
    40  	client, err := GetDefaultNsgRulesClientE(subscriptionID)
    41  	require.NoError(t, err)
    42  	return client
    43  }
    45  // GetDefaultNsgRulesClientE returns a rules client which can be used to read the list of *default* security rules
    46  // defined on an network security group. Note that the "default" rules are those provided implicitly
    47  // by the Azure platform.
    48  func GetDefaultNsgRulesClientE(subscriptionID string) (network.DefaultSecurityRulesClient, error) {
    49  	// Get new default client from client factory
    50  	nsgClient, err := CreateNsgDefaultRulesClientE(subscriptionID)
    51  	if err != nil {
    52  		return network.DefaultSecurityRulesClient{}, err
    53  	}
    55  	// Get an authorizer
    56  	auth, err := NewAuthorizer()
    57  	if err != nil {
    58  		return network.DefaultSecurityRulesClient{}, err
    59  	}
    61  	nsgClient.Authorizer = *auth
    62  	return *nsgClient, nil
    63  }
    65  // GetCustomNsgRulesClient returns a rules client which can be used to read the list of *custom* security rules
    66  // defined on an network security group. Note that the "custom" rules are those defined by
    67  // end users.
    68  // This function would fail the test if there is an error.
    69  func GetCustomNsgRulesClient(t *testing.T, subscriptionID string) network.SecurityRulesClient {
    70  	client, err := GetCustomNsgRulesClientE(subscriptionID)
    71  	require.NoError(t, err)
    72  	return client
    73  }
    75  // GetCustomNsgRulesClientE returns a rules client which can be used to read the list of *custom* security rules
    76  // defined on an network security group. Note that the "custom" rules are those defined by
    77  // end users.
    78  func GetCustomNsgRulesClientE(subscriptionID string) (network.SecurityRulesClient, error) {
    79  	// Get new custom rules client from client factory
    80  	nsgClient, err := CreateNsgCustomRulesClientE(subscriptionID)
    81  	if err != nil {
    82  		return network.SecurityRulesClient{}, err
    83  	}
    85  	// Get an authorizer
    86  	auth, err := NewAuthorizer()
    87  	if err != nil {
    88  		return network.SecurityRulesClient{}, err
    89  	}
    91  	nsgClient.Authorizer = *auth
    92  	return *nsgClient, nil
    93  }
    95  // GetAllNSGRules returns an NsgRuleSummaryList instance containing the combined "default" and "custom" rules from a network
    96  // security group.
    97  // This function would fail the test if there is an error.
    98  func GetAllNSGRules(t *testing.T, resourceGroupName, nsgName, subscriptionID string) NsgRuleSummaryList {
    99  	results, err := GetAllNSGRulesE(resourceGroupName, nsgName, subscriptionID)
   100  	require.NoError(t, err)
   101  	return results
   102  }
   104  // GetAllNSGRulesE returns an NsgRuleSummaryList instance containing the combined "default" and "custom" rules from a network
   105  // security group.
   106  func GetAllNSGRulesE(resourceGroupName, nsgName, subscriptionID string) (NsgRuleSummaryList, error) {
   107  	defaultRulesClient, err := GetDefaultNsgRulesClientE(subscriptionID)
   108  	if err != nil {
   109  		return NsgRuleSummaryList{}, err
   110  	}
   112  	// Get a client instance
   113  	customRulesClient, err := GetCustomNsgRulesClientE(subscriptionID)
   114  	if err != nil {
   115  		return NsgRuleSummaryList{}, err
   116  	}
   118  	// Read all default (platform) rules.
   119  	defaultRuleList, err := defaultRulesClient.ListComplete(context.Background(), resourceGroupName, nsgName)
   120  	if err != nil {
   121  		return NsgRuleSummaryList{}, err
   122  	}
   124  	// Read any custom (user provided) rules
   125  	customRuleList, err := customRulesClient.ListComplete(context.Background(), resourceGroupName, nsgName)
   126  	if err != nil {
   127  		return NsgRuleSummaryList{}, err
   128  	}
   130  	// Convert the default list to our summary type
   131  	boundDefaultRules, err := bindRuleList(defaultRuleList)
   132  	if err != nil {
   133  		return NsgRuleSummaryList{}, err
   134  	}
   136  	// Convert the custom list to our summary type
   137  	boundCustomRules, err := bindRuleList(customRuleList)
   138  	if err != nil {
   139  		return NsgRuleSummaryList{}, err
   140  	}
   142  	// Join the summarized lists and wrap in NsgRuleSummaryList struct
   143  	allRules := append(boundDefaultRules, boundCustomRules...)
   144  	ruleList := NsgRuleSummaryList{}
   145  	ruleList.SummarizedRules = allRules
   146  	return ruleList, nil
   147  }
   149  // bindRuleList takes a raw list of security rules from the SDK and converts them into a string-based
   150  // summary struct.
   151  func bindRuleList(source network.SecurityRuleListResultIterator) ([]NsgRuleSummary, error) {
   152  	rules := make([]NsgRuleSummary, 0)
   153  	for source.NotDone() {
   154  		v := source.Value()
   155  		rules = append(rules, convertToNsgRuleSummary(v.Name, v.SecurityRulePropertiesFormat))
   156  		err := source.NextWithContext(context.Background())
   157  		if err != nil {
   158  			return []NsgRuleSummary{}, err
   159  		}
   160  	}
   161  	return rules, nil
   162  }
   164  // convertToNsgRuleSummary converts the raw SDK security rule type into a summarized struct, flattening the
   165  // rules properties and name into a single, string-based struct.
   166  func convertToNsgRuleSummary(name *string, rule *network.SecurityRulePropertiesFormat) NsgRuleSummary {
   167  	summary := NsgRuleSummary{}
   168  	summary.Description = safePtrToString(rule.Description)
   169  	summary.Name = safePtrToString(name)
   170  	summary.Protocol = string(rule.Protocol)
   171  	summary.SourcePortRange = safePtrToString(rule.SourcePortRange)
   172  	summary.DestinationPortRange = safePtrToString(rule.DestinationPortRange)
   173  	summary.SourceAddressPrefix = safePtrToString(rule.SourceAddressPrefix)
   174  	summary.DestinationAddressPrefix = safePtrToString(rule.DestinationAddressPrefix)
   175  	summary.Access = string(rule.Access)
   176  	summary.Priority = safePtrToInt32(rule.Priority)
   177  	summary.Direction = string(rule.Direction)
   178  	return summary
   179  }
   181  // FindRuleByName looks for a matching rule by name within the current collection of rules.
   182  func (summarizedRules *NsgRuleSummaryList) FindRuleByName(name string) NsgRuleSummary {
   183  	for _, r := range summarizedRules.SummarizedRules {
   184  		if r.Name == name {
   185  			return r
   186  		}
   187  	}
   189  	return NsgRuleSummary{}
   190  }
   192  // AllowsDestinationPort checks to see if the rule allows a specific destination port. This is helpful when verifying
   193  // that a given rule is configured properly for a given port.
   194  func (summarizedRule *NsgRuleSummary) AllowsDestinationPort(t *testing.T, port string) bool {
   195  	allowed, err := portRangeAllowsPort(summarizedRule.DestinationPortRange, port)
   196  	assert.NoError(t, err)
   197  	return allowed && (summarizedRule.Access == "Allow")
   198  }
   200  // AllowsSourcePort checks to see if the rule allows a specific source port. This is helpful when verifying
   201  // that a given rule is configured properly for a given port.
   202  func (summarizedRule *NsgRuleSummary) AllowsSourcePort(t *testing.T, port string) bool {
   203  	allowed, err := portRangeAllowsPort(summarizedRule.SourcePortRange, port)
   204  	assert.NoError(t, err)
   205  	return allowed && (summarizedRule.Access == "Allow")
   206  }
   208  // portRangeAllowsPort is the internal impelmentation of AllowsSourcePort and AllowsDestinationPort.
   209  func portRangeAllowsPort(portRange string, port string) (bool, error) {
   210  	if portRange == "*" {
   211  		return true, nil
   212  	}
   214  	// Decode the provided port range
   215  	low, high, parseErr := parsePortRangeString(portRange)
   216  	if parseErr != nil {
   217  		return false, parseErr
   218  	}
   220  	// Decode user-provided port
   221  	portAsInt, parseErr := strconv.ParseInt(port, 10, 16)
   222  	if (parseErr != nil) && (port != "*") {
   223  		return false, parseErr
   224  	}
   226  	// If the user wants to check "all", make sure we parsed input range to include all ports.
   227  	if (port == "*") && (low == 0) && (high == 65535) {
   228  		return true, nil
   229  	}
   231  	// Evaluate and return
   232  	return ((uint16(portAsInt) >= low) && (uint16(portAsInt) <= high)), nil
   233  }
   235  // parsePortRangeString decodes a range string ("2-100") or a single digit ("22") and returns
   236  // a tuple in [low, hi] form. Note that if a single digit is supplied, both members of the
   237  // return tuple will be the same value (e.g., "22" returns (22, 22))
   238  func parsePortRangeString(rangeString string) (uint16, uint16, error) {
   239  	// An asterisk means all ports
   240  	if rangeString == "*" {
   241  		return uint16(0), uint16(65535), nil
   242  	}
   244  	// Check for range string that contains hyphen separator
   245  	if !strings.Contains(rangeString, "-") {
   246  		val, parseErr := strconv.ParseInt(rangeString, 10, 16)
   247  		if parseErr != nil {
   248  			return 0, 0, parseErr
   249  		}
   250  		return uint16(val), uint16(val), nil
   251  	}
   253  	// Split the rang into parts and validate
   254  	parts := strings.Split(rangeString, "-")
   255  	if len(parts) != 2 {
   256  		return 0, 0, fmt.Errorf("Invalid port range specified; must be of the format '{low port}-{high port}'")
   257  	}
   259  	// Assume the low port is listed first; parse it
   260  	lowVal, parseErr := strconv.ParseInt(parts[0], 10, 16)
   261  	if parseErr != nil {
   262  		return 0, 0, parseErr
   263  	}
   265  	// Assume the hi port is listed first; parse it
   266  	highVal, parseErr := strconv.ParseInt(parts[1], 10, 16)
   267  	if parseErr != nil {
   268  		return 0, 0, parseErr
   269  	}
   271  	// Normalize ordering in the case that low and hi were reversed.
   272  	// This should _never_ happen, as the Azure API's won't allow it, but
   273  	// we shouldn't fail if it's the case.
   274  	if lowVal > highVal {
   275  		temp := lowVal
   276  		lowVal = highVal
   277  		highVal = temp
   278  	}
   280  	// Return values
   281  	return uint16(lowVal), uint16(highVal), nil
   282  }