github.com/goshafaq/sonic@v0.0.0-20231026082336-871835fb94c6/ast/visitor_test.go (about) 1 /* 2 * Copyright 2021 ByteDance Inc. 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package ast 18 19 import ( 20 "bufio" 21 "encoding/json" 22 "fmt" 23 "io" 24 "os" 25 "sort" 26 "strings" 27 "testing" 28 29 "github.com/stretchr/testify/assert" 30 "github.com/stretchr/testify/require" 31 ) 32 33 type visitorNodeDiffTest struct { 34 t *testing.T 35 str string 36 37 tracer io.Writer 38 39 cursor Node 40 stk visitorNodeStack 41 sp uint8 42 } 43 44 type visitorNodeStack = [256]struct { 45 Node Node 46 Object map[string]Node 47 Array []Node 48 49 ObjectKey string 50 } 51 52 func (self *visitorNodeDiffTest) incrSP() { 53 self.t.Helper() 54 self.sp++ 55 require.NotZero(self.t, self.sp, "stack overflow") 56 } 57 58 func (self *visitorNodeDiffTest) debugStack() string { 59 var buf strings.Builder 60 buf.WriteString("[") 61 for i := uint8(0); i < self.sp; i++ { 62 if i != 0 { 63 buf.WriteString(", ") 64 } 65 if self.stk[i].Array != nil { 66 buf.WriteString("Array") 67 } else if self.stk[i].Object != nil { 68 buf.WriteString("Object") 69 } else { 70 fmt.Fprintf(&buf, "Key(%q)", self.stk[i].ObjectKey) 71 } 72 } 73 buf.WriteString("]") 74 return buf.String() 75 } 76 77 func (self *visitorNodeDiffTest) requireType(got int) { 78 self.t.Helper() 79 want := self.cursor.Type() 80 require.EqualValues(self.t, want, got) 81 } 82 83 func (self *visitorNodeDiffTest) toArrayIndex(array Node, i int) { 84 // set cursor to next Value if existed 85 self.t.Helper() 86 n, err := array.Len() 87 require.NoError(self.t, err) 88 if i < n { 89 self.cursor = *array.Index(i) 90 require.NoError(self.t, self.cursor.Check()) 91 } 92 } 93 94 func (self *visitorNodeDiffTest) onValueEnd() { 95 if self.tracer != nil { 96 fmt.Fprintf(self.tracer, "OnValueEnd: %s\n", self.debugStack()) 97 } 98 // cursor should point to the Value now 99 self.t.Helper() 100 if self.sp == 0 { 101 if self.tracer != nil { 102 fmt.Fprintf(self.tracer, "EOF\n\n") 103 } 104 return 105 } 106 // [..., Array, sp] 107 if array := self.stk[self.sp-1].Array; array != nil { 108 array = append(array, self.cursor) 109 self.stk[self.sp-1].Array = array 110 self.toArrayIndex(self.stk[self.sp-1].Node, len(array)) 111 return 112 } 113 // [..., Object, ObjectKey, sp] 114 require.GreaterOrEqual(self.t, self.sp, uint8(2)) 115 require.NotNil(self.t, self.stk[self.sp-2].Object) 116 require.Nil(self.t, self.stk[self.sp-1].Object) 117 require.Nil(self.t, self.stk[self.sp-1].Array) 118 self.stk[self.sp-2].Object[self.stk[self.sp-1].ObjectKey] = self.cursor 119 self.cursor = self.stk[self.sp-2].Node // reset cursor to Object 120 self.sp-- // pop ObjectKey 121 } 122 123 func (self *visitorNodeDiffTest) OnNull() error { 124 if self.tracer != nil { 125 fmt.Fprintf(self.tracer, "OnNull\n") 126 } 127 self.requireType(V_NULL) 128 self.onValueEnd() 129 return nil 130 } 131 132 func (self *visitorNodeDiffTest) OnBool(v bool) error { 133 if self.tracer != nil { 134 fmt.Fprintf(self.tracer, "OnBool: %t\n", v) 135 } 136 if v { 137 self.requireType(V_TRUE) 138 } else { 139 self.requireType(V_FALSE) 140 } 141 self.onValueEnd() 142 return nil 143 } 144 145 func (self *visitorNodeDiffTest) OnString(v string) error { 146 if self.tracer != nil { 147 fmt.Fprintf(self.tracer, "OnString: %q\n", v) 148 } 149 self.requireType(V_STRING) 150 want, err := self.cursor.StrictString() 151 require.NoError(self.t, err) 152 require.EqualValues(self.t, want, v) 153 self.onValueEnd() 154 return nil 155 } 156 157 func (self *visitorNodeDiffTest) OnInt64(v int64, n json.Number) error { 158 if self.tracer != nil { 159 fmt.Fprintf(self.tracer, "OnInt64: %d (%q)\n", v, n) 160 } 161 self.requireType(V_NUMBER) 162 want, err := self.cursor.StrictInt64() 163 require.NoError(self.t, err) 164 require.EqualValues(self.t, want, v) 165 nv, err := n.Int64() 166 require.NoError(self.t, err) 167 require.EqualValues(self.t, want, nv) 168 self.onValueEnd() 169 return nil 170 } 171 172 func (self *visitorNodeDiffTest) OnFloat64(v float64, n json.Number) error { 173 if self.tracer != nil { 174 fmt.Fprintf(self.tracer, "OnFloat64: %f (%q)\n", v, n) 175 } 176 self.requireType(V_NUMBER) 177 want, err := self.cursor.StrictFloat64() 178 require.NoError(self.t, err) 179 require.EqualValues(self.t, want, v) 180 nv, err := n.Float64() 181 require.NoError(self.t, err) 182 require.EqualValues(self.t, want, nv) 183 self.onValueEnd() 184 return nil 185 } 186 187 func (self *visitorNodeDiffTest) OnObjectBegin(capacity int) error { 188 if self.tracer != nil { 189 fmt.Fprintf(self.tracer, "OnObjectBegin: %d\n", capacity) 190 } 191 self.requireType(V_OBJECT) 192 self.stk[self.sp].Node = self.cursor 193 self.stk[self.sp].Object = make(map[string]Node, capacity) 194 self.incrSP() 195 return nil 196 } 197 198 func (self *visitorNodeDiffTest) OnObjectKey(key string) error { 199 if self.tracer != nil { 200 fmt.Fprintf(self.tracer, "OnObjectKey: %q %s\n", key, self.debugStack()) 201 } 202 require.NotNil(self.t, self.stk[self.sp-1].Object) 203 node := self.stk[self.sp-1].Node 204 self.stk[self.sp].ObjectKey = key 205 self.incrSP() 206 self.cursor = *node.Get(key) // set cursor to Value 207 require.NoError(self.t, self.cursor.Check()) 208 return nil 209 } 210 211 func (self *visitorNodeDiffTest) OnObjectEnd() error { 212 if self.tracer != nil { 213 fmt.Fprintf(self.tracer, "OnObjectEnd\n") 214 } 215 object := self.stk[self.sp-1].Object 216 require.NotNil(self.t, object) 217 218 node := self.stk[self.sp-1].Node 219 ps, err := node.unsafeMap() 220 var pairs = make([]Pair, ps.Len()) 221 ps.ToSlice(pairs) 222 require.NoError(self.t, err) 223 224 keysGot := make([]string, 0, len(object)) 225 for key := range object { 226 keysGot = append(keysGot, key) 227 } 228 keysWant := make([]string, 0, len(pairs)) 229 for _, pair := range pairs { 230 keysWant = append(keysWant, pair.Key) 231 } 232 sort.Strings(keysGot) 233 sort.Strings(keysWant) 234 require.EqualValues(self.t, keysWant, keysGot) 235 236 for _, pair := range pairs { 237 typeGot := object[pair.Key].Type() 238 typeWant := pair.Value.Type() 239 require.EqualValues(self.t, typeWant, typeGot) 240 } 241 242 // pop Object 243 self.sp-- 244 self.stk[self.sp].Node = Node{} 245 self.stk[self.sp].Object = nil 246 247 self.cursor = node // set cursor to this Object 248 self.onValueEnd() 249 return nil 250 } 251 252 func (self *visitorNodeDiffTest) OnArrayBegin(capacity int) error { 253 if self.tracer != nil { 254 fmt.Fprintf(self.tracer, "OnArrayBegin: %d\n", capacity) 255 } 256 self.requireType(V_ARRAY) 257 self.stk[self.sp].Node = self.cursor 258 self.stk[self.sp].Array = make([]Node, 0, capacity) 259 self.incrSP() 260 self.toArrayIndex(self.stk[self.sp-1].Node, 0) 261 return nil 262 } 263 264 func (self *visitorNodeDiffTest) OnArrayEnd() error { 265 if self.tracer != nil { 266 fmt.Fprintf(self.tracer, "OnArrayEnd\n") 267 } 268 array := self.stk[self.sp-1].Array 269 require.NotNil(self.t, array) 270 271 node := self.stk[self.sp-1].Node 272 vs, err := node.unsafeArray() 273 require.NoError(self.t, err) 274 var values = make([]Node, vs.Len()) 275 vs.ToSlice(values) 276 277 require.EqualValues(self.t, len(values), len(array)) 278 279 for i, n := 0, len(values); i < n; i++ { 280 typeGot := array[i].Type() 281 typeWant := values[i].Type() 282 require.EqualValues(self.t, typeWant, typeGot) 283 } 284 285 // pop Array 286 self.sp-- 287 self.stk[self.sp].Node = Node{} 288 self.stk[self.sp].Array = nil 289 290 self.cursor = node // set cursor to this Array 291 self.onValueEnd() 292 return nil 293 } 294 295 func (self *visitorNodeDiffTest) Run(t *testing.T, str string, 296 tracer io.Writer) { 297 self.t = t 298 self.str = str 299 self.tracer = tracer 300 301 self.t.Helper() 302 303 self.cursor = NewRaw(self.str) 304 require.NoError(self.t, self.cursor.LoadAll()) 305 306 self.stk = visitorNodeStack{} 307 self.sp = 0 308 309 require.NoError(self.t, Preorder(self.str, self, nil)) 310 } 311 312 func TestVisitor_NodeDiff(t *testing.T) { 313 var suite visitorNodeDiffTest 314 315 newTracer := func(t *testing.T) io.Writer { 316 const EnableTracer = false 317 if !EnableTracer { 318 return nil 319 } 320 basename := strings.ReplaceAll(t.Name(), "/", "_") 321 fp, err := os.Create(fmt.Sprintf("../output/%s.log", basename)) 322 require.NoError(t, err) 323 writer := bufio.NewWriter(fp) 324 t.Cleanup(func() { 325 _ = writer.Flush() 326 _ = fp.Close() 327 }) 328 return writer 329 } 330 331 t.Run("default", func(t *testing.T) { 332 suite.Run(t, _TwitterJson, newTracer(t)) 333 }) 334 t.Run("issue_case01", func(t *testing.T) { 335 suite.Run(t, `[1193.6419677734375]`, newTracer(t)) 336 }) 337 } 338 339 type visitorUserNode interface { 340 UserNode() 341 } 342 343 type ( 344 visitorUserNull struct{} 345 visitorUserBool struct{ Value bool } 346 visitorUserInt64 struct{ Value int64 } 347 visitorUserFloat64 struct{ Value float64 } 348 visitorUserString struct{ Value string } 349 visitorUserObject struct{ Value map[string]visitorUserNode } 350 visitorUserArray struct{ Value []visitorUserNode } 351 ) 352 353 func (*visitorUserNull) UserNode() {} 354 func (*visitorUserBool) UserNode() {} 355 func (*visitorUserInt64) UserNode() {} 356 func (*visitorUserFloat64) UserNode() {} 357 func (*visitorUserString) UserNode() {} 358 func (*visitorUserObject) UserNode() {} 359 func (*visitorUserArray) UserNode() {} 360 361 func compareUserNode(tb testing.TB, lhs, rhs visitorUserNode) bool { 362 switch lhs := lhs.(type) { 363 case *visitorUserNull: 364 _, ok := rhs.(*visitorUserNull) 365 return assert.True(tb, ok) 366 case *visitorUserBool: 367 rhs, ok := rhs.(*visitorUserBool) 368 return assert.True(tb, ok) && assert.Equal(tb, lhs.Value, rhs.Value) 369 case *visitorUserInt64: 370 rhs, ok := rhs.(*visitorUserInt64) 371 return assert.True(tb, ok) && assert.Equal(tb, lhs.Value, rhs.Value) 372 case *visitorUserFloat64: 373 rhs, ok := rhs.(*visitorUserFloat64) 374 return assert.True(tb, ok) && assert.Equal(tb, lhs.Value, rhs.Value) 375 case *visitorUserString: 376 rhs, ok := rhs.(*visitorUserString) 377 return assert.True(tb, ok) && assert.Equal(tb, lhs.Value, rhs.Value) 378 case *visitorUserObject: 379 rhs, ok := rhs.(*visitorUserObject) 380 if !(assert.True(tb, ok) && assert.Equal(tb, len(lhs.Value), len(rhs.Value))) { 381 return false 382 } 383 for key, lhs := range lhs.Value { 384 rhs, ok := rhs.Value[key] 385 if !(assert.True(tb, ok) && assert.True(tb, compareUserNode(tb, lhs, rhs))) { 386 return false 387 } 388 } 389 return true 390 case *visitorUserArray: 391 rhs, ok := rhs.(*visitorUserArray) 392 if !(assert.True(tb, ok) && assert.Equal(tb, len(lhs.Value), len(rhs.Value))) { 393 return false 394 } 395 for i, n := 0, len(lhs.Value); i < n; i++ { 396 if !assert.True(tb, compareUserNode(tb, lhs.Value[i], rhs.Value[i])) { 397 return false 398 } 399 } 400 return true 401 default: 402 tb.Fatalf("unexpected type of UserNode: %T", lhs) 403 return false 404 } 405 } 406 407 type visitorUserNodeDecoder interface { 408 Reset() 409 Decode(str string) (visitorUserNode, error) 410 } 411 412 var _ visitorUserNodeDecoder = (*visitorUserNodeASTDecoder)(nil) 413 414 type visitorUserNodeASTDecoder struct{} 415 416 func (self *visitorUserNodeASTDecoder) Reset() {} 417 418 func (self *visitorUserNodeASTDecoder) Decode(str string) (visitorUserNode, error) { 419 root := NewRaw(str) 420 if err := root.LoadAll(); err != nil { 421 return nil, err 422 } 423 return self.decodeValue(&root) 424 } 425 426 func (self *visitorUserNodeASTDecoder) decodeValue(root *Node) (visitorUserNode, error) { 427 switch typ := root.Type(); typ { 428 // embed (*Node).Check 429 case V_NONE: 430 return nil, ErrNotExist 431 case V_ERROR: 432 return nil, root 433 434 case V_NULL: 435 return &visitorUserNull{}, nil 436 case V_TRUE: 437 return &visitorUserBool{Value: true}, nil 438 case V_FALSE: 439 return &visitorUserBool{Value: false}, nil 440 441 case V_STRING: 442 value, err := root.StrictString() 443 if err != nil { 444 return nil, err 445 } 446 return &visitorUserString{Value: value}, nil 447 448 case V_NUMBER: 449 value, err := root.StrictNumber() 450 if err != nil { 451 return nil, err 452 } 453 i64, ierr := value.Int64() 454 if ierr == nil { 455 return &visitorUserInt64{Value: i64}, nil 456 } 457 f64, ferr := value.Float64() 458 if ferr == nil { 459 return &visitorUserFloat64{Value: f64}, nil 460 } 461 return nil, fmt.Errorf("invalid number: %v, ierr: %v, ferr: %v", 462 value, ierr, ferr) 463 464 case V_ARRAY: 465 nodes, err := root.unsafeArray() 466 if err != nil { 467 return nil, err 468 } 469 values := make([]visitorUserNode, nodes.Len()) 470 for i := 0; i < nodes.Len(); i++ { 471 n := nodes.At(i) 472 value, err := self.decodeValue(n) 473 if err != nil { 474 return nil, err 475 } 476 values[i] = value 477 } 478 return &visitorUserArray{Value: values}, nil 479 480 case V_OBJECT: 481 pairs, err := root.unsafeMap() 482 if err != nil { 483 return nil, err 484 } 485 values := make(map[string]visitorUserNode, pairs.Len()) 486 for i := 0; i < pairs.Len(); i++ { 487 value, err := self.decodeValue(&pairs.At(i).Value) 488 if err != nil { 489 return nil, err 490 } 491 values[pairs.At(i).Key] = value 492 } 493 return &visitorUserObject{Value: values}, nil 494 495 case V_ANY: 496 fallthrough 497 default: 498 return nil, fmt.Errorf("unexpected Node type: %v", typ) 499 } 500 } 501 502 var _ visitorUserNodeDecoder = (*visitorUserNodeVisitorDecoder)(nil) 503 504 type visitorUserNodeVisitorDecoder struct { 505 stk visitorUserNodeStack 506 sp uint8 507 } 508 509 type visitorUserNodeStack = [256]struct { 510 val visitorUserNode 511 obj map[string]visitorUserNode 512 arr []visitorUserNode 513 key string 514 } 515 516 func (self *visitorUserNodeVisitorDecoder) Reset() { 517 self.stk = visitorUserNodeStack{} 518 self.sp = 0 519 } 520 521 func (self *visitorUserNodeVisitorDecoder) Decode(str string) (visitorUserNode, error) { 522 if err := Preorder(str, self, nil); err != nil { 523 return nil, err 524 } 525 return self.result() 526 } 527 528 func (self *visitorUserNodeVisitorDecoder) result() (visitorUserNode, error) { 529 if self.sp != 1 { 530 return nil, fmt.Errorf("incorrect sp: %d", self.sp) 531 } 532 return self.stk[0].val, nil 533 } 534 535 func (self *visitorUserNodeVisitorDecoder) incrSP() error { 536 self.sp++ 537 if self.sp == 0 { 538 return fmt.Errorf("reached max depth: %d", len(self.stk)) 539 } 540 return nil 541 } 542 543 func (self *visitorUserNodeVisitorDecoder) OnNull() error { 544 self.stk[self.sp].val = &visitorUserNull{} 545 if err := self.incrSP(); err != nil { 546 return err 547 } 548 return self.onValueEnd() 549 } 550 551 func (self *visitorUserNodeVisitorDecoder) OnBool(v bool) error { 552 self.stk[self.sp].val = &visitorUserBool{Value: v} 553 if err := self.incrSP(); err != nil { 554 return err 555 } 556 return self.onValueEnd() 557 } 558 559 func (self *visitorUserNodeVisitorDecoder) OnString(v string) error { 560 self.stk[self.sp].val = &visitorUserString{Value: v} 561 if err := self.incrSP(); err != nil { 562 return err 563 } 564 return self.onValueEnd() 565 } 566 567 func (self *visitorUserNodeVisitorDecoder) OnInt64(v int64, n json.Number) error { 568 self.stk[self.sp].val = &visitorUserInt64{Value: v} 569 if err := self.incrSP(); err != nil { 570 return err 571 } 572 return self.onValueEnd() 573 } 574 575 func (self *visitorUserNodeVisitorDecoder) OnFloat64(v float64, n json.Number) error { 576 self.stk[self.sp].val = &visitorUserFloat64{Value: v} 577 if err := self.incrSP(); err != nil { 578 return err 579 } 580 return self.onValueEnd() 581 } 582 583 func (self *visitorUserNodeVisitorDecoder) OnObjectBegin(capacity int) error { 584 self.stk[self.sp].obj = make(map[string]visitorUserNode, capacity) 585 return self.incrSP() 586 } 587 588 func (self *visitorUserNodeVisitorDecoder) OnObjectKey(key string) error { 589 self.stk[self.sp].key = key 590 return self.incrSP() 591 } 592 593 func (self *visitorUserNodeVisitorDecoder) OnObjectEnd() error { 594 self.stk[self.sp-1].val = &visitorUserObject{Value: self.stk[self.sp-1].obj} 595 self.stk[self.sp-1].obj = nil 596 return self.onValueEnd() 597 } 598 599 func (self *visitorUserNodeVisitorDecoder) OnArrayBegin(capacity int) error { 600 self.stk[self.sp].arr = make([]visitorUserNode, 0, capacity) 601 return self.incrSP() 602 } 603 604 func (self *visitorUserNodeVisitorDecoder) OnArrayEnd() error { 605 self.stk[self.sp-1].val = &visitorUserArray{Value: self.stk[self.sp-1].arr} 606 self.stk[self.sp-1].arr = nil 607 return self.onValueEnd() 608 } 609 610 func (self *visitorUserNodeVisitorDecoder) onValueEnd() error { 611 if self.sp == 1 { 612 return nil 613 } 614 // [..., Array, Value, sp] 615 if self.stk[self.sp-2].arr != nil { 616 self.stk[self.sp-2].arr = append(self.stk[self.sp-2].arr, self.stk[self.sp-1].val) 617 self.sp-- 618 return nil 619 } 620 // [..., Object, ObjectKey, Value, sp] 621 self.stk[self.sp-3].obj[self.stk[self.sp-2].key] = self.stk[self.sp-1].val 622 self.sp -= 2 623 return nil 624 } 625 626 func testUserNodeDiff(t *testing.T, d1, d2 visitorUserNodeDecoder, str string) { 627 t.Helper() 628 d1.Reset() 629 n1, err := d1.Decode(_TwitterJson) 630 require.NoError(t, err) 631 632 d2.Reset() 633 n2, err := d2.Decode(_TwitterJson) 634 require.NoError(t, err) 635 636 require.True(t, compareUserNode(t, n1, n2)) 637 } 638 639 func TestVisitor_UserNodeDiff(t *testing.T) { 640 var d1 visitorUserNodeASTDecoder 641 var d2 visitorUserNodeVisitorDecoder 642 643 t.Run("default", func(t *testing.T) { 644 testUserNodeDiff(t, &d1, &d2, _TwitterJson) 645 }) 646 t.Run("issue_case01", func(t *testing.T) { 647 testUserNodeDiff(t, &d1, &d2, `[1193.6419677734375]`) 648 }) 649 } 650 651 func BenchmarkVisitor_UserNode(b *testing.B) { 652 const str = _TwitterJson 653 b.Run("AST", func(b *testing.B) { 654 var d visitorUserNodeASTDecoder 655 b.ResetTimer() 656 for k := 0; k < b.N; k++ { 657 d.Reset() 658 _, err := d.Decode(str) 659 require.NoError(b, err) 660 b.SetBytes(int64(len(str))) 661 } 662 }) 663 b.Run("Visitor", func(b *testing.B) { 664 var d visitorUserNodeVisitorDecoder 665 b.ResetTimer() 666 for k := 0; k < b.N; k++ { 667 d.Reset() 668 _, err := d.Decode(str) 669 require.NoError(b, err) 670 b.SetBytes(int64(len(str))) 671 } 672 }) 673 }