github.com/rajeev159/opa@v0.45.0/topdown/aggregates.go (about)

     1  // Copyright 2016 The OPA Authors.  All rights reserved.
     2  // Use of this source code is governed by an Apache2
     3  // license that can be found in the LICENSE file.
     4  
     5  package topdown
     6  
     7  import (
     8  	"math/big"
     9  
    10  	"github.com/open-policy-agent/opa/ast"
    11  	"github.com/open-policy-agent/opa/topdown/builtins"
    12  )
    13  
    14  func builtinCount(a ast.Value) (ast.Value, error) {
    15  	switch a := a.(type) {
    16  	case *ast.Array:
    17  		return ast.IntNumberTerm(a.Len()).Value, nil
    18  	case ast.Object:
    19  		return ast.IntNumberTerm(a.Len()).Value, nil
    20  	case ast.Set:
    21  		return ast.IntNumberTerm(a.Len()).Value, nil
    22  	case ast.String:
    23  		return ast.IntNumberTerm(len([]rune(a))).Value, nil
    24  	}
    25  	return nil, builtins.NewOperandTypeErr(1, a, "array", "object", "set", "string")
    26  }
    27  
    28  func builtinSum(a ast.Value) (ast.Value, error) {
    29  	switch a := a.(type) {
    30  	case *ast.Array:
    31  		sum := big.NewFloat(0)
    32  		err := a.Iter(func(x *ast.Term) error {
    33  			n, ok := x.Value.(ast.Number)
    34  			if !ok {
    35  				return builtins.NewOperandElementErr(1, a, x.Value, "number")
    36  			}
    37  			sum = new(big.Float).Add(sum, builtins.NumberToFloat(n))
    38  			return nil
    39  		})
    40  		return builtins.FloatToNumber(sum), err
    41  	case ast.Set:
    42  		sum := big.NewFloat(0)
    43  		err := a.Iter(func(x *ast.Term) error {
    44  			n, ok := x.Value.(ast.Number)
    45  			if !ok {
    46  				return builtins.NewOperandElementErr(1, a, x.Value, "number")
    47  			}
    48  			sum = new(big.Float).Add(sum, builtins.NumberToFloat(n))
    49  			return nil
    50  		})
    51  		return builtins.FloatToNumber(sum), err
    52  	}
    53  	return nil, builtins.NewOperandTypeErr(1, a, "set", "array")
    54  }
    55  
    56  func builtinProduct(a ast.Value) (ast.Value, error) {
    57  	switch a := a.(type) {
    58  	case *ast.Array:
    59  		product := big.NewFloat(1)
    60  		err := a.Iter(func(x *ast.Term) error {
    61  			n, ok := x.Value.(ast.Number)
    62  			if !ok {
    63  				return builtins.NewOperandElementErr(1, a, x.Value, "number")
    64  			}
    65  			product = new(big.Float).Mul(product, builtins.NumberToFloat(n))
    66  			return nil
    67  		})
    68  		return builtins.FloatToNumber(product), err
    69  	case ast.Set:
    70  		product := big.NewFloat(1)
    71  		err := a.Iter(func(x *ast.Term) error {
    72  			n, ok := x.Value.(ast.Number)
    73  			if !ok {
    74  				return builtins.NewOperandElementErr(1, a, x.Value, "number")
    75  			}
    76  			product = new(big.Float).Mul(product, builtins.NumberToFloat(n))
    77  			return nil
    78  		})
    79  		return builtins.FloatToNumber(product), err
    80  	}
    81  	return nil, builtins.NewOperandTypeErr(1, a, "set", "array")
    82  }
    83  
    84  func builtinMax(a ast.Value) (ast.Value, error) {
    85  	switch a := a.(type) {
    86  	case *ast.Array:
    87  		if a.Len() == 0 {
    88  			return nil, BuiltinEmpty{}
    89  		}
    90  		var max = ast.Value(ast.Null{})
    91  		a.Foreach(func(x *ast.Term) {
    92  			if ast.Compare(max, x.Value) <= 0 {
    93  				max = x.Value
    94  			}
    95  		})
    96  		return max, nil
    97  	case ast.Set:
    98  		if a.Len() == 0 {
    99  			return nil, BuiltinEmpty{}
   100  		}
   101  		max, err := a.Reduce(ast.NullTerm(), func(max *ast.Term, elem *ast.Term) (*ast.Term, error) {
   102  			if ast.Compare(max, elem) <= 0 {
   103  				return elem, nil
   104  			}
   105  			return max, nil
   106  		})
   107  		return max.Value, err
   108  	}
   109  
   110  	return nil, builtins.NewOperandTypeErr(1, a, "set", "array")
   111  }
   112  
   113  func builtinMin(a ast.Value) (ast.Value, error) {
   114  	switch a := a.(type) {
   115  	case *ast.Array:
   116  		if a.Len() == 0 {
   117  			return nil, BuiltinEmpty{}
   118  		}
   119  		min := a.Elem(0).Value
   120  		a.Foreach(func(x *ast.Term) {
   121  			if ast.Compare(min, x.Value) >= 0 {
   122  				min = x.Value
   123  			}
   124  		})
   125  		return min, nil
   126  	case ast.Set:
   127  		if a.Len() == 0 {
   128  			return nil, BuiltinEmpty{}
   129  		}
   130  		min, err := a.Reduce(ast.NullTerm(), func(min *ast.Term, elem *ast.Term) (*ast.Term, error) {
   131  			// The null term is considered to be less than any other term,
   132  			// so in order for min of a set to make sense, we need to check
   133  			// for it.
   134  			if min.Value.Compare(ast.Null{}) == 0 {
   135  				return elem, nil
   136  			}
   137  
   138  			if ast.Compare(min, elem) >= 0 {
   139  				return elem, nil
   140  			}
   141  			return min, nil
   142  		})
   143  		return min.Value, err
   144  	}
   145  
   146  	return nil, builtins.NewOperandTypeErr(1, a, "set", "array")
   147  }
   148  
   149  func builtinSort(a ast.Value) (ast.Value, error) {
   150  	switch a := a.(type) {
   151  	case *ast.Array:
   152  		return a.Sorted(), nil
   153  	case ast.Set:
   154  		return a.Sorted(), nil
   155  	}
   156  	return nil, builtins.NewOperandTypeErr(1, a, "set", "array")
   157  }
   158  
   159  func builtinAll(a ast.Value) (ast.Value, error) {
   160  	switch val := a.(type) {
   161  	case ast.Set:
   162  		res := true
   163  		match := ast.BooleanTerm(true)
   164  		val.Until(func(term *ast.Term) bool {
   165  			if !match.Equal(term) {
   166  				res = false
   167  				return true
   168  			}
   169  			return false
   170  		})
   171  		return ast.Boolean(res), nil
   172  	case *ast.Array:
   173  		res := true
   174  		match := ast.BooleanTerm(true)
   175  		val.Until(func(term *ast.Term) bool {
   176  			if !match.Equal(term) {
   177  				res = false
   178  				return true
   179  			}
   180  			return false
   181  		})
   182  		return ast.Boolean(res), nil
   183  	default:
   184  		return nil, builtins.NewOperandTypeErr(1, a, "array", "set")
   185  	}
   186  }
   187  
   188  func builtinAny(a ast.Value) (ast.Value, error) {
   189  	switch val := a.(type) {
   190  	case ast.Set:
   191  		res := val.Len() > 0 && val.Contains(ast.BooleanTerm(true))
   192  		return ast.Boolean(res), nil
   193  	case *ast.Array:
   194  		res := false
   195  		match := ast.BooleanTerm(true)
   196  		val.Until(func(term *ast.Term) bool {
   197  			if match.Equal(term) {
   198  				res = true
   199  				return true
   200  			}
   201  			return false
   202  		})
   203  		return ast.Boolean(res), nil
   204  	default:
   205  		return nil, builtins.NewOperandTypeErr(1, a, "array", "set")
   206  	}
   207  }
   208  
   209  func builtinMember(_ BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error {
   210  	containee := args[0]
   211  	switch c := args[1].Value.(type) {
   212  	case ast.Set:
   213  		return iter(ast.BooleanTerm(c.Contains(containee)))
   214  	case *ast.Array:
   215  		ret := false
   216  		c.Until(func(v *ast.Term) bool {
   217  			if v.Value.Compare(containee.Value) == 0 {
   218  				ret = true
   219  			}
   220  			return ret
   221  		})
   222  		return iter(ast.BooleanTerm(ret))
   223  	case ast.Object:
   224  		ret := false
   225  		c.Until(func(_, v *ast.Term) bool {
   226  			if v.Value.Compare(containee.Value) == 0 {
   227  				ret = true
   228  			}
   229  			return ret
   230  		})
   231  		return iter(ast.BooleanTerm(ret))
   232  	}
   233  	return iter(ast.BooleanTerm(false))
   234  }
   235  
   236  func builtinMemberWithKey(_ BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error {
   237  	key, val := args[0], args[1]
   238  	switch c := args[2].Value.(type) {
   239  	case interface{ Get(*ast.Term) *ast.Term }:
   240  		ret := false
   241  		if act := c.Get(key); act != nil {
   242  			ret = act.Value.Compare(val.Value) == 0
   243  		}
   244  		return iter(ast.BooleanTerm(ret))
   245  	}
   246  	return iter(ast.BooleanTerm(false))
   247  }
   248  
   249  func init() {
   250  	RegisterFunctionalBuiltin1(ast.Count.Name, builtinCount)
   251  	RegisterFunctionalBuiltin1(ast.Sum.Name, builtinSum)
   252  	RegisterFunctionalBuiltin1(ast.Product.Name, builtinProduct)
   253  	RegisterFunctionalBuiltin1(ast.Max.Name, builtinMax)
   254  	RegisterFunctionalBuiltin1(ast.Min.Name, builtinMin)
   255  	RegisterFunctionalBuiltin1(ast.Sort.Name, builtinSort)
   256  	RegisterFunctionalBuiltin1(ast.Any.Name, builtinAny)
   257  	RegisterFunctionalBuiltin1(ast.All.Name, builtinAll)
   258  	RegisterBuiltinFunc(ast.Member.Name, builtinMember)
   259  	RegisterBuiltinFunc(ast.MemberWithKey.Name, builtinMemberWithKey)
   260  }