github.com/rajeev159/opa@v0.45.0/topdown/aggregates.go (about) 1 // Copyright 2016 The OPA Authors. All rights reserved. 2 // Use of this source code is governed by an Apache2 3 // license that can be found in the LICENSE file. 4 5 package topdown 6 7 import ( 8 "math/big" 9 10 "github.com/open-policy-agent/opa/ast" 11 "github.com/open-policy-agent/opa/topdown/builtins" 12 ) 13 14 func builtinCount(a ast.Value) (ast.Value, error) { 15 switch a := a.(type) { 16 case *ast.Array: 17 return ast.IntNumberTerm(a.Len()).Value, nil 18 case ast.Object: 19 return ast.IntNumberTerm(a.Len()).Value, nil 20 case ast.Set: 21 return ast.IntNumberTerm(a.Len()).Value, nil 22 case ast.String: 23 return ast.IntNumberTerm(len([]rune(a))).Value, nil 24 } 25 return nil, builtins.NewOperandTypeErr(1, a, "array", "object", "set", "string") 26 } 27 28 func builtinSum(a ast.Value) (ast.Value, error) { 29 switch a := a.(type) { 30 case *ast.Array: 31 sum := big.NewFloat(0) 32 err := a.Iter(func(x *ast.Term) error { 33 n, ok := x.Value.(ast.Number) 34 if !ok { 35 return builtins.NewOperandElementErr(1, a, x.Value, "number") 36 } 37 sum = new(big.Float).Add(sum, builtins.NumberToFloat(n)) 38 return nil 39 }) 40 return builtins.FloatToNumber(sum), err 41 case ast.Set: 42 sum := big.NewFloat(0) 43 err := a.Iter(func(x *ast.Term) error { 44 n, ok := x.Value.(ast.Number) 45 if !ok { 46 return builtins.NewOperandElementErr(1, a, x.Value, "number") 47 } 48 sum = new(big.Float).Add(sum, builtins.NumberToFloat(n)) 49 return nil 50 }) 51 return builtins.FloatToNumber(sum), err 52 } 53 return nil, builtins.NewOperandTypeErr(1, a, "set", "array") 54 } 55 56 func builtinProduct(a ast.Value) (ast.Value, error) { 57 switch a := a.(type) { 58 case *ast.Array: 59 product := big.NewFloat(1) 60 err := a.Iter(func(x *ast.Term) error { 61 n, ok := x.Value.(ast.Number) 62 if !ok { 63 return builtins.NewOperandElementErr(1, a, x.Value, "number") 64 } 65 product = new(big.Float).Mul(product, builtins.NumberToFloat(n)) 66 return nil 67 }) 68 return builtins.FloatToNumber(product), err 69 case ast.Set: 70 product := big.NewFloat(1) 71 err := a.Iter(func(x *ast.Term) error { 72 n, ok := x.Value.(ast.Number) 73 if !ok { 74 return builtins.NewOperandElementErr(1, a, x.Value, "number") 75 } 76 product = new(big.Float).Mul(product, builtins.NumberToFloat(n)) 77 return nil 78 }) 79 return builtins.FloatToNumber(product), err 80 } 81 return nil, builtins.NewOperandTypeErr(1, a, "set", "array") 82 } 83 84 func builtinMax(a ast.Value) (ast.Value, error) { 85 switch a := a.(type) { 86 case *ast.Array: 87 if a.Len() == 0 { 88 return nil, BuiltinEmpty{} 89 } 90 var max = ast.Value(ast.Null{}) 91 a.Foreach(func(x *ast.Term) { 92 if ast.Compare(max, x.Value) <= 0 { 93 max = x.Value 94 } 95 }) 96 return max, nil 97 case ast.Set: 98 if a.Len() == 0 { 99 return nil, BuiltinEmpty{} 100 } 101 max, err := a.Reduce(ast.NullTerm(), func(max *ast.Term, elem *ast.Term) (*ast.Term, error) { 102 if ast.Compare(max, elem) <= 0 { 103 return elem, nil 104 } 105 return max, nil 106 }) 107 return max.Value, err 108 } 109 110 return nil, builtins.NewOperandTypeErr(1, a, "set", "array") 111 } 112 113 func builtinMin(a ast.Value) (ast.Value, error) { 114 switch a := a.(type) { 115 case *ast.Array: 116 if a.Len() == 0 { 117 return nil, BuiltinEmpty{} 118 } 119 min := a.Elem(0).Value 120 a.Foreach(func(x *ast.Term) { 121 if ast.Compare(min, x.Value) >= 0 { 122 min = x.Value 123 } 124 }) 125 return min, nil 126 case ast.Set: 127 if a.Len() == 0 { 128 return nil, BuiltinEmpty{} 129 } 130 min, err := a.Reduce(ast.NullTerm(), func(min *ast.Term, elem *ast.Term) (*ast.Term, error) { 131 // The null term is considered to be less than any other term, 132 // so in order for min of a set to make sense, we need to check 133 // for it. 134 if min.Value.Compare(ast.Null{}) == 0 { 135 return elem, nil 136 } 137 138 if ast.Compare(min, elem) >= 0 { 139 return elem, nil 140 } 141 return min, nil 142 }) 143 return min.Value, err 144 } 145 146 return nil, builtins.NewOperandTypeErr(1, a, "set", "array") 147 } 148 149 func builtinSort(a ast.Value) (ast.Value, error) { 150 switch a := a.(type) { 151 case *ast.Array: 152 return a.Sorted(), nil 153 case ast.Set: 154 return a.Sorted(), nil 155 } 156 return nil, builtins.NewOperandTypeErr(1, a, "set", "array") 157 } 158 159 func builtinAll(a ast.Value) (ast.Value, error) { 160 switch val := a.(type) { 161 case ast.Set: 162 res := true 163 match := ast.BooleanTerm(true) 164 val.Until(func(term *ast.Term) bool { 165 if !match.Equal(term) { 166 res = false 167 return true 168 } 169 return false 170 }) 171 return ast.Boolean(res), nil 172 case *ast.Array: 173 res := true 174 match := ast.BooleanTerm(true) 175 val.Until(func(term *ast.Term) bool { 176 if !match.Equal(term) { 177 res = false 178 return true 179 } 180 return false 181 }) 182 return ast.Boolean(res), nil 183 default: 184 return nil, builtins.NewOperandTypeErr(1, a, "array", "set") 185 } 186 } 187 188 func builtinAny(a ast.Value) (ast.Value, error) { 189 switch val := a.(type) { 190 case ast.Set: 191 res := val.Len() > 0 && val.Contains(ast.BooleanTerm(true)) 192 return ast.Boolean(res), nil 193 case *ast.Array: 194 res := false 195 match := ast.BooleanTerm(true) 196 val.Until(func(term *ast.Term) bool { 197 if match.Equal(term) { 198 res = true 199 return true 200 } 201 return false 202 }) 203 return ast.Boolean(res), nil 204 default: 205 return nil, builtins.NewOperandTypeErr(1, a, "array", "set") 206 } 207 } 208 209 func builtinMember(_ BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error { 210 containee := args[0] 211 switch c := args[1].Value.(type) { 212 case ast.Set: 213 return iter(ast.BooleanTerm(c.Contains(containee))) 214 case *ast.Array: 215 ret := false 216 c.Until(func(v *ast.Term) bool { 217 if v.Value.Compare(containee.Value) == 0 { 218 ret = true 219 } 220 return ret 221 }) 222 return iter(ast.BooleanTerm(ret)) 223 case ast.Object: 224 ret := false 225 c.Until(func(_, v *ast.Term) bool { 226 if v.Value.Compare(containee.Value) == 0 { 227 ret = true 228 } 229 return ret 230 }) 231 return iter(ast.BooleanTerm(ret)) 232 } 233 return iter(ast.BooleanTerm(false)) 234 } 235 236 func builtinMemberWithKey(_ BuiltinContext, args []*ast.Term, iter func(*ast.Term) error) error { 237 key, val := args[0], args[1] 238 switch c := args[2].Value.(type) { 239 case interface{ Get(*ast.Term) *ast.Term }: 240 ret := false 241 if act := c.Get(key); act != nil { 242 ret = act.Value.Compare(val.Value) == 0 243 } 244 return iter(ast.BooleanTerm(ret)) 245 } 246 return iter(ast.BooleanTerm(false)) 247 } 248 249 func init() { 250 RegisterFunctionalBuiltin1(ast.Count.Name, builtinCount) 251 RegisterFunctionalBuiltin1(ast.Sum.Name, builtinSum) 252 RegisterFunctionalBuiltin1(ast.Product.Name, builtinProduct) 253 RegisterFunctionalBuiltin1(ast.Max.Name, builtinMax) 254 RegisterFunctionalBuiltin1(ast.Min.Name, builtinMin) 255 RegisterFunctionalBuiltin1(ast.Sort.Name, builtinSort) 256 RegisterFunctionalBuiltin1(ast.Any.Name, builtinAny) 257 RegisterFunctionalBuiltin1(ast.All.Name, builtinAll) 258 RegisterBuiltinFunc(ast.Member.Name, builtinMember) 259 RegisterBuiltinFunc(ast.MemberWithKey.Name, builtinMemberWithKey) 260 }