github.com/hashicorp/vault/sdk@v0.11.0/helper/policyutil/policyutil.go (about)

     1  // Copyright (c) HashiCorp, Inc.
     2  // SPDX-License-Identifier: MPL-2.0
     3  
     4  package policyutil
     5  
     6  import (
     7  	"sort"
     8  	"strings"
     9  
    10  	"github.com/hashicorp/go-secure-stdlib/strutil"
    11  )
    12  
    13  const (
    14  	AddDefaultPolicy      = true
    15  	DoNotAddDefaultPolicy = false
    16  )
    17  
    18  // ParsePolicies parses a comma-delimited list of policies.
    19  // The resulting collection will have no duplicate elements.
    20  // If 'root' policy was present in the list of policies, then
    21  // all other policies will be ignored, the result will contain
    22  // just the 'root'. In cases where 'root' is not present, if
    23  // 'default' policy is not already present, it will be added.
    24  func ParsePolicies(policiesRaw interface{}) []string {
    25  	if policiesRaw == nil {
    26  		return []string{"default"}
    27  	}
    28  
    29  	var policies []string
    30  	switch policiesRaw.(type) {
    31  	case string:
    32  		if policiesRaw.(string) == "" {
    33  			return []string{}
    34  		}
    35  		policies = strings.Split(policiesRaw.(string), ",")
    36  	case []string:
    37  		policies = policiesRaw.([]string)
    38  	}
    39  
    40  	return SanitizePolicies(policies, false)
    41  }
    42  
    43  // SanitizePolicies performs the common input validation tasks
    44  // which are performed on the list of policies across Vault.
    45  // The resulting collection will have no duplicate elements.
    46  // If 'root' policy was present in the list of policies, then
    47  // all other policies will be ignored, the result will contain
    48  // just the 'root'. In cases where 'root' is not present, if
    49  // 'default' policy is not already present, it will be added
    50  // if addDefault is set to true.
    51  func SanitizePolicies(policies []string, addDefault bool) []string {
    52  	defaultFound := false
    53  	for i, p := range policies {
    54  		policies[i] = strings.ToLower(strings.TrimSpace(p))
    55  		// Eliminate unnamed policies.
    56  		if policies[i] == "" {
    57  			continue
    58  		}
    59  
    60  		// If 'root' policy is present, ignore all other policies.
    61  		if policies[i] == "root" {
    62  			policies = []string{"root"}
    63  			defaultFound = true
    64  			break
    65  		}
    66  		if policies[i] == "default" {
    67  			defaultFound = true
    68  		}
    69  	}
    70  
    71  	// Always add 'default' except only if the policies contain 'root'.
    72  	if addDefault && (len(policies) == 0 || !defaultFound) {
    73  		policies = append(policies, "default")
    74  	}
    75  
    76  	return strutil.RemoveDuplicates(policies, true)
    77  }
    78  
    79  // EquivalentPolicies checks whether the given policy sets are equivalent, as in,
    80  // they contain the same values. The benefit of this method is that it leaves
    81  // the "default" policy out of its comparisons as it may be added later by core
    82  // after a set of policies has been saved by a backend.
    83  func EquivalentPolicies(a, b []string) bool {
    84  	switch {
    85  	case a == nil && b == nil:
    86  		return true
    87  	case a == nil && len(b) == 1 && b[0] == "default":
    88  		return true
    89  	case b == nil && len(a) == 1 && a[0] == "default":
    90  		return true
    91  	case a == nil || b == nil:
    92  		return false
    93  	}
    94  
    95  	// First we'll build maps to ensure unique values and filter default
    96  	mapA := map[string]bool{}
    97  	mapB := map[string]bool{}
    98  	for _, keyA := range a {
    99  		if keyA == "default" {
   100  			continue
   101  		}
   102  		mapA[keyA] = true
   103  	}
   104  	for _, keyB := range b {
   105  		if keyB == "default" {
   106  			continue
   107  		}
   108  		mapB[keyB] = true
   109  	}
   110  
   111  	// Now we'll build our checking slices
   112  	var sortedA, sortedB []string
   113  	for keyA := range mapA {
   114  		sortedA = append(sortedA, keyA)
   115  	}
   116  	for keyB := range mapB {
   117  		sortedB = append(sortedB, keyB)
   118  	}
   119  	sort.Strings(sortedA)
   120  	sort.Strings(sortedB)
   121  
   122  	// Finally, compare
   123  	if len(sortedA) != len(sortedB) {
   124  		return false
   125  	}
   126  
   127  	for i := range sortedA {
   128  		if sortedA[i] != sortedB[i] {
   129  			return false
   130  		}
   131  	}
   132  
   133  	return true
   134  }