k8s.io/apiserver@v0.31.1/pkg/cel/library/cost.go (about) 1 /* 2 Copyright 2022 The Kubernetes Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package library 18 19 import ( 20 "fmt" 21 "math" 22 23 "github.com/google/cel-go/checker" 24 "github.com/google/cel-go/common" 25 "github.com/google/cel-go/common/ast" 26 "github.com/google/cel-go/common/types" 27 "github.com/google/cel-go/common/types/ref" 28 "github.com/google/cel-go/common/types/traits" 29 30 "k8s.io/apiserver/pkg/cel" 31 ) 32 33 // panicOnUnknown makes cost estimate functions panic on unrecognized functions. 34 // This is only set to true for unit tests. 35 var panicOnUnknown = false 36 37 // builtInFunctions is a list of functions used in cost tests that are not handled by CostEstimator. 38 var knownUnhandledFunctions = map[string]bool{ 39 "uint": true, 40 "duration": true, 41 "bytes": true, 42 "timestamp": true, 43 "value": true, 44 "_==_": true, 45 "_&&_": true, 46 "_>_": true, 47 "!_": true, 48 "strings.quote": true, 49 } 50 51 // CostEstimator implements CEL's interpretable.ActualCostEstimator and checker.CostEstimator. 52 type CostEstimator struct { 53 // SizeEstimator provides a CostEstimator.EstimateSize that this CostEstimator will delegate size estimation 54 // calculations to if the size is not well known (i.e. a constant). 55 SizeEstimator checker.CostEstimator 56 } 57 58 const ( 59 // shortest repeatable selector requirement that allocates a values slice is 2 characters: k, 60 selectorLengthToRequirementCount = float64(.5) 61 // the expensive parts to represent each requirement are a struct and a values slice 62 costPerRequirement = float64(common.ListCreateBaseCost + common.StructCreateBaseCost) 63 ) 64 65 // a selector consists of a list of requirements held in a slice 66 var baseSelectorCost = checker.CostEstimate{Min: common.ListCreateBaseCost, Max: common.ListCreateBaseCost} 67 68 func selectorCostEstimate(selectorLength checker.SizeEstimate) checker.CostEstimate { 69 parseCost := selectorLength.MultiplyByCostFactor(common.StringTraversalCostFactor) 70 71 requirementCount := selectorLength.MultiplyByCostFactor(selectorLengthToRequirementCount) 72 requirementCost := requirementCount.MultiplyByCostFactor(costPerRequirement) 73 74 return baseSelectorCost.Add(parseCost).Add(requirementCost) 75 } 76 77 func (l *CostEstimator) CallCost(function, overloadId string, args []ref.Val, result ref.Val) *uint64 { 78 switch function { 79 case "check": 80 // An authorization check has a fixed cost 81 // This cost is set to allow for only two authorization checks per expression 82 cost := uint64(350000) 83 return &cost 84 case "serviceAccount", "path", "group", "resource", "subresource", "namespace", "name", "allowed", "reason", "error", "errored": 85 // All authorization builder and accessor functions have a nominal cost 86 cost := uint64(1) 87 return &cost 88 case "fieldSelector", "labelSelector": 89 // field and label selector parse is a string parse into a structured set of requirements 90 if len(args) >= 2 { 91 selectorLength := actualSize(args[1]) 92 cost := selectorCostEstimate(checker.SizeEstimate{Min: selectorLength, Max: selectorLength}) 93 return &cost.Max 94 } 95 case "isSorted", "sum", "max", "min", "indexOf", "lastIndexOf": 96 var cost uint64 97 if len(args) > 0 { 98 cost += traversalCost(args[0]) // these O(n) operations all cost roughly the cost of a single traversal 99 } 100 return &cost 101 case "url", "lowerAscii", "upperAscii", "substring", "trim": 102 if len(args) >= 1 { 103 cost := uint64(math.Ceil(float64(actualSize(args[0])) * common.StringTraversalCostFactor)) 104 return &cost 105 } 106 case "replace", "split": 107 if len(args) >= 1 { 108 // cost is the traversal plus the construction of the result 109 cost := uint64(math.Ceil(float64(actualSize(args[0])) * 2 * common.StringTraversalCostFactor)) 110 return &cost 111 } 112 case "join": 113 if len(args) >= 1 { 114 cost := uint64(math.Ceil(float64(actualSize(result)) * 2 * common.StringTraversalCostFactor)) 115 return &cost 116 } 117 case "find", "findAll": 118 if len(args) >= 2 { 119 strCost := uint64(math.Ceil((1.0 + float64(actualSize(args[0]))) * common.StringTraversalCostFactor)) 120 // We don't know how many expressions are in the regex, just the string length (a huge 121 // improvement here would be to somehow get a count the number of expressions in the regex or 122 // how many states are in the regex state machine and use that to measure regex cost). 123 // For now, we're making a guess that each expression in a regex is typically at least 4 chars 124 // in length. 125 regexCost := uint64(math.Ceil(float64(actualSize(args[1])) * common.RegexStringLengthCostFactor)) 126 cost := strCost * regexCost 127 return &cost 128 } 129 case "cidr", "isIP", "isCIDR": 130 // IP and CIDR parsing is a string traversal. 131 if len(args) >= 1 { 132 cost := uint64(math.Ceil(float64(actualSize(args[0])) * common.StringTraversalCostFactor)) 133 return &cost 134 } 135 case "ip": 136 // IP and CIDR parsing is a string traversal. 137 if len(args) >= 1 { 138 if overloadId == "cidr_ip" { 139 // The IP member of the CIDR object is just accessing a field. 140 // Nominal cost. 141 cost := uint64(1) 142 return &cost 143 } 144 145 cost := uint64(math.Ceil(float64(actualSize(args[0])) * common.StringTraversalCostFactor)) 146 return &cost 147 } 148 case "ip.isCanonical": 149 if len(args) >= 1 { 150 // We have to parse the string and then compare the parsed string to the original string. 151 // So we double the cost of parsing the string. 152 cost := uint64(math.Ceil(float64(actualSize(args[0])) * 2 * common.StringTraversalCostFactor)) 153 return &cost 154 } 155 case "masked", "prefixLength", "family", "isUnspecified", "isLoopback", "isLinkLocalMulticast", "isLinkLocalUnicast", "isGlobalUnicast": 156 // IP and CIDR accessors are nominal cost. 157 cost := uint64(1) 158 return &cost 159 case "containsIP": 160 if len(args) >= 2 { 161 cidrSize := actualSize(args[0]) 162 otherSize := actualSize(args[1]) 163 164 // This is the base cost of comparing two byte lists. 165 // We will compare only up to the length of the CIDR prefix in bytes, so use the cidrSize twice. 166 cost := uint64(math.Ceil(float64(cidrSize+cidrSize) * common.StringTraversalCostFactor)) 167 168 if overloadId == "cidr_contains_ip_string" { 169 // If we are comparing a string, we must parse the string to into the right type, so add the cost of traversing the string again. 170 cost += uint64(math.Ceil(float64(otherSize) * common.StringTraversalCostFactor)) 171 172 } 173 174 return &cost 175 } 176 case "containsCIDR": 177 if len(args) >= 2 { 178 cidrSize := actualSize(args[0]) 179 otherSize := actualSize(args[1]) 180 181 // This is the base cost of comparing two byte lists. 182 // We will compare only up to the length of the CIDR prefix in bytes, so use the cidrSize twice. 183 cost := uint64(math.Ceil(float64(cidrSize+cidrSize) * common.StringTraversalCostFactor)) 184 185 // As we are comparing if a CIDR is within another CIDR, we first mask the base CIDR and 186 // also compare the CIDR bits. 187 // This has an additional cost of the length of the IP being traversed again, plus 1. 188 cost += uint64(math.Ceil(float64(cidrSize)*common.StringTraversalCostFactor)) + 1 189 190 if overloadId == "cidr_contains_cidr_string" { 191 // If we are comparing a string, we must parse the string to into the right type, so add the cost of traversing the string again. 192 cost += uint64(math.Ceil(float64(otherSize) * common.StringTraversalCostFactor)) 193 } 194 195 return &cost 196 } 197 case "quantity", "isQuantity": 198 if len(args) >= 1 { 199 cost := uint64(math.Ceil(float64(actualSize(args[0])) * common.StringTraversalCostFactor)) 200 return &cost 201 } 202 case "validate": 203 if len(args) >= 2 { 204 format, isFormat := args[0].Value().(*cel.Format) 205 if isFormat { 206 strSize := actualSize(args[1]) 207 208 // Dont have access to underlying regex, estimate a long regexp 209 regexSize := format.MaxRegexSize 210 211 // Copied from CEL implementation for regex cost 212 // 213 // https://swtch.com/~rsc/regexp/regexp1.html applies to RE2 implementation supported by CEL 214 // Add one to string length for purposes of cost calculation to prevent product of string and regex to be 0 215 // in case where string is empty but regex is still expensive. 216 strCost := uint64(math.Ceil((1.0 + float64(strSize)) * common.StringTraversalCostFactor)) 217 // We don't know how many expressions are in the regex, just the string length (a huge 218 // improvement here would be to somehow get a count the number of expressions in the regex or 219 // how many states are in the regex state machine and use that to measure regex cost). 220 // For now, we're making a guess that each expression in a regex is typically at least 4 chars 221 // in length. 222 regexCost := uint64(math.Ceil(float64(regexSize) * common.RegexStringLengthCostFactor)) 223 cost := strCost * regexCost 224 return &cost 225 } 226 } 227 case "format.named": 228 // Simply dictionary lookup 229 cost := uint64(1) 230 return &cost 231 case "sign", "asInteger", "isInteger", "asApproximateFloat", "isGreaterThan", "isLessThan", "compareTo", "add", "sub": 232 cost := uint64(1) 233 return &cost 234 case "getScheme", "getHostname", "getHost", "getPort", "getEscapedPath", "getQuery": 235 // url accessors 236 cost := uint64(1) 237 return &cost 238 } 239 if panicOnUnknown && !knownUnhandledFunctions[function] { 240 panic(fmt.Errorf("CallCost: unhandled function %q or args %v", function, args)) 241 } 242 return nil 243 } 244 245 func (l *CostEstimator) EstimateCallCost(function, overloadId string, target *checker.AstNode, args []checker.AstNode) *checker.CallEstimate { 246 // WARNING: Any changes to this code impact API compatibility! The estimated cost is used to determine which CEL rules may be written to a 247 // CRD and any change (cost increases and cost decreases) are breaking. 248 switch function { 249 case "check": 250 // An authorization check has a fixed cost 251 // This cost is set to allow for only two authorization checks per expression 252 return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 350000, Max: 350000}} 253 case "serviceAccount", "path", "group", "resource", "subresource", "namespace", "name", "allowed", "reason", "error", "errored": 254 // All authorization builder and accessor functions have a nominal cost 255 return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 1, Max: 1}} 256 case "fieldSelector", "labelSelector": 257 // field and label selector parse is a string parse into a structured set of requirements 258 if len(args) == 1 { 259 return &checker.CallEstimate{CostEstimate: selectorCostEstimate(l.sizeEstimate(args[0]))} 260 } 261 case "isSorted", "sum", "max", "min", "indexOf", "lastIndexOf": 262 if target != nil { 263 // Charge 1 cost for comparing each element in the list 264 elCost := checker.CostEstimate{Min: 1, Max: 1} 265 // If the list contains strings or bytes, add the cost of traversing all the strings/bytes as a way 266 // of estimating the additional comparison cost. 267 if elNode := l.listElementNode(*target); elNode != nil { 268 k := elNode.Type().Kind() 269 if k == types.StringKind || k == types.BytesKind { 270 sz := l.sizeEstimate(elNode) 271 elCost = elCost.Add(sz.MultiplyByCostFactor(common.StringTraversalCostFactor)) 272 } 273 return &checker.CallEstimate{CostEstimate: l.sizeEstimate(*target).MultiplyByCost(elCost)} 274 } else { // the target is a string, which is supported by indexOf and lastIndexOf 275 return &checker.CallEstimate{CostEstimate: l.sizeEstimate(*target).MultiplyByCostFactor(common.StringTraversalCostFactor)} 276 } 277 } 278 case "url": 279 if len(args) == 1 { 280 sz := l.sizeEstimate(args[0]) 281 return &checker.CallEstimate{CostEstimate: sz.MultiplyByCostFactor(common.StringTraversalCostFactor)} 282 } 283 case "lowerAscii", "upperAscii", "substring", "trim": 284 if target != nil { 285 sz := l.sizeEstimate(*target) 286 return &checker.CallEstimate{CostEstimate: sz.MultiplyByCostFactor(common.StringTraversalCostFactor), ResultSize: &sz} 287 } 288 case "replace": 289 if target != nil && len(args) >= 2 { 290 sz := l.sizeEstimate(*target) 291 toReplaceSz := l.sizeEstimate(args[0]) 292 replaceWithSz := l.sizeEstimate(args[1]) 293 294 var replaceCount, retainedSz checker.SizeEstimate 295 // find the longest replacement: 296 if toReplaceSz.Min == 0 { 297 // if the string being replaced is empty, replace surrounds all characters in the input string with the replacement. 298 if sz.Max < math.MaxUint64 { 299 replaceCount.Max = sz.Max + 1 300 } else { 301 replaceCount.Max = sz.Max 302 } 303 // Include the length of the longest possible original string length. 304 retainedSz.Max = sz.Max 305 } else if replaceWithSz.Max <= toReplaceSz.Min { 306 // If the replacement does not make the result longer, use the original string length. 307 replaceCount.Max = 0 308 retainedSz.Max = sz.Max 309 } else { 310 // Replace the smallest possible substrings with the largest possible replacement 311 // as many times as possible. 312 replaceCount.Max = uint64(math.Ceil(float64(sz.Max) / float64(toReplaceSz.Min))) 313 } 314 315 // find the shortest replacement: 316 if toReplaceSz.Max == 0 { 317 // if the string being replaced is empty, replace surrounds all characters in the input string with the replacement. 318 if sz.Min < math.MaxUint64 { 319 replaceCount.Min = sz.Min + 1 320 } else { 321 replaceCount.Min = sz.Min 322 } 323 // Include the length of the shortest possible original string length. 324 retainedSz.Min = sz.Min 325 } else if toReplaceSz.Max <= replaceWithSz.Min { 326 // If the replacement does not make the result shorter, use the original string length. 327 replaceCount.Min = 0 328 retainedSz.Min = sz.Min 329 } else { 330 // Replace the largest possible substrings being with the smallest possible replacement 331 // as many times as possible. 332 replaceCount.Min = uint64(math.Ceil(float64(sz.Min) / float64(toReplaceSz.Max))) 333 } 334 size := replaceCount.Multiply(replaceWithSz).Add(retainedSz) 335 336 // cost is the traversal plus the construction of the result 337 return &checker.CallEstimate{CostEstimate: sz.MultiplyByCostFactor(2 * common.StringTraversalCostFactor), ResultSize: &size} 338 } 339 case "split": 340 if target != nil { 341 sz := l.sizeEstimate(*target) 342 343 // Worst case size is where is that a separator of "" is used, and each char is returned as a list element. 344 max := sz.Max 345 if len(args) > 1 { 346 if v := args[1].Expr().AsLiteral(); v != nil { 347 if i, ok := v.Value().(int64); ok { 348 max = uint64(i) 349 } 350 } 351 } 352 // Cost is the traversal plus the construction of the result. 353 return &checker.CallEstimate{CostEstimate: sz.MultiplyByCostFactor(2 * common.StringTraversalCostFactor), ResultSize: &checker.SizeEstimate{Min: 0, Max: max}} 354 } 355 case "join": 356 if target != nil { 357 var sz checker.SizeEstimate 358 listSize := l.sizeEstimate(*target) 359 if elNode := l.listElementNode(*target); elNode != nil { 360 elemSize := l.sizeEstimate(elNode) 361 sz = listSize.Multiply(elemSize) 362 } 363 364 if len(args) > 0 { 365 sepSize := l.sizeEstimate(args[0]) 366 minSeparators := uint64(0) 367 maxSeparators := uint64(0) 368 if listSize.Min > 0 { 369 minSeparators = listSize.Min - 1 370 } 371 if listSize.Max > 0 { 372 maxSeparators = listSize.Max - 1 373 } 374 sz = sz.Add(sepSize.Multiply(checker.SizeEstimate{Min: minSeparators, Max: maxSeparators})) 375 } 376 377 return &checker.CallEstimate{CostEstimate: sz.MultiplyByCostFactor(common.StringTraversalCostFactor), ResultSize: &sz} 378 } 379 case "find", "findAll": 380 if target != nil && len(args) >= 1 { 381 sz := l.sizeEstimate(*target) 382 // Add one to string length for purposes of cost calculation to prevent product of string and regex to be 0 383 // in case where string is empty but regex is still expensive. 384 strCost := sz.Add(checker.SizeEstimate{Min: 1, Max: 1}).MultiplyByCostFactor(common.StringTraversalCostFactor) 385 // We don't know how many expressions are in the regex, just the string length (a huge 386 // improvement here would be to somehow get a count the number of expressions in the regex or 387 // how many states are in the regex state machine and use that to measure regex cost). 388 // For now, we're making a guess that each expression in a regex is typically at least 4 chars 389 // in length. 390 regexCost := l.sizeEstimate(args[0]).MultiplyByCostFactor(common.RegexStringLengthCostFactor) 391 // worst case size of result is that every char is returned as separate find result. 392 return &checker.CallEstimate{CostEstimate: strCost.Multiply(regexCost), ResultSize: &checker.SizeEstimate{Min: 0, Max: sz.Max}} 393 } 394 case "cidr", "isIP", "isCIDR": 395 if target != nil { 396 sz := l.sizeEstimate(args[0]) 397 return &checker.CallEstimate{CostEstimate: sz.MultiplyByCostFactor(common.StringTraversalCostFactor)} 398 } 399 case "ip": 400 if target != nil && len(args) >= 1 { 401 if overloadId == "cidr_ip" { 402 // The IP member of the CIDR object is just accessing a field. 403 // Nominal cost. 404 return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 1, Max: 1}} 405 } 406 407 sz := l.sizeEstimate(args[0]) 408 return &checker.CallEstimate{CostEstimate: sz.MultiplyByCostFactor(common.StringTraversalCostFactor)} 409 } else if target != nil { 410 // The IP member of a CIDR is a just accessing a field, nominal cost. 411 return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 1, Max: 1}} 412 } 413 case "ip.isCanonical": 414 if target != nil && len(args) >= 1 { 415 sz := l.sizeEstimate(args[0]) 416 // We have to parse the string and then compare the parsed string to the original string. 417 // So we double the cost of parsing the string. 418 return &checker.CallEstimate{CostEstimate: sz.MultiplyByCostFactor(2 * common.StringTraversalCostFactor)} 419 } 420 case "masked", "prefixLength", "family", "isUnspecified", "isLoopback", "isLinkLocalMulticast", "isLinkLocalUnicast", "isGlobalUnicast": 421 // IP and CIDR accessors are nominal cost. 422 return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 1, Max: 1}} 423 case "containsIP": 424 if target != nil && len(args) >= 1 { 425 // The base cost of the function is the cost of comparing two byte lists. 426 // The byte lists will be either ipv4 or ipv6 so will have a length of 4, or 16 bytes. 427 sz := checker.SizeEstimate{Min: 4, Max: 16} 428 429 // We have to compare the two strings to determine if the CIDR/IP is in the other CIDR. 430 ipCompCost := sz.Add(sz).MultiplyByCostFactor(common.StringTraversalCostFactor) 431 432 if overloadId == "cidr_contains_ip_string" { 433 // If we are comparing a string, we must parse the string to into the right type, so add the cost of traversing the string again. 434 ipCompCost = ipCompCost.Add(checker.CostEstimate(l.sizeEstimate(args[0])).MultiplyByCostFactor(common.StringTraversalCostFactor)) 435 } 436 437 return &checker.CallEstimate{CostEstimate: ipCompCost} 438 } 439 case "containsCIDR": 440 if target != nil && len(args) >= 1 { 441 // The base cost of the function is the cost of comparing two byte lists. 442 // The byte lists will be either ipv4 or ipv6 so will have a length of 4, or 16 bytes. 443 sz := checker.SizeEstimate{Min: 4, Max: 16} 444 445 // We have to compare the two strings to determine if the CIDR/IP is in the other CIDR. 446 ipCompCost := sz.Add(sz).MultiplyByCostFactor(common.StringTraversalCostFactor) 447 448 // As we are comparing if a CIDR is within another CIDR, we first mask the base CIDR and 449 // also compare the CIDR bits. 450 // This has an additional cost of the length of the IP being traversed again, plus 1. 451 ipCompCost = ipCompCost.Add(sz.MultiplyByCostFactor(common.StringTraversalCostFactor)) 452 ipCompCost = ipCompCost.Add(checker.CostEstimate{Min: 1, Max: 1}) 453 454 if overloadId == "cidr_contains_cidr_string" { 455 // If we are comparing a string, we must parse the string to into the right type, so add the cost of traversing the string again. 456 ipCompCost = ipCompCost.Add(checker.CostEstimate(l.sizeEstimate(args[0])).MultiplyByCostFactor(common.StringTraversalCostFactor)) 457 } 458 459 return &checker.CallEstimate{CostEstimate: ipCompCost} 460 } 461 case "quantity", "isQuantity": 462 if target != nil { 463 sz := l.sizeEstimate(args[0]) 464 return &checker.CallEstimate{CostEstimate: sz.MultiplyByCostFactor(common.StringTraversalCostFactor)} 465 } 466 case "validate": 467 if target != nil { 468 sz := l.sizeEstimate(args[0]) 469 return &checker.CallEstimate{CostEstimate: sz.MultiplyByCostFactor(common.StringTraversalCostFactor).MultiplyByCostFactor(cel.MaxNameFormatRegexSize * common.RegexStringLengthCostFactor)} 470 } 471 case "format.named": 472 return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 1, Max: 1}} 473 case "sign", "asInteger", "isInteger", "asApproximateFloat", "isGreaterThan", "isLessThan", "compareTo", "add", "sub": 474 return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 1, Max: 1}} 475 case "getScheme", "getHostname", "getHost", "getPort", "getEscapedPath", "getQuery": 476 // url accessors 477 return &checker.CallEstimate{CostEstimate: checker.CostEstimate{Min: 1, Max: 1}} 478 } 479 if panicOnUnknown && !knownUnhandledFunctions[function] { 480 panic(fmt.Errorf("EstimateCallCost: unhandled function %q, target %v, args %v", function, target, args)) 481 } 482 return nil 483 } 484 485 func actualSize(value ref.Val) uint64 { 486 if sz, ok := value.(traits.Sizer); ok { 487 return uint64(sz.Size().(types.Int)) 488 } 489 if panicOnUnknown { 490 // debug.PrintStack() 491 panic(fmt.Errorf("actualSize: non-sizer type %T", value)) 492 } 493 return 1 494 } 495 496 func (l *CostEstimator) sizeEstimate(t checker.AstNode) checker.SizeEstimate { 497 if sz := t.ComputedSize(); sz != nil { 498 return *sz 499 } 500 if sz := l.EstimateSize(t); sz != nil { 501 return *sz 502 } 503 return checker.SizeEstimate{Min: 0, Max: math.MaxUint64} 504 } 505 506 func (l *CostEstimator) listElementNode(list checker.AstNode) checker.AstNode { 507 if params := list.Type().Parameters(); len(params) > 0 { 508 lt := params[0] 509 nodePath := list.Path() 510 if nodePath != nil { 511 // Provide path if we have it so that a OpenAPIv3 maxLength validation can be looked up, if it exists 512 // for this node. 513 path := make([]string, len(nodePath)+1) 514 copy(path, nodePath) 515 path[len(nodePath)] = "@items" 516 return &itemsNode{path: path, t: lt, expr: nil} 517 } else { 518 // Provide just the type if no path is available so that worst case size can be looked up based on type. 519 return &itemsNode{t: lt, expr: nil} 520 } 521 } 522 return nil 523 } 524 525 func (l *CostEstimator) EstimateSize(element checker.AstNode) *checker.SizeEstimate { 526 if l.SizeEstimator != nil { 527 return l.SizeEstimator.EstimateSize(element) 528 } 529 return nil 530 } 531 532 type itemsNode struct { 533 path []string 534 t *types.Type 535 expr ast.Expr 536 } 537 538 func (i *itemsNode) Path() []string { 539 return i.path 540 } 541 542 func (i *itemsNode) Type() *types.Type { 543 return i.t 544 } 545 546 func (i *itemsNode) Expr() ast.Expr { 547 return i.expr 548 } 549 550 func (i *itemsNode) ComputedSize() *checker.SizeEstimate { 551 return nil 552 } 553 554 var _ checker.AstNode = (*itemsNode)(nil) 555 556 // traversalCost computes the cost of traversing a ref.Val as a data tree. 557 func traversalCost(v ref.Val) uint64 { 558 // TODO: This could potentially be optimized by sampling maps and lists instead of traversing. 559 switch vt := v.(type) { 560 case types.String: 561 return uint64(float64(len(string(vt))) * common.StringTraversalCostFactor) 562 case types.Bytes: 563 return uint64(float64(len([]byte(vt))) * common.StringTraversalCostFactor) 564 case traits.Lister: 565 cost := uint64(0) 566 for it := vt.Iterator(); it.HasNext() == types.True; { 567 i := it.Next() 568 cost += traversalCost(i) 569 } 570 return cost 571 case traits.Mapper: // maps and objects 572 cost := uint64(0) 573 for it := vt.Iterator(); it.HasNext() == types.True; { 574 k := it.Next() 575 cost += traversalCost(k) + traversalCost(vt.Get(k)) 576 } 577 return cost 578 default: 579 return 1 580 } 581 }