github.com/rajeev159/opa@v0.45.0/ast/term.go (about) 1 // Copyright 2016 The OPA Authors. All rights reserved. 2 // Use of this source code is governed by an Apache2 3 // license that can be found in the LICENSE file. 4 5 // nolint: deadcode // Public API. 6 package ast 7 8 import ( 9 "bytes" 10 "encoding/json" 11 "fmt" 12 "io" 13 "math" 14 "math/big" 15 "net/url" 16 "regexp" 17 "sort" 18 "strconv" 19 "strings" 20 21 "github.com/OneOfOne/xxhash" 22 "github.com/pkg/errors" 23 24 "github.com/open-policy-agent/opa/ast/location" 25 "github.com/open-policy-agent/opa/util" 26 ) 27 28 var errFindNotFound = fmt.Errorf("find: not found") 29 30 // Location records a position in source code. 31 type Location = location.Location 32 33 // NewLocation returns a new Location object. 34 func NewLocation(text []byte, file string, row int, col int) *Location { 35 return location.NewLocation(text, file, row, col) 36 } 37 38 // Value declares the common interface for all Term values. Every kind of Term value 39 // in the language is represented as a type that implements this interface: 40 // 41 // - Null, Boolean, Number, String 42 // - Object, Array, Set 43 // - Variables, References 44 // - Array, Set, and Object Comprehensions 45 // - Calls 46 type Value interface { 47 Compare(other Value) int // Compare returns <0, 0, or >0 if this Value is less than, equal to, or greater than other, respectively. 48 Find(path Ref) (Value, error) // Find returns value referred to by path or an error if path is not found. 49 Hash() int // Returns hash code of the value. 50 IsGround() bool // IsGround returns true if this value is not a variable or contains no variables. 51 String() string // String returns a human readable string representation of the value. 52 } 53 54 // InterfaceToValue converts a native Go value x to a Value. 55 func InterfaceToValue(x interface{}) (Value, error) { 56 switch x := x.(type) { 57 case nil: 58 return Null{}, nil 59 case bool: 60 return Boolean(x), nil 61 case json.Number: 62 return Number(x), nil 63 case int64: 64 return int64Number(x), nil 65 case uint64: 66 return uint64Number(x), nil 67 case float64: 68 return floatNumber(x), nil 69 case int: 70 return intNumber(x), nil 71 case string: 72 return String(x), nil 73 case []interface{}: 74 r := make([]*Term, len(x)) 75 for i, e := range x { 76 e, err := InterfaceToValue(e) 77 if err != nil { 78 return nil, err 79 } 80 r[i] = &Term{Value: e} 81 } 82 return NewArray(r...), nil 83 case map[string]interface{}: 84 r := newobject(len(x)) 85 for k, v := range x { 86 k, err := InterfaceToValue(k) 87 if err != nil { 88 return nil, err 89 } 90 v, err := InterfaceToValue(v) 91 if err != nil { 92 return nil, err 93 } 94 r.Insert(NewTerm(k), NewTerm(v)) 95 } 96 return r, nil 97 case map[string]string: 98 r := newobject(len(x)) 99 for k, v := range x { 100 k, err := InterfaceToValue(k) 101 if err != nil { 102 return nil, err 103 } 104 v, err := InterfaceToValue(v) 105 if err != nil { 106 return nil, err 107 } 108 r.Insert(NewTerm(k), NewTerm(v)) 109 } 110 return r, nil 111 default: 112 ptr := util.Reference(x) 113 if err := util.RoundTrip(ptr); err != nil { 114 return nil, fmt.Errorf("ast: interface conversion: %w", err) 115 } 116 return InterfaceToValue(*ptr) 117 } 118 } 119 120 // ValueFromReader returns an AST value from a JSON serialized value in the reader. 121 func ValueFromReader(r io.Reader) (Value, error) { 122 var x interface{} 123 if err := util.NewJSONDecoder(r).Decode(&x); err != nil { 124 return nil, err 125 } 126 return InterfaceToValue(x) 127 } 128 129 // As converts v into a Go native type referred to by x. 130 func As(v Value, x interface{}) error { 131 return util.NewJSONDecoder(bytes.NewBufferString(v.String())).Decode(x) 132 } 133 134 // Resolver defines the interface for resolving references to native Go values. 135 type Resolver interface { 136 Resolve(ref Ref) (interface{}, error) 137 } 138 139 // ValueResolver defines the interface for resolving references to AST values. 140 type ValueResolver interface { 141 Resolve(ref Ref) (Value, error) 142 } 143 144 // UnknownValueErr indicates a ValueResolver was unable to resolve a reference 145 // because the reference refers to an unknown value. 146 type UnknownValueErr struct{} 147 148 func (UnknownValueErr) Error() string { 149 return "unknown value" 150 } 151 152 // IsUnknownValueErr returns true if the err is an UnknownValueErr. 153 func IsUnknownValueErr(err error) bool { 154 _, ok := err.(UnknownValueErr) 155 return ok 156 } 157 158 type illegalResolver struct{} 159 160 func (illegalResolver) Resolve(ref Ref) (interface{}, error) { 161 return nil, fmt.Errorf("illegal value: %v", ref) 162 } 163 164 // ValueToInterface returns the Go representation of an AST value. The AST 165 // value should not contain any values that require evaluation (e.g., vars, 166 // comprehensions, etc.) 167 func ValueToInterface(v Value, resolver Resolver) (interface{}, error) { 168 return valueToInterface(v, resolver, JSONOpt{}) 169 } 170 171 func valueToInterface(v Value, resolver Resolver, opt JSONOpt) (interface{}, error) { 172 switch v := v.(type) { 173 case Null: 174 return nil, nil 175 case Boolean: 176 return bool(v), nil 177 case Number: 178 return json.Number(v), nil 179 case String: 180 return string(v), nil 181 case *Array: 182 buf := []interface{}{} 183 for i := 0; i < v.Len(); i++ { 184 x1, err := valueToInterface(v.Elem(i).Value, resolver, opt) 185 if err != nil { 186 return nil, err 187 } 188 buf = append(buf, x1) 189 } 190 return buf, nil 191 case *object: 192 buf := make(map[string]interface{}, v.Len()) 193 err := v.Iter(func(k, v *Term) error { 194 ki, err := valueToInterface(k.Value, resolver, opt) 195 if err != nil { 196 return err 197 } 198 var str string 199 var ok bool 200 if str, ok = ki.(string); !ok { 201 var buf bytes.Buffer 202 if err := json.NewEncoder(&buf).Encode(ki); err != nil { 203 return err 204 } 205 str = strings.TrimSpace(buf.String()) 206 } 207 vi, err := valueToInterface(v.Value, resolver, opt) 208 if err != nil { 209 return err 210 } 211 buf[str] = vi 212 return nil 213 }) 214 if err != nil { 215 return nil, err 216 } 217 return buf, nil 218 case Set: 219 buf := []interface{}{} 220 iter := func(x *Term) error { 221 x1, err := valueToInterface(x.Value, resolver, opt) 222 if err != nil { 223 return err 224 } 225 buf = append(buf, x1) 226 return nil 227 } 228 var err error 229 if opt.SortSets { 230 err = v.Sorted().Iter(iter) 231 } else { 232 err = v.Iter(iter) 233 } 234 if err != nil { 235 return nil, err 236 } 237 return buf, nil 238 case Ref: 239 return resolver.Resolve(v) 240 default: 241 return nil, fmt.Errorf("%v requires evaluation", TypeName(v)) 242 } 243 } 244 245 // JSON returns the JSON representation of v. The value must not contain any 246 // refs or terms that require evaluation (e.g., vars, comprehensions, etc.) 247 func JSON(v Value) (interface{}, error) { 248 return JSONWithOpt(v, JSONOpt{}) 249 } 250 251 // JSONOpt defines parameters for AST to JSON conversion. 252 type JSONOpt struct { 253 SortSets bool // sort sets before serializing (this makes conversion more expensive) 254 } 255 256 // JSONWithOpt returns the JSON representation of v. The value must not contain any 257 // refs or terms that require evaluation (e.g., vars, comprehensions, etc.) 258 func JSONWithOpt(v Value, opt JSONOpt) (interface{}, error) { 259 return valueToInterface(v, illegalResolver{}, opt) 260 } 261 262 // MustJSON returns the JSON representation of v. The value must not contain any 263 // refs or terms that require evaluation (e.g., vars, comprehensions, etc.) If 264 // the conversion fails, this function will panic. This function is mostly for 265 // test purposes. 266 func MustJSON(v Value) interface{} { 267 r, err := JSON(v) 268 if err != nil { 269 panic(err) 270 } 271 return r 272 } 273 274 // MustInterfaceToValue converts a native Go value x to a Value. If the 275 // conversion fails, this function will panic. This function is mostly for test 276 // purposes. 277 func MustInterfaceToValue(x interface{}) Value { 278 v, err := InterfaceToValue(x) 279 if err != nil { 280 panic(err) 281 } 282 return v 283 } 284 285 // Term is an argument to a function. 286 type Term struct { 287 Value Value `json:"value"` // the value of the Term as represented in Go 288 Location *Location `json:"-"` // the location of the Term in the source 289 } 290 291 // NewTerm returns a new Term object. 292 func NewTerm(v Value) *Term { 293 return &Term{ 294 Value: v, 295 } 296 } 297 298 // SetLocation updates the term's Location and returns the term itself. 299 func (term *Term) SetLocation(loc *Location) *Term { 300 term.Location = loc 301 return term 302 } 303 304 // Loc returns the Location of term. 305 func (term *Term) Loc() *Location { 306 if term == nil { 307 return nil 308 } 309 return term.Location 310 } 311 312 // SetLoc sets the location on term. 313 func (term *Term) SetLoc(loc *Location) { 314 term.SetLocation(loc) 315 } 316 317 // Copy returns a deep copy of term. 318 func (term *Term) Copy() *Term { 319 320 if term == nil { 321 return nil 322 } 323 324 cpy := *term 325 326 switch v := term.Value.(type) { 327 case Null, Boolean, Number, String, Var: 328 cpy.Value = v 329 case Ref: 330 cpy.Value = v.Copy() 331 case *Array: 332 cpy.Value = v.Copy() 333 case Set: 334 cpy.Value = v.Copy() 335 case *object: 336 cpy.Value = v.Copy() 337 case *ArrayComprehension: 338 cpy.Value = v.Copy() 339 case *ObjectComprehension: 340 cpy.Value = v.Copy() 341 case *SetComprehension: 342 cpy.Value = v.Copy() 343 case Call: 344 cpy.Value = v.Copy() 345 } 346 347 return &cpy 348 } 349 350 // Equal returns true if this term equals the other term. Equality is 351 // defined for each kind of term. 352 func (term *Term) Equal(other *Term) bool { 353 if term == nil && other != nil { 354 return false 355 } 356 if term != nil && other == nil { 357 return false 358 } 359 if term == other { 360 return true 361 } 362 363 // TODO(tsandall): This early-exit avoids allocations for types that have 364 // Equal() functions that just use == underneath. We should revisit the 365 // other types and implement Equal() functions that do not require 366 // allocations. 367 switch v := term.Value.(type) { 368 case Null: 369 return v.Equal(other.Value) 370 case Boolean: 371 return v.Equal(other.Value) 372 case Number: 373 return v.Equal(other.Value) 374 case String: 375 return v.Equal(other.Value) 376 case Var: 377 return v.Equal(other.Value) 378 } 379 380 return term.Value.Compare(other.Value) == 0 381 } 382 383 // Get returns a value referred to by name from the term. 384 func (term *Term) Get(name *Term) *Term { 385 switch v := term.Value.(type) { 386 case *Array: 387 return v.Get(name) 388 case *object: 389 return v.Get(name) 390 case Set: 391 if v.Contains(name) { 392 return name 393 } 394 } 395 return nil 396 } 397 398 // Hash returns the hash code of the Term's Value. Its Location 399 // is ignored. 400 func (term *Term) Hash() int { 401 return term.Value.Hash() 402 } 403 404 // IsGround returns true if this term's Value is ground. 405 func (term *Term) IsGround() bool { 406 return term.Value.IsGround() 407 } 408 409 // MarshalJSON returns the JSON encoding of the term. 410 // 411 // Specialized marshalling logic is required to include a type hint for Value. 412 func (term *Term) MarshalJSON() ([]byte, error) { 413 d := map[string]interface{}{ 414 "type": TypeName(term.Value), 415 "value": term.Value, 416 } 417 return json.Marshal(d) 418 } 419 420 func (term *Term) String() string { 421 return term.Value.String() 422 } 423 424 // UnmarshalJSON parses the byte array and stores the result in term. 425 // Specialized unmarshalling is required to handle Value. 426 func (term *Term) UnmarshalJSON(bs []byte) error { 427 v := map[string]interface{}{} 428 if err := util.UnmarshalJSON(bs, &v); err != nil { 429 return err 430 } 431 val, err := unmarshalValue(v) 432 if err != nil { 433 return err 434 } 435 term.Value = val 436 return nil 437 } 438 439 // Vars returns a VarSet with variables contained in this term. 440 func (term *Term) Vars() VarSet { 441 vis := &VarVisitor{vars: VarSet{}} 442 vis.Walk(term) 443 return vis.vars 444 } 445 446 // IsConstant returns true if the AST value is constant. 447 func IsConstant(v Value) bool { 448 found := false 449 vis := GenericVisitor{ 450 func(x interface{}) bool { 451 switch x.(type) { 452 case Var, Ref, *ArrayComprehension, *ObjectComprehension, *SetComprehension, Call: 453 found = true 454 return true 455 } 456 return false 457 }, 458 } 459 vis.Walk(v) 460 return !found 461 } 462 463 // IsComprehension returns true if the supplied value is a comprehension. 464 func IsComprehension(x Value) bool { 465 switch x.(type) { 466 case *ArrayComprehension, *ObjectComprehension, *SetComprehension: 467 return true 468 } 469 return false 470 } 471 472 // ContainsRefs returns true if the Value v contains refs. 473 func ContainsRefs(v interface{}) bool { 474 found := false 475 WalkRefs(v, func(Ref) bool { 476 found = true 477 return found 478 }) 479 return found 480 } 481 482 // ContainsComprehensions returns true if the Value v contains comprehensions. 483 func ContainsComprehensions(v interface{}) bool { 484 found := false 485 WalkClosures(v, func(x interface{}) bool { 486 switch x.(type) { 487 case *ArrayComprehension, *ObjectComprehension, *SetComprehension: 488 found = true 489 return found 490 } 491 return found 492 }) 493 return found 494 } 495 496 // ContainsClosures returns true if the Value v contains closures. 497 func ContainsClosures(v interface{}) bool { 498 found := false 499 WalkClosures(v, func(x interface{}) bool { 500 switch x.(type) { 501 case *ArrayComprehension, *ObjectComprehension, *SetComprehension, *Every: 502 found = true 503 return found 504 } 505 return found 506 }) 507 return found 508 } 509 510 // IsScalar returns true if the AST value is a scalar. 511 func IsScalar(v Value) bool { 512 switch v.(type) { 513 case String: 514 return true 515 case Number: 516 return true 517 case Boolean: 518 return true 519 case Null: 520 return true 521 } 522 return false 523 } 524 525 // Null represents the null value defined by JSON. 526 type Null struct{} 527 528 // NullTerm creates a new Term with a Null value. 529 func NullTerm() *Term { 530 return &Term{Value: Null{}} 531 } 532 533 // Equal returns true if the other term Value is also Null. 534 func (null Null) Equal(other Value) bool { 535 switch other.(type) { 536 case Null: 537 return true 538 default: 539 return false 540 } 541 } 542 543 // Compare compares null to other, return <0, 0, or >0 if it is less than, equal to, 544 // or greater than other. 545 func (null Null) Compare(other Value) int { 546 return Compare(null, other) 547 } 548 549 // Find returns the current value or a not found error. 550 func (null Null) Find(path Ref) (Value, error) { 551 if len(path) == 0 { 552 return null, nil 553 } 554 return nil, errFindNotFound 555 } 556 557 // Hash returns the hash code for the Value. 558 func (null Null) Hash() int { 559 return 0 560 } 561 562 // IsGround always returns true. 563 func (Null) IsGround() bool { 564 return true 565 } 566 567 func (null Null) String() string { 568 return "null" 569 } 570 571 // Boolean represents a boolean value defined by JSON. 572 type Boolean bool 573 574 // BooleanTerm creates a new Term with a Boolean value. 575 func BooleanTerm(b bool) *Term { 576 return &Term{Value: Boolean(b)} 577 } 578 579 // Equal returns true if the other Value is a Boolean and is equal. 580 func (bol Boolean) Equal(other Value) bool { 581 switch other := other.(type) { 582 case Boolean: 583 return bol == other 584 default: 585 return false 586 } 587 } 588 589 // Compare compares bol to other, return <0, 0, or >0 if it is less than, equal to, 590 // or greater than other. 591 func (bol Boolean) Compare(other Value) int { 592 return Compare(bol, other) 593 } 594 595 // Find returns the current value or a not found error. 596 func (bol Boolean) Find(path Ref) (Value, error) { 597 if len(path) == 0 { 598 return bol, nil 599 } 600 return nil, errFindNotFound 601 } 602 603 // Hash returns the hash code for the Value. 604 func (bol Boolean) Hash() int { 605 if bol { 606 return 1 607 } 608 return 0 609 } 610 611 // IsGround always returns true. 612 func (Boolean) IsGround() bool { 613 return true 614 } 615 616 func (bol Boolean) String() string { 617 return strconv.FormatBool(bool(bol)) 618 } 619 620 // Number represents a numeric value as defined by JSON. 621 type Number json.Number 622 623 // NumberTerm creates a new Term with a Number value. 624 func NumberTerm(n json.Number) *Term { 625 return &Term{Value: Number(n)} 626 } 627 628 // IntNumberTerm creates a new Term with an integer Number value. 629 func IntNumberTerm(i int) *Term { 630 return &Term{Value: Number(strconv.Itoa(i))} 631 } 632 633 // UIntNumberTerm creates a new Term with an unsigned integer Number value. 634 func UIntNumberTerm(u uint64) *Term { 635 return &Term{Value: uint64Number(u)} 636 } 637 638 // FloatNumberTerm creates a new Term with a floating point Number value. 639 func FloatNumberTerm(f float64) *Term { 640 s := strconv.FormatFloat(f, 'g', -1, 64) 641 return &Term{Value: Number(s)} 642 } 643 644 // Equal returns true if the other Value is a Number and is equal. 645 func (num Number) Equal(other Value) bool { 646 switch other := other.(type) { 647 case Number: 648 return Compare(num, other) == 0 649 default: 650 return false 651 } 652 } 653 654 // Compare compares num to other, return <0, 0, or >0 if it is less than, equal to, 655 // or greater than other. 656 func (num Number) Compare(other Value) int { 657 return Compare(num, other) 658 } 659 660 // Find returns the current value or a not found error. 661 func (num Number) Find(path Ref) (Value, error) { 662 if len(path) == 0 { 663 return num, nil 664 } 665 return nil, errFindNotFound 666 } 667 668 // Hash returns the hash code for the Value. 669 func (num Number) Hash() int { 670 f, err := json.Number(num).Float64() 671 if err != nil { 672 bs := []byte(num) 673 h := xxhash.Checksum64(bs) 674 return int(h) 675 } 676 return int(f) 677 } 678 679 // Int returns the int representation of num if possible. 680 func (num Number) Int() (int, bool) { 681 i64, ok := num.Int64() 682 return int(i64), ok 683 } 684 685 // Int64 returns the int64 representation of num if possible. 686 func (num Number) Int64() (int64, bool) { 687 i, err := json.Number(num).Int64() 688 if err != nil { 689 return 0, false 690 } 691 return i, true 692 } 693 694 // Float64 returns the float64 representation of num if possible. 695 func (num Number) Float64() (float64, bool) { 696 f, err := json.Number(num).Float64() 697 if err != nil { 698 return 0, false 699 } 700 return f, true 701 } 702 703 // IsGround always returns true. 704 func (Number) IsGround() bool { 705 return true 706 } 707 708 // MarshalJSON returns JSON encoded bytes representing num. 709 func (num Number) MarshalJSON() ([]byte, error) { 710 return json.Marshal(json.Number(num)) 711 } 712 713 func (num Number) String() string { 714 return string(num) 715 } 716 717 func intNumber(i int) Number { 718 return Number(strconv.Itoa(i)) 719 } 720 721 func int64Number(i int64) Number { 722 return Number(strconv.FormatInt(i, 10)) 723 } 724 725 func uint64Number(u uint64) Number { 726 return Number(strconv.FormatUint(u, 10)) 727 } 728 729 func floatNumber(f float64) Number { 730 return Number(strconv.FormatFloat(f, 'g', -1, 64)) 731 } 732 733 // String represents a string value as defined by JSON. 734 type String string 735 736 // StringTerm creates a new Term with a String value. 737 func StringTerm(s string) *Term { 738 return &Term{Value: String(s)} 739 } 740 741 // Equal returns true if the other Value is a String and is equal. 742 func (str String) Equal(other Value) bool { 743 switch other := other.(type) { 744 case String: 745 return str == other 746 default: 747 return false 748 } 749 } 750 751 // Compare compares str to other, return <0, 0, or >0 if it is less than, equal to, 752 // or greater than other. 753 func (str String) Compare(other Value) int { 754 return Compare(str, other) 755 } 756 757 // Find returns the current value or a not found error. 758 func (str String) Find(path Ref) (Value, error) { 759 if len(path) == 0 { 760 return str, nil 761 } 762 return nil, errFindNotFound 763 } 764 765 // IsGround always returns true. 766 func (String) IsGround() bool { 767 return true 768 } 769 770 func (str String) String() string { 771 return strconv.Quote(string(str)) 772 } 773 774 // Hash returns the hash code for the Value. 775 func (str String) Hash() int { 776 h := xxhash.ChecksumString64S(string(str), hashSeed0) 777 return int(h) 778 } 779 780 // Var represents a variable as defined by the language. 781 type Var string 782 783 // VarTerm creates a new Term with a Variable value. 784 func VarTerm(v string) *Term { 785 return &Term{Value: Var(v)} 786 } 787 788 // Equal returns true if the other Value is a Variable and has the same value 789 // (name). 790 func (v Var) Equal(other Value) bool { 791 switch other := other.(type) { 792 case Var: 793 return v == other 794 default: 795 return false 796 } 797 } 798 799 // Compare compares v to other, return <0, 0, or >0 if it is less than, equal to, 800 // or greater than other. 801 func (v Var) Compare(other Value) int { 802 return Compare(v, other) 803 } 804 805 // Find returns the current value or a not found error. 806 func (v Var) Find(path Ref) (Value, error) { 807 if len(path) == 0 { 808 return v, nil 809 } 810 return nil, errFindNotFound 811 } 812 813 // Hash returns the hash code for the Value. 814 func (v Var) Hash() int { 815 h := xxhash.ChecksumString64S(string(v), hashSeed0) 816 return int(h) 817 } 818 819 // IsGround always returns false. 820 func (Var) IsGround() bool { 821 return false 822 } 823 824 // IsWildcard returns true if this is a wildcard variable. 825 func (v Var) IsWildcard() bool { 826 return strings.HasPrefix(string(v), WildcardPrefix) 827 } 828 829 // IsGenerated returns true if this variable was generated during compilation. 830 func (v Var) IsGenerated() bool { 831 return strings.HasPrefix(string(v), "__local") 832 } 833 834 func (v Var) String() string { 835 // Special case for wildcard so that string representation is parseable. The 836 // parser mangles wildcard variables to make their names unique and uses an 837 // illegal variable name character (WildcardPrefix) to avoid conflicts. When 838 // we serialize the variable here, we need to make sure it's parseable. 839 if v.IsWildcard() { 840 return Wildcard.String() 841 } 842 return string(v) 843 } 844 845 // Ref represents a reference as defined by the language. 846 type Ref []*Term 847 848 // EmptyRef returns a new, empty reference. 849 func EmptyRef() Ref { 850 return Ref([]*Term{}) 851 } 852 853 // PtrRef returns a new reference against the head for the pointer 854 // s. Path components in the pointer are unescaped. 855 func PtrRef(head *Term, s string) (Ref, error) { 856 s = strings.Trim(s, "/") 857 if s == "" { 858 return Ref{head}, nil 859 } 860 parts := strings.Split(s, "/") 861 if max := math.MaxInt32; len(parts) >= max { 862 return nil, fmt.Errorf("path too long: %s, %d > %d (max)", s, len(parts), max) 863 } 864 ref := make(Ref, uint(len(parts))+1) 865 ref[0] = head 866 for i := 0; i < len(parts); i++ { 867 var err error 868 parts[i], err = url.PathUnescape(parts[i]) 869 if err != nil { 870 return nil, err 871 } 872 ref[i+1] = StringTerm(parts[i]) 873 } 874 return ref, nil 875 } 876 877 // RefTerm creates a new Term with a Ref value. 878 func RefTerm(r ...*Term) *Term { 879 return &Term{Value: Ref(r)} 880 } 881 882 // Append returns a copy of ref with the term appended to the end. 883 func (ref Ref) Append(term *Term) Ref { 884 n := len(ref) 885 dst := make(Ref, n+1) 886 copy(dst, ref) 887 dst[n] = term 888 return dst 889 } 890 891 // Insert returns a copy of the ref with x inserted at pos. If pos < len(ref), 892 // existing elements are shifted to the right. If pos > len(ref)+1 this 893 // function panics. 894 func (ref Ref) Insert(x *Term, pos int) Ref { 895 if pos == len(ref) { 896 return ref.Append(x) 897 } else if pos > len(ref)+1 { 898 panic("illegal index") 899 } 900 cpy := make(Ref, len(ref)+1) 901 for i := 0; i < pos; i++ { 902 cpy[i] = ref[i] 903 } 904 cpy[pos] = x 905 for i := pos; i < len(ref); i++ { 906 cpy[i+1] = ref[i] 907 } 908 return cpy 909 } 910 911 // Extend returns a copy of ref with the terms from other appended. The head of 912 // other will be converted to a string. 913 func (ref Ref) Extend(other Ref) Ref { 914 dst := make(Ref, len(ref)+len(other)) 915 copy(dst, ref) 916 917 head := other[0].Copy() 918 head.Value = String(head.Value.(Var)) 919 offset := len(ref) 920 dst[offset] = head 921 for i := range other[1:] { 922 dst[offset+i+1] = other[i+1] 923 } 924 return dst 925 } 926 927 // Concat returns a ref with the terms appended. 928 func (ref Ref) Concat(terms []*Term) Ref { 929 if len(terms) == 0 { 930 return ref 931 } 932 cpy := make(Ref, len(ref)+len(terms)) 933 copy(cpy, ref) 934 935 for i := range terms { 936 cpy[len(ref)+i] = terms[i] 937 } 938 return cpy 939 } 940 941 // Dynamic returns the offset of the first non-constant operand of ref. 942 func (ref Ref) Dynamic() int { 943 switch ref[0].Value.(type) { 944 case Call: 945 return 0 946 } 947 for i := 1; i < len(ref); i++ { 948 if !IsConstant(ref[i].Value) { 949 return i 950 } 951 } 952 return -1 953 } 954 955 // Copy returns a deep copy of ref. 956 func (ref Ref) Copy() Ref { 957 return termSliceCopy(ref) 958 } 959 960 // Equal returns true if ref is equal to other. 961 func (ref Ref) Equal(other Value) bool { 962 return Compare(ref, other) == 0 963 } 964 965 // Compare compares ref to other, return <0, 0, or >0 if it is less than, equal to, 966 // or greater than other. 967 func (ref Ref) Compare(other Value) int { 968 return Compare(ref, other) 969 } 970 971 // Find returns the current value or a "not found" error. 972 func (ref Ref) Find(path Ref) (Value, error) { 973 if len(path) == 0 { 974 return ref, nil 975 } 976 return nil, errFindNotFound 977 } 978 979 // Hash returns the hash code for the Value. 980 func (ref Ref) Hash() int { 981 return termSliceHash(ref) 982 } 983 984 // HasPrefix returns true if the other ref is a prefix of this ref. 985 func (ref Ref) HasPrefix(other Ref) bool { 986 if len(other) > len(ref) { 987 return false 988 } 989 for i := range other { 990 if !ref[i].Equal(other[i]) { 991 return false 992 } 993 } 994 return true 995 } 996 997 // ConstantPrefix returns the constant portion of the ref starting from the head. 998 func (ref Ref) ConstantPrefix() Ref { 999 ref = ref.Copy() 1000 1001 i := ref.Dynamic() 1002 if i < 0 { 1003 return ref 1004 } 1005 return ref[:i] 1006 } 1007 1008 // GroundPrefix returns the ground portion of the ref starting from the head. By 1009 // definition, the head of the reference is always ground. 1010 func (ref Ref) GroundPrefix() Ref { 1011 prefix := make(Ref, 0, len(ref)) 1012 1013 for i, x := range ref { 1014 if i > 0 && !x.IsGround() { 1015 break 1016 } 1017 prefix = append(prefix, x) 1018 } 1019 1020 return prefix 1021 } 1022 1023 // IsGround returns true if all of the parts of the Ref are ground. 1024 func (ref Ref) IsGround() bool { 1025 if len(ref) == 0 { 1026 return true 1027 } 1028 return termSliceIsGround(ref[1:]) 1029 } 1030 1031 // IsNested returns true if this ref contains other Refs. 1032 func (ref Ref) IsNested() bool { 1033 for _, x := range ref { 1034 if _, ok := x.Value.(Ref); ok { 1035 return true 1036 } 1037 } 1038 return false 1039 } 1040 1041 // Ptr returns a slash-separated path string for this ref. If the ref 1042 // contains non-string terms this function returns an error. Path 1043 // components are escaped. 1044 func (ref Ref) Ptr() (string, error) { 1045 parts := make([]string, 0, len(ref)-1) 1046 for _, term := range ref[1:] { 1047 if str, ok := term.Value.(String); ok { 1048 parts = append(parts, url.PathEscape(string(str))) 1049 } else { 1050 return "", fmt.Errorf("invalid path value type") 1051 } 1052 } 1053 return strings.Join(parts, "/"), nil 1054 } 1055 1056 var varRegexp = regexp.MustCompile("^[[:alpha:]_][[:alpha:][:digit:]_]*$") 1057 1058 func (ref Ref) String() string { 1059 if len(ref) == 0 { 1060 return "" 1061 } 1062 buf := []string{ref[0].Value.String()} 1063 path := ref[1:] 1064 for _, p := range path { 1065 switch p := p.Value.(type) { 1066 case String: 1067 str := string(p) 1068 if varRegexp.MatchString(str) && len(buf) > 0 && !IsKeyword(str) { 1069 buf = append(buf, "."+str) 1070 } else { 1071 buf = append(buf, "["+p.String()+"]") 1072 } 1073 default: 1074 buf = append(buf, "["+p.String()+"]") 1075 } 1076 } 1077 return strings.Join(buf, "") 1078 } 1079 1080 // OutputVars returns a VarSet containing variables that would be bound by evaluating 1081 // this expression in isolation. 1082 func (ref Ref) OutputVars() VarSet { 1083 vis := NewVarVisitor().WithParams(VarVisitorParams{SkipRefHead: true}) 1084 vis.Walk(ref) 1085 return vis.Vars() 1086 } 1087 1088 func (ref Ref) toArray() *Array { 1089 a := NewArray() 1090 for _, term := range ref { 1091 if _, ok := term.Value.(String); ok { 1092 a = a.Append(term) 1093 } else { 1094 a = a.Append(StringTerm(term.Value.String())) 1095 } 1096 } 1097 return a 1098 } 1099 1100 // QueryIterator defines the interface for querying AST documents with references. 1101 type QueryIterator func(map[Var]Value, Value) error 1102 1103 // ArrayTerm creates a new Term with an Array value. 1104 func ArrayTerm(a ...*Term) *Term { 1105 return NewTerm(NewArray(a...)) 1106 } 1107 1108 // NewArray creates an Array with the terms provided. The array will 1109 // use the provided term slice. 1110 func NewArray(a ...*Term) *Array { 1111 hs := make([]int, len(a)) 1112 for i, e := range a { 1113 hs[i] = e.Value.Hash() 1114 } 1115 arr := &Array{elems: a, hashs: hs, ground: termSliceIsGround(a)} 1116 arr.rehash() 1117 return arr 1118 } 1119 1120 // Array represents an array as defined by the language. Arrays are similar to the 1121 // same types as defined by JSON with the exception that they can contain Vars 1122 // and References. 1123 type Array struct { 1124 elems []*Term 1125 hashs []int // element hashes 1126 hash int 1127 ground bool 1128 } 1129 1130 // Copy returns a deep copy of arr. 1131 func (arr *Array) Copy() *Array { 1132 cpy := make([]int, len(arr.elems)) 1133 copy(cpy, arr.hashs) 1134 return &Array{ 1135 elems: termSliceCopy(arr.elems), 1136 hashs: cpy, 1137 hash: arr.hash, 1138 ground: arr.IsGround()} 1139 } 1140 1141 // Equal returns true if arr is equal to other. 1142 func (arr *Array) Equal(other Value) bool { 1143 return Compare(arr, other) == 0 1144 } 1145 1146 // Compare compares arr to other, return <0, 0, or >0 if it is less than, equal to, 1147 // or greater than other. 1148 func (arr *Array) Compare(other Value) int { 1149 return Compare(arr, other) 1150 } 1151 1152 // Find returns the value at the index or an out-of-range error. 1153 func (arr *Array) Find(path Ref) (Value, error) { 1154 if len(path) == 0 { 1155 return arr, nil 1156 } 1157 num, ok := path[0].Value.(Number) 1158 if !ok { 1159 return nil, errFindNotFound 1160 } 1161 i, ok := num.Int() 1162 if !ok { 1163 return nil, errFindNotFound 1164 } 1165 if i < 0 || i >= arr.Len() { 1166 return nil, errFindNotFound 1167 } 1168 return arr.Elem(i).Value.Find(path[1:]) 1169 } 1170 1171 // Get returns the element at pos or nil if not possible. 1172 func (arr *Array) Get(pos *Term) *Term { 1173 num, ok := pos.Value.(Number) 1174 if !ok { 1175 return nil 1176 } 1177 1178 i, ok := num.Int() 1179 if !ok { 1180 return nil 1181 } 1182 1183 if i >= 0 && i < len(arr.elems) { 1184 return arr.elems[i] 1185 } 1186 1187 return nil 1188 } 1189 1190 // Sorted returns a new Array that contains the sorted elements of arr. 1191 func (arr *Array) Sorted() *Array { 1192 cpy := make([]*Term, len(arr.elems)) 1193 for i := range cpy { 1194 cpy[i] = arr.elems[i] 1195 } 1196 sort.Sort(termSlice(cpy)) 1197 a := NewArray(cpy...) 1198 a.hashs = arr.hashs 1199 return a 1200 } 1201 1202 // Hash returns the hash code for the Value. 1203 func (arr *Array) Hash() int { 1204 return arr.hash 1205 } 1206 1207 // IsGround returns true if all of the Array elements are ground. 1208 func (arr *Array) IsGround() bool { 1209 return arr.ground 1210 } 1211 1212 // MarshalJSON returns JSON encoded bytes representing arr. 1213 func (arr *Array) MarshalJSON() ([]byte, error) { 1214 if len(arr.elems) == 0 { 1215 return []byte(`[]`), nil 1216 } 1217 return json.Marshal(arr.elems) 1218 } 1219 1220 func (arr *Array) String() string { 1221 var b strings.Builder 1222 b.WriteRune('[') 1223 for i, e := range arr.elems { 1224 if i > 0 { 1225 b.WriteString(", ") 1226 } 1227 b.WriteString(e.String()) 1228 } 1229 b.WriteRune(']') 1230 return b.String() 1231 } 1232 1233 // Len returns the number of elements in the array. 1234 func (arr *Array) Len() int { 1235 return len(arr.elems) 1236 } 1237 1238 // Elem returns the element i of arr. 1239 func (arr *Array) Elem(i int) *Term { 1240 return arr.elems[i] 1241 } 1242 1243 // rehash updates the cached hash of arr. 1244 func (arr *Array) rehash() { 1245 arr.hash = 0 1246 for _, h := range arr.hashs { 1247 arr.hash += h 1248 } 1249 } 1250 1251 // set sets the element i of arr. 1252 func (arr *Array) set(i int, v *Term) { 1253 arr.ground = arr.ground && v.IsGround() 1254 arr.elems[i] = v 1255 arr.hashs[i] = v.Value.Hash() 1256 } 1257 1258 // Slice returns a slice of arr starting from i index to j. -1 1259 // indicates the end of the array. The returned value array is not a 1260 // copy and any modifications to either of arrays may be reflected to 1261 // the other. 1262 func (arr *Array) Slice(i, j int) *Array { 1263 var elems []*Term 1264 var hashs []int 1265 if j == -1 { 1266 elems = arr.elems[i:] 1267 hashs = arr.hashs[i:] 1268 } else { 1269 elems = arr.elems[i:j] 1270 hashs = arr.hashs[i:j] 1271 } 1272 // If arr is ground, the slice is, too. 1273 // If it's not, the slice could still be. 1274 gr := arr.ground || termSliceIsGround(elems) 1275 1276 s := &Array{elems: elems, hashs: hashs, ground: gr} 1277 s.rehash() 1278 return s 1279 } 1280 1281 // Iter calls f on each element in arr. If f returns an error, 1282 // iteration stops and the return value is the error. 1283 func (arr *Array) Iter(f func(*Term) error) error { 1284 for i := range arr.elems { 1285 if err := f(arr.elems[i]); err != nil { 1286 return err 1287 } 1288 } 1289 return nil 1290 } 1291 1292 // Until calls f on each element in arr. If f returns true, iteration stops. 1293 func (arr *Array) Until(f func(*Term) bool) bool { 1294 err := arr.Iter(func(t *Term) error { 1295 if f(t) { 1296 return errStop 1297 } 1298 return nil 1299 }) 1300 return err != nil 1301 } 1302 1303 // Foreach calls f on each element in arr. 1304 func (arr *Array) Foreach(f func(*Term)) { 1305 _ = arr.Iter(func(t *Term) error { 1306 f(t) 1307 return nil 1308 }) // ignore error 1309 } 1310 1311 // Append appends a term to arr, returning the appended array. 1312 func (arr *Array) Append(v *Term) *Array { 1313 cpy := *arr 1314 cpy.elems = append(arr.elems, v) 1315 cpy.hashs = append(arr.hashs, v.Value.Hash()) 1316 cpy.hash = arr.hash + v.Value.Hash() 1317 cpy.ground = arr.ground && v.IsGround() 1318 return &cpy 1319 } 1320 1321 // Set represents a set as defined by the language. 1322 type Set interface { 1323 Value 1324 Len() int 1325 Copy() Set 1326 Diff(Set) Set 1327 Intersect(Set) Set 1328 Union(Set) Set 1329 Add(*Term) 1330 Iter(func(*Term) error) error 1331 Until(func(*Term) bool) bool 1332 Foreach(func(*Term)) 1333 Contains(*Term) bool 1334 Map(func(*Term) (*Term, error)) (Set, error) 1335 Reduce(*Term, func(*Term, *Term) (*Term, error)) (*Term, error) 1336 Sorted() *Array 1337 Slice() []*Term 1338 } 1339 1340 // NewSet returns a new Set containing t. 1341 func NewSet(t ...*Term) Set { 1342 s := newset(len(t)) 1343 for i := range t { 1344 s.Add(t[i]) 1345 } 1346 return s 1347 } 1348 1349 func newset(n int) *set { 1350 var keys []*Term 1351 if n > 0 { 1352 keys = make([]*Term, 0, n) 1353 } 1354 return &set{ 1355 elems: make(map[int]*Term, n), 1356 keys: keys, 1357 hash: 0, 1358 ground: true, 1359 } 1360 } 1361 1362 // SetTerm returns a new Term representing a set containing terms t. 1363 func SetTerm(t ...*Term) *Term { 1364 set := NewSet(t...) 1365 return &Term{ 1366 Value: set, 1367 } 1368 } 1369 1370 type set struct { 1371 elems map[int]*Term 1372 keys []*Term 1373 hash int 1374 ground bool 1375 } 1376 1377 // Copy returns a deep copy of s. 1378 func (s *set) Copy() Set { 1379 cpy := newset(s.Len()) 1380 s.Foreach(func(x *Term) { 1381 cpy.Add(x.Copy()) 1382 }) 1383 cpy.hash = s.hash 1384 cpy.ground = s.ground 1385 return cpy 1386 } 1387 1388 // IsGround returns true if all terms in s are ground. 1389 func (s *set) IsGround() bool { 1390 return s.ground 1391 } 1392 1393 // Hash returns a hash code for s. 1394 func (s *set) Hash() int { 1395 return s.hash 1396 } 1397 1398 func (s *set) String() string { 1399 if s.Len() == 0 { 1400 return "set()" 1401 } 1402 var b strings.Builder 1403 b.WriteRune('{') 1404 for i := range s.keys { 1405 if i > 0 { 1406 b.WriteString(", ") 1407 } 1408 b.WriteString(s.keys[i].Value.String()) 1409 } 1410 b.WriteRune('}') 1411 return b.String() 1412 } 1413 1414 // Compare compares s to other, return <0, 0, or >0 if it is less than, equal to, 1415 // or greater than other. 1416 func (s *set) Compare(other Value) int { 1417 o1 := sortOrder(s) 1418 o2 := sortOrder(other) 1419 if o1 < o2 { 1420 return -1 1421 } else if o1 > o2 { 1422 return 1 1423 } 1424 t := other.(*set) 1425 return termSliceCompare(s.keys, t.keys) 1426 } 1427 1428 // Find returns the set or dereferences the element itself. 1429 func (s *set) Find(path Ref) (Value, error) { 1430 if len(path) == 0 { 1431 return s, nil 1432 } 1433 if !s.Contains(path[0]) { 1434 return nil, errFindNotFound 1435 } 1436 return path[0].Value.Find(path[1:]) 1437 } 1438 1439 // Diff returns elements in s that are not in other. 1440 func (s *set) Diff(other Set) Set { 1441 r := NewSet() 1442 s.Foreach(func(x *Term) { 1443 if !other.Contains(x) { 1444 r.Add(x) 1445 } 1446 }) 1447 return r 1448 } 1449 1450 // Intersect returns the set containing elements in both s and other. 1451 func (s *set) Intersect(other Set) Set { 1452 o := other.(*set) 1453 n, m := s.Len(), o.Len() 1454 ss := s 1455 so := o 1456 if m < n { 1457 ss = o 1458 so = s 1459 n = m 1460 } 1461 1462 r := newset(n) 1463 ss.Foreach(func(x *Term) { 1464 if so.Contains(x) { 1465 r.Add(x) 1466 } 1467 }) 1468 return r 1469 } 1470 1471 // Union returns the set containing all elements of s and other. 1472 func (s *set) Union(other Set) Set { 1473 r := NewSet() 1474 s.Foreach(func(x *Term) { 1475 r.Add(x) 1476 }) 1477 other.Foreach(func(x *Term) { 1478 r.Add(x) 1479 }) 1480 return r 1481 } 1482 1483 // Add updates s to include t. 1484 func (s *set) Add(t *Term) { 1485 s.insert(t) 1486 } 1487 1488 // Iter calls f on each element in s. If f returns an error, iteration stops 1489 // and the return value is the error. 1490 func (s *set) Iter(f func(*Term) error) error { 1491 for i := range s.keys { 1492 if err := f(s.keys[i]); err != nil { 1493 return err 1494 } 1495 } 1496 return nil 1497 } 1498 1499 var errStop = errors.New("stop") 1500 1501 // Until calls f on each element in s. If f returns true, iteration stops. 1502 func (s *set) Until(f func(*Term) bool) bool { 1503 err := s.Iter(func(t *Term) error { 1504 if f(t) { 1505 return errStop 1506 } 1507 return nil 1508 }) 1509 return err != nil 1510 } 1511 1512 // Foreach calls f on each element in s. 1513 func (s *set) Foreach(f func(*Term)) { 1514 _ = s.Iter(func(t *Term) error { 1515 f(t) 1516 return nil 1517 }) // ignore error 1518 } 1519 1520 // Map returns a new Set obtained by applying f to each value in s. 1521 func (s *set) Map(f func(*Term) (*Term, error)) (Set, error) { 1522 set := NewSet() 1523 err := s.Iter(func(x *Term) error { 1524 term, err := f(x) 1525 if err != nil { 1526 return err 1527 } 1528 set.Add(term) 1529 return nil 1530 }) 1531 if err != nil { 1532 return nil, err 1533 } 1534 return set, nil 1535 } 1536 1537 // Reduce returns a Term produced by applying f to each value in s. The first 1538 // argument to f is the reduced value (starting with i) and the second argument 1539 // to f is the element in s. 1540 func (s *set) Reduce(i *Term, f func(*Term, *Term) (*Term, error)) (*Term, error) { 1541 err := s.Iter(func(x *Term) error { 1542 var err error 1543 i, err = f(i, x) 1544 if err != nil { 1545 return err 1546 } 1547 return nil 1548 }) 1549 return i, err 1550 } 1551 1552 // Contains returns true if t is in s. 1553 func (s *set) Contains(t *Term) bool { 1554 return s.get(t) != nil 1555 } 1556 1557 // Len returns the number of elements in the set. 1558 func (s *set) Len() int { 1559 return len(s.keys) 1560 } 1561 1562 // MarshalJSON returns JSON encoded bytes representing s. 1563 func (s *set) MarshalJSON() ([]byte, error) { 1564 if s.keys == nil { 1565 return []byte(`[]`), nil 1566 } 1567 return json.Marshal(s.keys) 1568 } 1569 1570 // Sorted returns an Array that contains the sorted elements of s. 1571 func (s *set) Sorted() *Array { 1572 cpy := make([]*Term, len(s.keys)) 1573 copy(cpy, s.keys) 1574 sort.Sort(termSlice(cpy)) 1575 return NewArray(cpy...) 1576 } 1577 1578 // Slice returns a slice of terms contained in the set. 1579 func (s *set) Slice() []*Term { 1580 return s.keys 1581 } 1582 1583 func (s *set) insert(x *Term) { 1584 hash := x.Hash() 1585 insertHash := hash 1586 // This `equal` utility is duplicated and manually inlined a number of 1587 // time in this file. Inlining it avoids heap allocations, so it makes 1588 // a big performance difference: some operations like lookup become twice 1589 // as slow without it. 1590 var equal func(v Value) bool 1591 1592 switch x := x.Value.(type) { 1593 case Null, Boolean, String, Var: 1594 equal = func(y Value) bool { return x == y } 1595 case Number: 1596 if xi, err := json.Number(x).Int64(); err == nil { 1597 equal = func(y Value) bool { 1598 if y, ok := y.(Number); ok { 1599 if yi, err := json.Number(y).Int64(); err == nil { 1600 return xi == yi 1601 } 1602 } 1603 1604 return false 1605 } 1606 break 1607 } 1608 1609 // We use big.Rat for comparing big numbers. 1610 // It replaces big.Float due to following reason: 1611 // big.Float comes with a default precision of 64, and setting a 1612 // larger precision results in more memory being allocated 1613 // (regardless of the actual number we are parsing with SetString). 1614 // 1615 // Note: If we're so close to zero that big.Float says we are zero, do 1616 // *not* big.Rat).SetString on the original string it'll potentially 1617 // take very long. 1618 var a *big.Rat 1619 fa, ok := new(big.Float).SetString(string(x)) 1620 if !ok { 1621 panic("illegal value") 1622 } 1623 if fa.IsInt() { 1624 if i, _ := fa.Int64(); i == 0 { 1625 a = new(big.Rat).SetInt64(0) 1626 } 1627 } 1628 if a == nil { 1629 a, ok = new(big.Rat).SetString(string(x)) 1630 if !ok { 1631 panic("illegal value") 1632 } 1633 } 1634 1635 equal = func(b Value) bool { 1636 if bNum, ok := b.(Number); ok { 1637 var b *big.Rat 1638 fb, ok := new(big.Float).SetString(string(bNum)) 1639 if !ok { 1640 panic("illegal value") 1641 } 1642 if fb.IsInt() { 1643 if i, _ := fb.Int64(); i == 0 { 1644 b = new(big.Rat).SetInt64(0) 1645 } 1646 } 1647 if b == nil { 1648 b, ok = new(big.Rat).SetString(string(bNum)) 1649 if !ok { 1650 panic("illegal value") 1651 } 1652 } 1653 1654 return a.Cmp(b) == 0 1655 } 1656 1657 return false 1658 } 1659 default: 1660 equal = func(y Value) bool { return Compare(x, y) == 0 } 1661 } 1662 1663 for curr, ok := s.elems[insertHash]; ok; { 1664 if equal(curr.Value) { 1665 return 1666 } 1667 1668 insertHash++ 1669 curr, ok = s.elems[insertHash] 1670 } 1671 1672 s.elems[insertHash] = x 1673 i := sort.Search(len(s.keys), func(i int) bool { return Compare(x, s.keys[i]) < 0 }) 1674 if i < len(s.keys) { 1675 // insert at position `i`: 1676 s.keys = append(s.keys, nil) // add some space 1677 copy(s.keys[i+1:], s.keys[i:]) // move things over 1678 s.keys[i] = x // drop it in position 1679 } else { 1680 s.keys = append(s.keys, x) 1681 } 1682 1683 s.hash += hash 1684 s.ground = s.ground && x.IsGround() 1685 } 1686 1687 func (s *set) get(x *Term) *Term { 1688 hash := x.Hash() 1689 // This `equal` utility is duplicated and manually inlined a number of 1690 // time in this file. Inlining it avoids heap allocations, so it makes 1691 // a big performance difference: some operations like lookup become twice 1692 // as slow without it. 1693 var equal func(v Value) bool 1694 1695 switch x := x.Value.(type) { 1696 case Null, Boolean, String, Var: 1697 equal = func(y Value) bool { return x == y } 1698 case Number: 1699 if xi, err := json.Number(x).Int64(); err == nil { 1700 equal = func(y Value) bool { 1701 if y, ok := y.(Number); ok { 1702 if yi, err := json.Number(y).Int64(); err == nil { 1703 return xi == yi 1704 } 1705 } 1706 1707 return false 1708 } 1709 break 1710 } 1711 1712 // We use big.Rat for comparing big numbers. 1713 // It replaces big.Float due to following reason: 1714 // big.Float comes with a default precision of 64, and setting a 1715 // larger precision results in more memory being allocated 1716 // (regardless of the actual number we are parsing with SetString). 1717 // 1718 // Note: If we're so close to zero that big.Float says we are zero, do 1719 // *not* big.Rat).SetString on the original string it'll potentially 1720 // take very long. 1721 var a *big.Rat 1722 fa, ok := new(big.Float).SetString(string(x)) 1723 if !ok { 1724 panic("illegal value") 1725 } 1726 if fa.IsInt() { 1727 if i, _ := fa.Int64(); i == 0 { 1728 a = new(big.Rat).SetInt64(0) 1729 } 1730 } 1731 if a == nil { 1732 a, ok = new(big.Rat).SetString(string(x)) 1733 if !ok { 1734 panic("illegal value") 1735 } 1736 } 1737 1738 equal = func(b Value) bool { 1739 if bNum, ok := b.(Number); ok { 1740 var b *big.Rat 1741 fb, ok := new(big.Float).SetString(string(bNum)) 1742 if !ok { 1743 panic("illegal value") 1744 } 1745 if fb.IsInt() { 1746 if i, _ := fb.Int64(); i == 0 { 1747 b = new(big.Rat).SetInt64(0) 1748 } 1749 } 1750 if b == nil { 1751 b, ok = new(big.Rat).SetString(string(bNum)) 1752 if !ok { 1753 panic("illegal value") 1754 } 1755 } 1756 1757 return a.Cmp(b) == 0 1758 } 1759 return false 1760 1761 } 1762 1763 default: 1764 equal = func(y Value) bool { return Compare(x, y) == 0 } 1765 } 1766 1767 for curr, ok := s.elems[hash]; ok; { 1768 if equal(curr.Value) { 1769 return curr 1770 } 1771 1772 hash++ 1773 curr, ok = s.elems[hash] 1774 } 1775 return nil 1776 } 1777 1778 // Object represents an object as defined by the language. 1779 type Object interface { 1780 Value 1781 Len() int 1782 Get(*Term) *Term 1783 Copy() Object 1784 Insert(*Term, *Term) 1785 Iter(func(*Term, *Term) error) error 1786 Until(func(*Term, *Term) bool) bool 1787 Foreach(func(*Term, *Term)) 1788 Map(func(*Term, *Term) (*Term, *Term, error)) (Object, error) 1789 Diff(other Object) Object 1790 Intersect(other Object) [][3]*Term 1791 Merge(other Object) (Object, bool) 1792 MergeWith(other Object, conflictResolver func(v1, v2 *Term) (*Term, bool)) (Object, bool) 1793 Filter(filter Object) (Object, error) 1794 Keys() []*Term 1795 Elem(i int) (*Term, *Term) 1796 get(k *Term) *objectElem // To prevent external implementations 1797 } 1798 1799 // NewObject creates a new Object with t. 1800 func NewObject(t ...[2]*Term) Object { 1801 obj := newobject(len(t)) 1802 for i := range t { 1803 obj.Insert(t[i][0], t[i][1]) 1804 } 1805 return obj 1806 } 1807 1808 // ObjectTerm creates a new Term with an Object value. 1809 func ObjectTerm(o ...[2]*Term) *Term { 1810 return &Term{Value: NewObject(o...)} 1811 } 1812 1813 type object struct { 1814 elems map[int]*objectElem 1815 keys objectElemSlice 1816 ground int // number of key and value grounds. Counting is 1817 // required to support insert's key-value replace. 1818 hash int 1819 } 1820 1821 func newobject(n int) *object { 1822 var keys objectElemSlice 1823 if n > 0 { 1824 keys = make(objectElemSlice, 0, n) 1825 } 1826 return &object{ 1827 elems: make(map[int]*objectElem, n), 1828 keys: keys, 1829 ground: 0, 1830 hash: 0, 1831 } 1832 } 1833 1834 type objectElem struct { 1835 key *Term 1836 value *Term 1837 next *objectElem 1838 } 1839 1840 type objectElemSlice []*objectElem 1841 1842 func (s objectElemSlice) Less(i, j int) bool { return Compare(s[i].key.Value, s[j].key.Value) < 0 } 1843 func (s objectElemSlice) Swap(i, j int) { x := s[i]; s[i] = s[j]; s[j] = x } 1844 func (s objectElemSlice) Len() int { return len(s) } 1845 1846 // Item is a helper for constructing an tuple containing two Terms 1847 // representing a key/value pair in an Object. 1848 func Item(key, value *Term) [2]*Term { 1849 return [2]*Term{key, value} 1850 } 1851 1852 // Compare compares obj to other, return <0, 0, or >0 if it is less than, equal to, 1853 // or greater than other. 1854 func (obj *object) Compare(other Value) int { 1855 o1 := sortOrder(obj) 1856 o2 := sortOrder(other) 1857 if o1 < o2 { 1858 return -1 1859 } else if o2 < o1 { 1860 return 1 1861 } 1862 a := obj 1863 b := other.(*object) 1864 minLen := len(a.keys) 1865 if len(b.keys) < len(a.keys) { 1866 minLen = len(b.keys) 1867 } 1868 for i := 0; i < minLen; i++ { 1869 keysCmp := Compare(a.keys[i].key, b.keys[i].key) 1870 if keysCmp < 0 { 1871 return -1 1872 } 1873 if keysCmp > 0 { 1874 return 1 1875 } 1876 valA := a.keys[i].value 1877 valB := b.keys[i].value 1878 valCmp := Compare(valA, valB) 1879 if valCmp != 0 { 1880 return valCmp 1881 } 1882 } 1883 if len(a.keys) < len(b.keys) { 1884 return -1 1885 } 1886 if len(b.keys) < len(a.keys) { 1887 return 1 1888 } 1889 return 0 1890 } 1891 1892 // Find returns the value at the key or undefined. 1893 func (obj *object) Find(path Ref) (Value, error) { 1894 if len(path) == 0 { 1895 return obj, nil 1896 } 1897 value := obj.Get(path[0]) 1898 if value == nil { 1899 return nil, errFindNotFound 1900 } 1901 return value.Value.Find(path[1:]) 1902 } 1903 1904 func (obj *object) Insert(k, v *Term) { 1905 obj.insert(k, v) 1906 } 1907 1908 // Get returns the value of k in obj if k exists, otherwise nil. 1909 func (obj *object) Get(k *Term) *Term { 1910 if elem := obj.get(k); elem != nil { 1911 return elem.value 1912 } 1913 return nil 1914 } 1915 1916 // Hash returns the hash code for the Value. 1917 func (obj *object) Hash() int { 1918 return obj.hash 1919 } 1920 1921 // IsGround returns true if all of the Object key/value pairs are ground. 1922 func (obj *object) IsGround() bool { 1923 return obj.ground == 2*len(obj.keys) 1924 } 1925 1926 // Copy returns a deep copy of obj. 1927 func (obj *object) Copy() Object { 1928 cpy, _ := obj.Map(func(k, v *Term) (*Term, *Term, error) { 1929 return k.Copy(), v.Copy(), nil 1930 }) 1931 cpy.(*object).hash = obj.hash 1932 return cpy 1933 } 1934 1935 // Diff returns a new Object that contains only the key/value pairs that exist in obj. 1936 func (obj *object) Diff(other Object) Object { 1937 r := NewObject() 1938 obj.Foreach(func(k, v *Term) { 1939 if other.Get(k) == nil { 1940 r.Insert(k, v) 1941 } 1942 }) 1943 return r 1944 } 1945 1946 // Intersect returns a slice of term triplets that represent the intersection of keys 1947 // between obj and other. For each intersecting key, the values from obj and other are included 1948 // as the last two terms in the triplet (respectively). 1949 func (obj *object) Intersect(other Object) [][3]*Term { 1950 r := [][3]*Term{} 1951 obj.Foreach(func(k, v *Term) { 1952 if v2 := other.Get(k); v2 != nil { 1953 r = append(r, [3]*Term{k, v, v2}) 1954 } 1955 }) 1956 return r 1957 } 1958 1959 // Iter calls the function f for each key-value pair in the object. If f 1960 // returns an error, iteration stops and the error is returned. 1961 func (obj *object) Iter(f func(*Term, *Term) error) error { 1962 for _, node := range obj.keys { 1963 if err := f(node.key, node.value); err != nil { 1964 return err 1965 } 1966 } 1967 return nil 1968 } 1969 1970 // Until calls f for each key-value pair in the object. If f returns 1971 // true, iteration stops and Until returns true. Otherwise, return 1972 // false. 1973 func (obj *object) Until(f func(*Term, *Term) bool) bool { 1974 err := obj.Iter(func(k, v *Term) error { 1975 if f(k, v) { 1976 return errStop 1977 } 1978 return nil 1979 }) 1980 return err != nil 1981 } 1982 1983 // Foreach calls f for each key-value pair in the object. 1984 func (obj *object) Foreach(f func(*Term, *Term)) { 1985 _ = obj.Iter(func(k, v *Term) error { 1986 f(k, v) 1987 return nil 1988 }) // ignore error 1989 } 1990 1991 // Map returns a new Object constructed by mapping each element in the object 1992 // using the function f. 1993 func (obj *object) Map(f func(*Term, *Term) (*Term, *Term, error)) (Object, error) { 1994 cpy := newobject(obj.Len()) 1995 err := obj.Iter(func(k, v *Term) error { 1996 var err error 1997 k, v, err = f(k, v) 1998 if err != nil { 1999 return err 2000 } 2001 cpy.insert(k, v) 2002 return nil 2003 }) 2004 if err != nil { 2005 return nil, err 2006 } 2007 return cpy, nil 2008 } 2009 2010 // Keys returns the keys of obj. 2011 func (obj *object) Keys() []*Term { 2012 keys := make([]*Term, len(obj.keys)) 2013 2014 for i, elem := range obj.keys { 2015 keys[i] = elem.key 2016 } 2017 2018 return keys 2019 } 2020 2021 func (obj *object) Elem(i int) (*Term, *Term) { 2022 return obj.keys[i].key, obj.keys[i].value 2023 } 2024 2025 // MarshalJSON returns JSON encoded bytes representing obj. 2026 func (obj *object) MarshalJSON() ([]byte, error) { 2027 sl := make([][2]*Term, obj.Len()) 2028 for i, node := range obj.keys { 2029 sl[i] = Item(node.key, node.value) 2030 } 2031 return json.Marshal(sl) 2032 } 2033 2034 // Merge returns a new Object containing the non-overlapping keys of obj and other. If there are 2035 // overlapping keys between obj and other, the values of associated with the keys are merged. Only 2036 // objects can be merged with other objects. If the values cannot be merged, the second turn value 2037 // will be false. 2038 func (obj object) Merge(other Object) (Object, bool) { 2039 return obj.MergeWith(other, func(v1, v2 *Term) (*Term, bool) { 2040 obj1, ok1 := v1.Value.(Object) 2041 obj2, ok2 := v2.Value.(Object) 2042 if !ok1 || !ok2 { 2043 return nil, true 2044 } 2045 obj3, ok := obj1.Merge(obj2) 2046 if !ok { 2047 return nil, true 2048 } 2049 return NewTerm(obj3), false 2050 }) 2051 } 2052 2053 // MergeWith returns a new Object containing the merged keys of obj and other. 2054 // If there are overlapping keys between obj and other, the conflictResolver 2055 // is called. The conflictResolver can return a merged value and a boolean 2056 // indicating if the merge has failed and should stop. 2057 func (obj object) MergeWith(other Object, conflictResolver func(v1, v2 *Term) (*Term, bool)) (Object, bool) { 2058 result := NewObject() 2059 stop := obj.Until(func(k, v *Term) bool { 2060 v2 := other.Get(k) 2061 // The key didn't exist in other, keep the original value 2062 if v2 == nil { 2063 result.Insert(k, v) 2064 return false 2065 } 2066 2067 // The key exists in both, resolve the conflict if possible 2068 merged, stop := conflictResolver(v, v2) 2069 if !stop { 2070 result.Insert(k, merged) 2071 } 2072 return stop 2073 }) 2074 2075 if stop { 2076 return nil, false 2077 } 2078 2079 // Copy in any values from other for keys that don't exist in obj 2080 other.Foreach(func(k, v *Term) { 2081 if v2 := obj.Get(k); v2 == nil { 2082 result.Insert(k, v) 2083 } 2084 }) 2085 return result, true 2086 } 2087 2088 // Filter returns a new object from values in obj where the keys are 2089 // found in filter. Array indices for values can be specified as 2090 // number strings. 2091 func (obj *object) Filter(filter Object) (Object, error) { 2092 filtered, err := filterObject(obj, filter) 2093 if err != nil { 2094 return nil, err 2095 } 2096 return filtered.(Object), nil 2097 } 2098 2099 // Len returns the number of elements in the object. 2100 func (obj object) Len() int { 2101 return len(obj.keys) 2102 } 2103 2104 func (obj object) String() string { 2105 var b strings.Builder 2106 b.WriteRune('{') 2107 2108 for i, elem := range obj.keys { 2109 if i > 0 { 2110 b.WriteString(", ") 2111 } 2112 b.WriteString(elem.key.String()) 2113 b.WriteString(": ") 2114 b.WriteString(elem.value.String()) 2115 } 2116 b.WriteRune('}') 2117 return b.String() 2118 } 2119 2120 func (obj *object) get(k *Term) *objectElem { 2121 hash := k.Hash() 2122 2123 // This `equal` utility is duplicated and manually inlined a number of 2124 // time in this file. Inlining it avoids heap allocations, so it makes 2125 // a big performance difference: some operations like lookup become twice 2126 // as slow without it. 2127 var equal func(v Value) bool 2128 2129 switch x := k.Value.(type) { 2130 case Null, Boolean, String, Var: 2131 equal = func(y Value) bool { return x == y } 2132 case Number: 2133 if xi, err := json.Number(x).Int64(); err == nil { 2134 equal = func(y Value) bool { 2135 if y, ok := y.(Number); ok { 2136 if yi, err := json.Number(y).Int64(); err == nil { 2137 return xi == yi 2138 } 2139 } 2140 2141 return false 2142 } 2143 break 2144 } 2145 2146 // We use big.Rat for comparing big numbers. 2147 // It replaces big.Float due to following reason: 2148 // big.Float comes with a default precision of 64, and setting a 2149 // larger precision results in more memory being allocated 2150 // (regardless of the actual number we are parsing with SetString). 2151 // 2152 // Note: If we're so close to zero that big.Float says we are zero, do 2153 // *not* big.Rat).SetString on the original string it'll potentially 2154 // take very long. 2155 var a *big.Rat 2156 fa, ok := new(big.Float).SetString(string(x)) 2157 if !ok { 2158 panic("illegal value") 2159 } 2160 if fa.IsInt() { 2161 if i, _ := fa.Int64(); i == 0 { 2162 a = new(big.Rat).SetInt64(0) 2163 } 2164 } 2165 if a == nil { 2166 a, ok = new(big.Rat).SetString(string(x)) 2167 if !ok { 2168 panic("illegal value") 2169 } 2170 } 2171 2172 equal = func(b Value) bool { 2173 if bNum, ok := b.(Number); ok { 2174 var b *big.Rat 2175 fb, ok := new(big.Float).SetString(string(bNum)) 2176 if !ok { 2177 panic("illegal value") 2178 } 2179 if fb.IsInt() { 2180 if i, _ := fb.Int64(); i == 0 { 2181 b = new(big.Rat).SetInt64(0) 2182 } 2183 } 2184 if b == nil { 2185 b, ok = new(big.Rat).SetString(string(bNum)) 2186 if !ok { 2187 panic("illegal value") 2188 } 2189 } 2190 2191 return a.Cmp(b) == 0 2192 } 2193 2194 return false 2195 } 2196 default: 2197 equal = func(y Value) bool { return Compare(x, y) == 0 } 2198 } 2199 2200 for curr := obj.elems[hash]; curr != nil; curr = curr.next { 2201 if equal(curr.key.Value) { 2202 return curr 2203 } 2204 } 2205 return nil 2206 } 2207 2208 func (obj *object) insert(k, v *Term) { 2209 hash := k.Hash() 2210 head := obj.elems[hash] 2211 // This `equal` utility is duplicated and manually inlined a number of 2212 // time in this file. Inlining it avoids heap allocations, so it makes 2213 // a big performance difference: some operations like lookup become twice 2214 // as slow without it. 2215 var equal func(v Value) bool 2216 2217 switch x := k.Value.(type) { 2218 case Null, Boolean, String, Var: 2219 equal = func(y Value) bool { return x == y } 2220 case Number: 2221 if xi, err := json.Number(x).Int64(); err == nil { 2222 equal = func(y Value) bool { 2223 if y, ok := y.(Number); ok { 2224 if yi, err := json.Number(y).Int64(); err == nil { 2225 return xi == yi 2226 } 2227 } 2228 2229 return false 2230 } 2231 break 2232 } 2233 2234 // We use big.Rat for comparing big numbers. 2235 // It replaces big.Float due to following reason: 2236 // big.Float comes with a default precision of 64, and setting a 2237 // larger precision results in more memory being allocated 2238 // (regardless of the actual number we are parsing with SetString). 2239 // 2240 // Note: If we're so close to zero that big.Float says we are zero, do 2241 // *not* big.Rat).SetString on the original string it'll potentially 2242 // take very long. 2243 var a *big.Rat 2244 fa, ok := new(big.Float).SetString(string(x)) 2245 if !ok { 2246 panic("illegal value") 2247 } 2248 if fa.IsInt() { 2249 if i, _ := fa.Int64(); i == 0 { 2250 a = new(big.Rat).SetInt64(0) 2251 } 2252 } 2253 if a == nil { 2254 a, ok = new(big.Rat).SetString(string(x)) 2255 if !ok { 2256 panic("illegal value") 2257 } 2258 } 2259 2260 equal = func(b Value) bool { 2261 if bNum, ok := b.(Number); ok { 2262 var b *big.Rat 2263 fb, ok := new(big.Float).SetString(string(bNum)) 2264 if !ok { 2265 panic("illegal value") 2266 } 2267 if fb.IsInt() { 2268 if i, _ := fb.Int64(); i == 0 { 2269 b = new(big.Rat).SetInt64(0) 2270 } 2271 } 2272 if b == nil { 2273 b, ok = new(big.Rat).SetString(string(bNum)) 2274 if !ok { 2275 panic("illegal value") 2276 } 2277 } 2278 2279 return a.Cmp(b) == 0 2280 } 2281 2282 return false 2283 } 2284 default: 2285 equal = func(y Value) bool { return Compare(x, y) == 0 } 2286 } 2287 2288 for curr := head; curr != nil; curr = curr.next { 2289 if equal(curr.key.Value) { 2290 // The ground bit of the value may change in 2291 // replace, hence adjust the counter per old 2292 // and new value. 2293 2294 if curr.value.IsGround() { 2295 obj.ground-- 2296 } 2297 if v.IsGround() { 2298 obj.ground++ 2299 } 2300 2301 curr.value = v 2302 return 2303 } 2304 } 2305 elem := &objectElem{ 2306 key: k, 2307 value: v, 2308 next: head, 2309 } 2310 obj.elems[hash] = elem 2311 i := sort.Search(len(obj.keys), func(i int) bool { return Compare(elem.key, obj.keys[i].key) < 0 }) 2312 if i < len(obj.keys) { 2313 // insert at position `i`: 2314 obj.keys = append(obj.keys, nil) // add some space 2315 copy(obj.keys[i+1:], obj.keys[i:]) // move things over 2316 obj.keys[i] = elem // drop it in position 2317 } else { 2318 obj.keys = append(obj.keys, elem) 2319 } 2320 obj.hash += hash + v.Hash() 2321 2322 if k.IsGround() { 2323 obj.ground++ 2324 } 2325 if v.IsGround() { 2326 obj.ground++ 2327 } 2328 } 2329 2330 func filterObject(o Value, filter Value) (Value, error) { 2331 if filter.Compare(Null{}) == 0 { 2332 return o, nil 2333 } 2334 2335 filteredObj, ok := filter.(*object) 2336 if !ok { 2337 return nil, fmt.Errorf("invalid filter value %q, expected an object", filter) 2338 } 2339 2340 switch v := o.(type) { 2341 case String, Number, Boolean, Null: 2342 return o, nil 2343 case *Array: 2344 values := NewArray() 2345 for i := 0; i < v.Len(); i++ { 2346 subFilter := filteredObj.Get(StringTerm(strconv.Itoa(i))) 2347 if subFilter != nil { 2348 filteredValue, err := filterObject(v.Elem(i).Value, subFilter.Value) 2349 if err != nil { 2350 return nil, err 2351 } 2352 values = values.Append(NewTerm(filteredValue)) 2353 } 2354 } 2355 return values, nil 2356 case Set: 2357 values := NewSet() 2358 err := v.Iter(func(t *Term) error { 2359 if filteredObj.Get(t) != nil { 2360 filteredValue, err := filterObject(t.Value, filteredObj.Get(t).Value) 2361 if err != nil { 2362 return err 2363 } 2364 values.Add(NewTerm(filteredValue)) 2365 } 2366 return nil 2367 }) 2368 return values, err 2369 case *object: 2370 values := NewObject() 2371 2372 iterObj := v 2373 other := filteredObj 2374 if v.Len() < filteredObj.Len() { 2375 iterObj = filteredObj 2376 other = v 2377 } 2378 2379 err := iterObj.Iter(func(key *Term, value *Term) error { 2380 if other.Get(key) != nil { 2381 filteredValue, err := filterObject(v.Get(key).Value, filteredObj.Get(key).Value) 2382 if err != nil { 2383 return err 2384 } 2385 values.Insert(key, NewTerm(filteredValue)) 2386 } 2387 return nil 2388 }) 2389 return values, err 2390 default: 2391 return nil, fmt.Errorf("invalid object value type %q", v) 2392 } 2393 } 2394 2395 // ArrayComprehension represents an array comprehension as defined in the language. 2396 type ArrayComprehension struct { 2397 Term *Term `json:"term"` 2398 Body Body `json:"body"` 2399 } 2400 2401 // ArrayComprehensionTerm creates a new Term with an ArrayComprehension value. 2402 func ArrayComprehensionTerm(term *Term, body Body) *Term { 2403 return &Term{ 2404 Value: &ArrayComprehension{ 2405 Term: term, 2406 Body: body, 2407 }, 2408 } 2409 } 2410 2411 // Copy returns a deep copy of ac. 2412 func (ac *ArrayComprehension) Copy() *ArrayComprehension { 2413 cpy := *ac 2414 cpy.Body = ac.Body.Copy() 2415 cpy.Term = ac.Term.Copy() 2416 return &cpy 2417 } 2418 2419 // Equal returns true if ac is equal to other. 2420 func (ac *ArrayComprehension) Equal(other Value) bool { 2421 return Compare(ac, other) == 0 2422 } 2423 2424 // Compare compares ac to other, return <0, 0, or >0 if it is less than, equal to, 2425 // or greater than other. 2426 func (ac *ArrayComprehension) Compare(other Value) int { 2427 return Compare(ac, other) 2428 } 2429 2430 // Find returns the current value or a not found error. 2431 func (ac *ArrayComprehension) Find(path Ref) (Value, error) { 2432 if len(path) == 0 { 2433 return ac, nil 2434 } 2435 return nil, errFindNotFound 2436 } 2437 2438 // Hash returns the hash code of the Value. 2439 func (ac *ArrayComprehension) Hash() int { 2440 return ac.Term.Hash() + ac.Body.Hash() 2441 } 2442 2443 // IsGround returns true if the Term and Body are ground. 2444 func (ac *ArrayComprehension) IsGround() bool { 2445 return ac.Term.IsGround() && ac.Body.IsGround() 2446 } 2447 2448 func (ac *ArrayComprehension) String() string { 2449 return "[" + ac.Term.String() + " | " + ac.Body.String() + "]" 2450 } 2451 2452 // ObjectComprehension represents an object comprehension as defined in the language. 2453 type ObjectComprehension struct { 2454 Key *Term `json:"key"` 2455 Value *Term `json:"value"` 2456 Body Body `json:"body"` 2457 } 2458 2459 // ObjectComprehensionTerm creates a new Term with an ObjectComprehension value. 2460 func ObjectComprehensionTerm(key, value *Term, body Body) *Term { 2461 return &Term{ 2462 Value: &ObjectComprehension{ 2463 Key: key, 2464 Value: value, 2465 Body: body, 2466 }, 2467 } 2468 } 2469 2470 // Copy returns a deep copy of oc. 2471 func (oc *ObjectComprehension) Copy() *ObjectComprehension { 2472 cpy := *oc 2473 cpy.Body = oc.Body.Copy() 2474 cpy.Key = oc.Key.Copy() 2475 cpy.Value = oc.Value.Copy() 2476 return &cpy 2477 } 2478 2479 // Equal returns true if oc is equal to other. 2480 func (oc *ObjectComprehension) Equal(other Value) bool { 2481 return Compare(oc, other) == 0 2482 } 2483 2484 // Compare compares oc to other, return <0, 0, or >0 if it is less than, equal to, 2485 // or greater than other. 2486 func (oc *ObjectComprehension) Compare(other Value) int { 2487 return Compare(oc, other) 2488 } 2489 2490 // Find returns the current value or a not found error. 2491 func (oc *ObjectComprehension) Find(path Ref) (Value, error) { 2492 if len(path) == 0 { 2493 return oc, nil 2494 } 2495 return nil, errFindNotFound 2496 } 2497 2498 // Hash returns the hash code of the Value. 2499 func (oc *ObjectComprehension) Hash() int { 2500 return oc.Key.Hash() + oc.Value.Hash() + oc.Body.Hash() 2501 } 2502 2503 // IsGround returns true if the Key, Value and Body are ground. 2504 func (oc *ObjectComprehension) IsGround() bool { 2505 return oc.Key.IsGround() && oc.Value.IsGround() && oc.Body.IsGround() 2506 } 2507 2508 func (oc *ObjectComprehension) String() string { 2509 return "{" + oc.Key.String() + ": " + oc.Value.String() + " | " + oc.Body.String() + "}" 2510 } 2511 2512 // SetComprehension represents a set comprehension as defined in the language. 2513 type SetComprehension struct { 2514 Term *Term `json:"term"` 2515 Body Body `json:"body"` 2516 } 2517 2518 // SetComprehensionTerm creates a new Term with an SetComprehension value. 2519 func SetComprehensionTerm(term *Term, body Body) *Term { 2520 return &Term{ 2521 Value: &SetComprehension{ 2522 Term: term, 2523 Body: body, 2524 }, 2525 } 2526 } 2527 2528 // Copy returns a deep copy of sc. 2529 func (sc *SetComprehension) Copy() *SetComprehension { 2530 cpy := *sc 2531 cpy.Body = sc.Body.Copy() 2532 cpy.Term = sc.Term.Copy() 2533 return &cpy 2534 } 2535 2536 // Equal returns true if sc is equal to other. 2537 func (sc *SetComprehension) Equal(other Value) bool { 2538 return Compare(sc, other) == 0 2539 } 2540 2541 // Compare compares sc to other, return <0, 0, or >0 if it is less than, equal to, 2542 // or greater than other. 2543 func (sc *SetComprehension) Compare(other Value) int { 2544 return Compare(sc, other) 2545 } 2546 2547 // Find returns the current value or a not found error. 2548 func (sc *SetComprehension) Find(path Ref) (Value, error) { 2549 if len(path) == 0 { 2550 return sc, nil 2551 } 2552 return nil, errFindNotFound 2553 } 2554 2555 // Hash returns the hash code of the Value. 2556 func (sc *SetComprehension) Hash() int { 2557 return sc.Term.Hash() + sc.Body.Hash() 2558 } 2559 2560 // IsGround returns true if the Term and Body are ground. 2561 func (sc *SetComprehension) IsGround() bool { 2562 return sc.Term.IsGround() && sc.Body.IsGround() 2563 } 2564 2565 func (sc *SetComprehension) String() string { 2566 return "{" + sc.Term.String() + " | " + sc.Body.String() + "}" 2567 } 2568 2569 // Call represents as function call in the language. 2570 type Call []*Term 2571 2572 // CallTerm returns a new Term with a Call value defined by terms. The first 2573 // term is the operator and the rest are operands. 2574 func CallTerm(terms ...*Term) *Term { 2575 return NewTerm(Call(terms)) 2576 } 2577 2578 // Copy returns a deep copy of c. 2579 func (c Call) Copy() Call { 2580 return termSliceCopy(c) 2581 } 2582 2583 // Compare compares c to other, return <0, 0, or >0 if it is less than, equal to, 2584 // or greater than other. 2585 func (c Call) Compare(other Value) int { 2586 return Compare(c, other) 2587 } 2588 2589 // Find returns the current value or a not found error. 2590 func (c Call) Find(Ref) (Value, error) { 2591 return nil, errFindNotFound 2592 } 2593 2594 // Hash returns the hash code for the Value. 2595 func (c Call) Hash() int { 2596 return termSliceHash(c) 2597 } 2598 2599 // IsGround returns true if the Value is ground. 2600 func (c Call) IsGround() bool { 2601 return termSliceIsGround(c) 2602 } 2603 2604 // MakeExpr returns an ew Expr from this call. 2605 func (c Call) MakeExpr(output *Term) *Expr { 2606 terms := []*Term(c) 2607 return NewExpr(append(terms, output)) 2608 } 2609 2610 func (c Call) String() string { 2611 args := make([]string, len(c)-1) 2612 for i := 1; i < len(c); i++ { 2613 args[i-1] = c[i].String() 2614 } 2615 return fmt.Sprintf("%v(%v)", c[0], strings.Join(args, ", ")) 2616 } 2617 2618 func termSliceCopy(a []*Term) []*Term { 2619 cpy := make([]*Term, len(a)) 2620 for i := range a { 2621 cpy[i] = a[i].Copy() 2622 } 2623 return cpy 2624 } 2625 2626 func termSliceEqual(a, b []*Term) bool { 2627 if len(a) == len(b) { 2628 for i := range a { 2629 if !a[i].Equal(b[i]) { 2630 return false 2631 } 2632 } 2633 return true 2634 } 2635 return false 2636 } 2637 2638 func termSliceHash(a []*Term) int { 2639 var hash int 2640 for _, v := range a { 2641 hash += v.Value.Hash() 2642 } 2643 return hash 2644 } 2645 2646 func termSliceIsGround(a []*Term) bool { 2647 for _, v := range a { 2648 if !v.IsGround() { 2649 return false 2650 } 2651 } 2652 return true 2653 } 2654 2655 // NOTE(tsandall): The unmarshalling errors in these functions are not 2656 // helpful for callers because they do not identify the source of the 2657 // unmarshalling error. Because OPA doesn't accept JSON describing ASTs 2658 // from callers, this is acceptable (for now). If that changes in the future, 2659 // the error messages should be revisited. The current approach focuses 2660 // on the happy path and treats all errors the same. If better error 2661 // reporting is needed, the error paths will need to be fleshed out. 2662 2663 func unmarshalBody(b []interface{}) (Body, error) { 2664 buf := Body{} 2665 for _, e := range b { 2666 if m, ok := e.(map[string]interface{}); ok { 2667 expr := &Expr{} 2668 if err := unmarshalExpr(expr, m); err == nil { 2669 buf = append(buf, expr) 2670 continue 2671 } 2672 } 2673 goto unmarshal_error 2674 } 2675 return buf, nil 2676 unmarshal_error: 2677 return nil, fmt.Errorf("ast: unable to unmarshal body") 2678 } 2679 2680 func unmarshalExpr(expr *Expr, v map[string]interface{}) error { 2681 if x, ok := v["negated"]; ok { 2682 if b, ok := x.(bool); ok { 2683 expr.Negated = b 2684 } else { 2685 return fmt.Errorf("ast: unable to unmarshal negated field with type: %T (expected true or false)", v["negated"]) 2686 } 2687 } 2688 if err := unmarshalExprIndex(expr, v); err != nil { 2689 return err 2690 } 2691 switch ts := v["terms"].(type) { 2692 case map[string]interface{}: 2693 t, err := unmarshalTerm(ts) 2694 if err != nil { 2695 return err 2696 } 2697 expr.Terms = t 2698 case []interface{}: 2699 terms, err := unmarshalTermSlice(ts) 2700 if err != nil { 2701 return err 2702 } 2703 expr.Terms = terms 2704 default: 2705 return fmt.Errorf(`ast: unable to unmarshal terms field with type: %T (expected {"value": ..., "type": ...} or [{"value": ..., "type": ...}, ...])`, v["terms"]) 2706 } 2707 if x, ok := v["with"]; ok { 2708 if sl, ok := x.([]interface{}); ok { 2709 ws := make([]*With, len(sl)) 2710 for i := range sl { 2711 var err error 2712 ws[i], err = unmarshalWith(sl[i]) 2713 if err != nil { 2714 return err 2715 } 2716 } 2717 expr.With = ws 2718 } 2719 } 2720 return nil 2721 } 2722 2723 func unmarshalExprIndex(expr *Expr, v map[string]interface{}) error { 2724 if x, ok := v["index"]; ok { 2725 if n, ok := x.(json.Number); ok { 2726 i, err := n.Int64() 2727 if err == nil { 2728 expr.Index = int(i) 2729 return nil 2730 } 2731 } 2732 } 2733 return fmt.Errorf("ast: unable to unmarshal index field with type: %T (expected integer)", v["index"]) 2734 } 2735 2736 func unmarshalTerm(m map[string]interface{}) (*Term, error) { 2737 v, err := unmarshalValue(m) 2738 if err != nil { 2739 return nil, err 2740 } 2741 return &Term{Value: v}, nil 2742 } 2743 2744 func unmarshalTermSlice(s []interface{}) ([]*Term, error) { 2745 buf := []*Term{} 2746 for _, x := range s { 2747 if m, ok := x.(map[string]interface{}); ok { 2748 if t, err := unmarshalTerm(m); err == nil { 2749 buf = append(buf, t) 2750 continue 2751 } else { 2752 return nil, err 2753 } 2754 } 2755 return nil, fmt.Errorf("ast: unable to unmarshal term") 2756 } 2757 return buf, nil 2758 } 2759 2760 func unmarshalTermSliceValue(d map[string]interface{}) ([]*Term, error) { 2761 if s, ok := d["value"].([]interface{}); ok { 2762 return unmarshalTermSlice(s) 2763 } 2764 return nil, fmt.Errorf(`ast: unable to unmarshal term (expected {"value": [...], "type": ...} where type is one of: ref, array, or set)`) 2765 } 2766 2767 func unmarshalWith(i interface{}) (*With, error) { 2768 if m, ok := i.(map[string]interface{}); ok { 2769 tgt, _ := m["target"].(map[string]interface{}) 2770 target, err := unmarshalTerm(tgt) 2771 if err == nil { 2772 val, _ := m["value"].(map[string]interface{}) 2773 value, err := unmarshalTerm(val) 2774 if err == nil { 2775 return &With{ 2776 Target: target, 2777 Value: value, 2778 }, nil 2779 } 2780 return nil, err 2781 } 2782 return nil, err 2783 } 2784 return nil, fmt.Errorf(`ast: unable to unmarshal with modifier (expected {"target": {...}, "value": {...}})`) 2785 } 2786 2787 func unmarshalValue(d map[string]interface{}) (Value, error) { 2788 v := d["value"] 2789 switch d["type"] { 2790 case "null": 2791 return Null{}, nil 2792 case "boolean": 2793 if b, ok := v.(bool); ok { 2794 return Boolean(b), nil 2795 } 2796 case "number": 2797 if n, ok := v.(json.Number); ok { 2798 return Number(n), nil 2799 } 2800 case "string": 2801 if s, ok := v.(string); ok { 2802 return String(s), nil 2803 } 2804 case "var": 2805 if s, ok := v.(string); ok { 2806 return Var(s), nil 2807 } 2808 case "ref": 2809 if s, err := unmarshalTermSliceValue(d); err == nil { 2810 return Ref(s), nil 2811 } 2812 case "array": 2813 if s, err := unmarshalTermSliceValue(d); err == nil { 2814 return NewArray(s...), nil 2815 } 2816 case "set": 2817 if s, err := unmarshalTermSliceValue(d); err == nil { 2818 set := NewSet() 2819 for _, x := range s { 2820 set.Add(x) 2821 } 2822 return set, nil 2823 } 2824 case "object": 2825 if s, ok := v.([]interface{}); ok { 2826 buf := NewObject() 2827 for _, x := range s { 2828 if i, ok := x.([]interface{}); ok && len(i) == 2 { 2829 p, err := unmarshalTermSlice(i) 2830 if err == nil { 2831 buf.Insert(p[0], p[1]) 2832 continue 2833 } 2834 } 2835 goto unmarshal_error 2836 } 2837 return buf, nil 2838 } 2839 case "arraycomprehension", "setcomprehension": 2840 if m, ok := v.(map[string]interface{}); ok { 2841 t, ok := m["term"].(map[string]interface{}) 2842 if !ok { 2843 goto unmarshal_error 2844 } 2845 2846 term, err := unmarshalTerm(t) 2847 if err != nil { 2848 goto unmarshal_error 2849 } 2850 2851 b, ok := m["body"].([]interface{}) 2852 if !ok { 2853 goto unmarshal_error 2854 } 2855 2856 body, err := unmarshalBody(b) 2857 if err != nil { 2858 goto unmarshal_error 2859 } 2860 2861 if d["type"] == "arraycomprehension" { 2862 return &ArrayComprehension{Term: term, Body: body}, nil 2863 } 2864 return &SetComprehension{Term: term, Body: body}, nil 2865 } 2866 case "objectcomprehension": 2867 if m, ok := v.(map[string]interface{}); ok { 2868 k, ok := m["key"].(map[string]interface{}) 2869 if !ok { 2870 goto unmarshal_error 2871 } 2872 2873 key, err := unmarshalTerm(k) 2874 if err != nil { 2875 goto unmarshal_error 2876 } 2877 2878 v, ok := m["value"].(map[string]interface{}) 2879 if !ok { 2880 goto unmarshal_error 2881 } 2882 2883 value, err := unmarshalTerm(v) 2884 if err != nil { 2885 goto unmarshal_error 2886 } 2887 2888 b, ok := m["body"].([]interface{}) 2889 if !ok { 2890 goto unmarshal_error 2891 } 2892 2893 body, err := unmarshalBody(b) 2894 if err != nil { 2895 goto unmarshal_error 2896 } 2897 2898 return &ObjectComprehension{Key: key, Value: value, Body: body}, nil 2899 } 2900 case "call": 2901 if s, err := unmarshalTermSliceValue(d); err == nil { 2902 return Call(s), nil 2903 } 2904 } 2905 unmarshal_error: 2906 return nil, fmt.Errorf("ast: unable to unmarshal term") 2907 }