github.com/amarpal/go-tools@v0.0.0-20240422043104-40142f59f616/pattern/match.go (about) 1 package pattern 2 3 import ( 4 "fmt" 5 "go/ast" 6 "go/token" 7 "go/types" 8 "reflect" 9 10 "golang.org/x/tools/go/ast/astutil" 11 ) 12 13 var tokensByString = map[string]Token{ 14 "INT": Token(token.INT), 15 "FLOAT": Token(token.FLOAT), 16 "IMAG": Token(token.IMAG), 17 "CHAR": Token(token.CHAR), 18 "STRING": Token(token.STRING), 19 "+": Token(token.ADD), 20 "-": Token(token.SUB), 21 "*": Token(token.MUL), 22 "/": Token(token.QUO), 23 "%": Token(token.REM), 24 "&": Token(token.AND), 25 "|": Token(token.OR), 26 "^": Token(token.XOR), 27 "<<": Token(token.SHL), 28 ">>": Token(token.SHR), 29 "&^": Token(token.AND_NOT), 30 "+=": Token(token.ADD_ASSIGN), 31 "-=": Token(token.SUB_ASSIGN), 32 "*=": Token(token.MUL_ASSIGN), 33 "/=": Token(token.QUO_ASSIGN), 34 "%=": Token(token.REM_ASSIGN), 35 "&=": Token(token.AND_ASSIGN), 36 "|=": Token(token.OR_ASSIGN), 37 "^=": Token(token.XOR_ASSIGN), 38 "<<=": Token(token.SHL_ASSIGN), 39 ">>=": Token(token.SHR_ASSIGN), 40 "&^=": Token(token.AND_NOT_ASSIGN), 41 "&&": Token(token.LAND), 42 "||": Token(token.LOR), 43 "<-": Token(token.ARROW), 44 "++": Token(token.INC), 45 "--": Token(token.DEC), 46 "==": Token(token.EQL), 47 "<": Token(token.LSS), 48 ">": Token(token.GTR), 49 "=": Token(token.ASSIGN), 50 "!": Token(token.NOT), 51 "!=": Token(token.NEQ), 52 "<=": Token(token.LEQ), 53 ">=": Token(token.GEQ), 54 ":=": Token(token.DEFINE), 55 "...": Token(token.ELLIPSIS), 56 "IMPORT": Token(token.IMPORT), 57 "VAR": Token(token.VAR), 58 "TYPE": Token(token.TYPE), 59 "CONST": Token(token.CONST), 60 "BREAK": Token(token.BREAK), 61 "CONTINUE": Token(token.CONTINUE), 62 "GOTO": Token(token.GOTO), 63 "FALLTHROUGH": Token(token.FALLTHROUGH), 64 } 65 66 func maybeToken(node Node) (Node, bool) { 67 if node, ok := node.(String); ok { 68 if tok, ok := tokensByString[string(node)]; ok { 69 return tok, true 70 } 71 return node, false 72 } 73 return node, false 74 } 75 76 func isNil(v interface{}) bool { 77 if v == nil { 78 return true 79 } 80 if _, ok := v.(Nil); ok { 81 return true 82 } 83 return false 84 } 85 86 type matcher interface { 87 Match(*Matcher, interface{}) (interface{}, bool) 88 } 89 90 type State = map[string]any 91 92 type Matcher struct { 93 TypesInfo *types.Info 94 State State 95 96 bindingsMapping []string 97 98 setBindings []uint64 99 } 100 101 func (m *Matcher) set(b Binding, value interface{}) { 102 m.State[b.Name] = value 103 m.setBindings[len(m.setBindings)-1] |= 1 << b.idx 104 } 105 106 func (m *Matcher) push() { 107 m.setBindings = append(m.setBindings, 0) 108 } 109 110 func (m *Matcher) pop() { 111 set := m.setBindings[len(m.setBindings)-1] 112 if set != 0 { 113 for i := 0; i < len(m.bindingsMapping); i++ { 114 if (set & (1 << i)) != 0 { 115 key := m.bindingsMapping[i] 116 delete(m.State, key) 117 } 118 } 119 } 120 m.setBindings = m.setBindings[:len(m.setBindings)-1] 121 } 122 123 func (m *Matcher) merge() { 124 m.setBindings = m.setBindings[:len(m.setBindings)-1] 125 } 126 127 func (m *Matcher) Match(a Pattern, b ast.Node) bool { 128 m.bindingsMapping = a.Bindings 129 m.State = State{} 130 m.push() 131 _, ok := match(m, a.Root, b) 132 m.merge() 133 if len(m.setBindings) != 0 { 134 panic(fmt.Sprintf("%d entries left on the stack, expected none", len(m.setBindings))) 135 } 136 return ok 137 } 138 139 func Match(a Pattern, b ast.Node) (*Matcher, bool) { 140 m := &Matcher{} 141 ret := m.Match(a, b) 142 return m, ret 143 } 144 145 // Match two items, which may be (Node, AST) or (AST, AST) 146 func match(m *Matcher, l, r interface{}) (interface{}, bool) { 147 if _, ok := r.(Node); ok { 148 panic("Node mustn't be on right side of match") 149 } 150 151 switch l := l.(type) { 152 case *ast.ParenExpr: 153 return match(m, l.X, r) 154 case *ast.ExprStmt: 155 return match(m, l.X, r) 156 case *ast.DeclStmt: 157 return match(m, l.Decl, r) 158 case *ast.LabeledStmt: 159 return match(m, l.Stmt, r) 160 case *ast.BlockStmt: 161 return match(m, l.List, r) 162 case *ast.FieldList: 163 if l == nil { 164 return match(m, nil, r) 165 } else { 166 return match(m, l.List, r) 167 } 168 } 169 170 switch r := r.(type) { 171 case *ast.ParenExpr: 172 return match(m, l, r.X) 173 case *ast.ExprStmt: 174 return match(m, l, r.X) 175 case *ast.DeclStmt: 176 return match(m, l, r.Decl) 177 case *ast.LabeledStmt: 178 return match(m, l, r.Stmt) 179 case *ast.BlockStmt: 180 if r == nil { 181 return match(m, l, nil) 182 } 183 return match(m, l, r.List) 184 case *ast.FieldList: 185 if r == nil { 186 return match(m, l, nil) 187 } 188 return match(m, l, r.List) 189 case *ast.BasicLit: 190 if r == nil { 191 return match(m, l, nil) 192 } 193 } 194 195 if l, ok := l.(matcher); ok { 196 return l.Match(m, r) 197 } 198 199 if l, ok := l.(Node); ok { 200 // Matching of pattern with concrete value 201 return matchNodeAST(m, l, r) 202 } 203 204 if l == nil || r == nil { 205 return nil, l == r 206 } 207 208 { 209 ln, ok1 := l.(ast.Node) 210 rn, ok2 := r.(ast.Node) 211 if ok1 && ok2 { 212 return matchAST(m, ln, rn) 213 } 214 } 215 216 { 217 obj, ok := l.(types.Object) 218 if ok { 219 switch r := r.(type) { 220 case *ast.Ident: 221 return obj, obj == m.TypesInfo.ObjectOf(r) 222 case *ast.SelectorExpr: 223 return obj, obj == m.TypesInfo.ObjectOf(r.Sel) 224 default: 225 return obj, false 226 } 227 } 228 } 229 230 // TODO(dh): the three blocks handling slices can be combined into a single block if we use reflection 231 232 { 233 ln, ok1 := l.([]ast.Expr) 234 rn, ok2 := r.([]ast.Expr) 235 if ok1 || ok2 { 236 if ok1 && !ok2 { 237 cast, ok := r.(ast.Expr) 238 if !ok { 239 return nil, false 240 } 241 rn = []ast.Expr{cast} 242 } else if !ok1 && ok2 { 243 cast, ok := l.(ast.Expr) 244 if !ok { 245 return nil, false 246 } 247 ln = []ast.Expr{cast} 248 } 249 250 if len(ln) != len(rn) { 251 return nil, false 252 } 253 for i, ll := range ln { 254 if _, ok := match(m, ll, rn[i]); !ok { 255 return nil, false 256 } 257 } 258 return r, true 259 } 260 } 261 262 { 263 ln, ok1 := l.([]ast.Stmt) 264 rn, ok2 := r.([]ast.Stmt) 265 if ok1 || ok2 { 266 if ok1 && !ok2 { 267 cast, ok := r.(ast.Stmt) 268 if !ok { 269 return nil, false 270 } 271 rn = []ast.Stmt{cast} 272 } else if !ok1 && ok2 { 273 cast, ok := l.(ast.Stmt) 274 if !ok { 275 return nil, false 276 } 277 ln = []ast.Stmt{cast} 278 } 279 280 if len(ln) != len(rn) { 281 return nil, false 282 } 283 for i, ll := range ln { 284 if _, ok := match(m, ll, rn[i]); !ok { 285 return nil, false 286 } 287 } 288 return r, true 289 } 290 } 291 292 { 293 ln, ok1 := l.([]*ast.Field) 294 rn, ok2 := r.([]*ast.Field) 295 if ok1 || ok2 { 296 if ok1 && !ok2 { 297 cast, ok := r.(*ast.Field) 298 if !ok { 299 return nil, false 300 } 301 rn = []*ast.Field{cast} 302 } else if !ok1 && ok2 { 303 cast, ok := l.(*ast.Field) 304 if !ok { 305 return nil, false 306 } 307 ln = []*ast.Field{cast} 308 } 309 310 if len(ln) != len(rn) { 311 return nil, false 312 } 313 for i, ll := range ln { 314 if _, ok := match(m, ll, rn[i]); !ok { 315 return nil, false 316 } 317 } 318 return r, true 319 } 320 } 321 322 return nil, false 323 } 324 325 // Match a Node with an AST node 326 func matchNodeAST(m *Matcher, a Node, b interface{}) (interface{}, bool) { 327 switch b := b.(type) { 328 case []ast.Stmt: 329 // 'a' is not a List or we'd be using its Match 330 // implementation. 331 332 if len(b) != 1 { 333 return nil, false 334 } 335 return match(m, a, b[0]) 336 case []ast.Expr: 337 // 'a' is not a List or we'd be using its Match 338 // implementation. 339 340 if len(b) != 1 { 341 return nil, false 342 } 343 return match(m, a, b[0]) 344 case []*ast.Field: 345 // 'a' is not a List or we'd be using its Match 346 // implementation 347 if len(b) != 1 { 348 return nil, false 349 } 350 return match(m, a, b[0]) 351 case ast.Node: 352 ra := reflect.ValueOf(a) 353 rb := reflect.ValueOf(b).Elem() 354 355 if ra.Type().Name() != rb.Type().Name() { 356 return nil, false 357 } 358 359 for i := 0; i < ra.NumField(); i++ { 360 af := ra.Field(i) 361 fieldName := ra.Type().Field(i).Name 362 bf := rb.FieldByName(fieldName) 363 if (bf == reflect.Value{}) { 364 panic(fmt.Sprintf("internal error: could not find field %s in type %t when comparing with %T", fieldName, b, a)) 365 } 366 ai := af.Interface() 367 bi := bf.Interface() 368 if ai == nil { 369 return b, bi == nil 370 } 371 if _, ok := match(m, ai.(Node), bi); !ok { 372 return b, false 373 } 374 } 375 return b, true 376 case nil: 377 return nil, a == Nil{} 378 case string, token.Token: 379 // 'a' can't be a String, Token, or Binding or we'd be using their Match implementations. 380 return nil, false 381 default: 382 panic(fmt.Sprintf("unhandled type %T", b)) 383 } 384 } 385 386 // Match two AST nodes 387 func matchAST(m *Matcher, a, b ast.Node) (interface{}, bool) { 388 ra := reflect.ValueOf(a) 389 rb := reflect.ValueOf(b) 390 391 if ra.Type() != rb.Type() { 392 return nil, false 393 } 394 if ra.IsNil() || rb.IsNil() { 395 return rb, ra.IsNil() == rb.IsNil() 396 } 397 398 ra = ra.Elem() 399 rb = rb.Elem() 400 for i := 0; i < ra.NumField(); i++ { 401 af := ra.Field(i) 402 bf := rb.Field(i) 403 if af.Type() == rtTokPos || af.Type() == rtObject || af.Type() == rtCommentGroup { 404 continue 405 } 406 407 switch af.Kind() { 408 case reflect.Slice: 409 if af.Len() != bf.Len() { 410 return nil, false 411 } 412 for j := 0; j < af.Len(); j++ { 413 if _, ok := match(m, af.Index(j).Interface().(ast.Node), bf.Index(j).Interface().(ast.Node)); !ok { 414 return nil, false 415 } 416 } 417 case reflect.String: 418 if af.String() != bf.String() { 419 return nil, false 420 } 421 case reflect.Int: 422 if af.Int() != bf.Int() { 423 return nil, false 424 } 425 case reflect.Bool: 426 if af.Bool() != bf.Bool() { 427 return nil, false 428 } 429 case reflect.Ptr, reflect.Interface: 430 if _, ok := match(m, af.Interface(), bf.Interface()); !ok { 431 return nil, false 432 } 433 default: 434 panic(fmt.Sprintf("internal error: unhandled kind %s (%T)", af.Kind(), af.Interface())) 435 } 436 } 437 return b, true 438 } 439 440 func (b Binding) Match(m *Matcher, node interface{}) (interface{}, bool) { 441 if isNil(b.Node) { 442 v, ok := m.State[b.Name] 443 if ok { 444 // Recall value 445 return match(m, v, node) 446 } 447 // Matching anything 448 b.Node = Any{} 449 } 450 451 // Store value 452 if _, ok := m.State[b.Name]; ok { 453 panic(fmt.Sprintf("binding already created: %s", b.Name)) 454 } 455 new, ret := match(m, b.Node, node) 456 if ret { 457 m.set(b, new) 458 } 459 return new, ret 460 } 461 462 func (Any) Match(m *Matcher, node interface{}) (interface{}, bool) { 463 return node, true 464 } 465 466 func (l List) Match(m *Matcher, node interface{}) (interface{}, bool) { 467 v := reflect.ValueOf(node) 468 if v.Kind() == reflect.Slice { 469 if isNil(l.Head) { 470 return node, v.Len() == 0 471 } 472 if v.Len() == 0 { 473 return nil, false 474 } 475 // OPT(dh): don't check the entire tail if head didn't match 476 _, ok1 := match(m, l.Head, v.Index(0).Interface()) 477 _, ok2 := match(m, l.Tail, v.Slice(1, v.Len()).Interface()) 478 return node, ok1 && ok2 479 } 480 // Our empty list does not equal an untyped Go nil. This way, we can 481 // tell apart an if with no else and an if with an empty else. 482 return nil, false 483 } 484 485 func (s String) Match(m *Matcher, node interface{}) (interface{}, bool) { 486 switch o := node.(type) { 487 case token.Token: 488 if tok, ok := maybeToken(s); ok { 489 return match(m, tok, node) 490 } 491 return nil, false 492 case string: 493 return o, string(s) == o 494 case types.TypeAndValue: 495 return o, o.Value != nil && o.Value.String() == string(s) 496 default: 497 return nil, false 498 } 499 } 500 501 func (tok Token) Match(m *Matcher, node interface{}) (interface{}, bool) { 502 o, ok := node.(token.Token) 503 if !ok { 504 return nil, false 505 } 506 return o, token.Token(tok) == o 507 } 508 509 func (Nil) Match(m *Matcher, node interface{}) (interface{}, bool) { 510 if isNil(node) { 511 return nil, true 512 } 513 v := reflect.ValueOf(node) 514 switch v.Kind() { 515 case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Pointer, reflect.Slice: 516 return nil, v.IsNil() 517 default: 518 return nil, false 519 } 520 } 521 522 func (builtin Builtin) Match(m *Matcher, node interface{}) (interface{}, bool) { 523 r, ok := match(m, Ident(builtin), node) 524 if !ok { 525 return nil, false 526 } 527 ident := r.(*ast.Ident) 528 obj := m.TypesInfo.ObjectOf(ident) 529 if obj != types.Universe.Lookup(ident.Name) { 530 return nil, false 531 } 532 return ident, true 533 } 534 535 func (obj Object) Match(m *Matcher, node interface{}) (interface{}, bool) { 536 r, ok := match(m, Ident(obj), node) 537 if !ok { 538 return nil, false 539 } 540 ident := r.(*ast.Ident) 541 542 id := m.TypesInfo.ObjectOf(ident) 543 _, ok = match(m, obj.Name, ident.Name) 544 return id, ok 545 } 546 547 func (fn Symbol) Match(m *Matcher, node interface{}) (interface{}, bool) { 548 var name string 549 var obj types.Object 550 551 base := []Node{ 552 Ident{Any{}}, 553 SelectorExpr{Any{}, Any{}}, 554 } 555 p := Or{ 556 Nodes: append(base, 557 IndexExpr{Or{Nodes: base}, Any{}}, 558 IndexListExpr{Or{Nodes: base}, Any{}})} 559 560 r, ok := match(m, p, node) 561 if !ok { 562 return nil, false 563 } 564 565 fun := r.(ast.Expr) 566 switch idx := fun.(type) { 567 case *ast.IndexExpr: 568 fun = idx.X 569 case *ast.IndexListExpr: 570 fun = idx.X 571 } 572 fun = astutil.Unparen(fun) 573 574 switch fun := fun.(type) { 575 case *ast.Ident: 576 obj = m.TypesInfo.ObjectOf(fun) 577 case *ast.SelectorExpr: 578 obj = m.TypesInfo.ObjectOf(fun.Sel) 579 default: 580 panic("unreachable") 581 } 582 switch obj := obj.(type) { 583 case *types.Func: 584 // OPT(dh): optimize this similar to code.FuncName 585 name = obj.FullName() 586 case *types.Builtin: 587 name = obj.Name() 588 case *types.TypeName: 589 if obj.Pkg() == nil { 590 return nil, false 591 } 592 if obj.Parent() != obj.Pkg().Scope() { 593 return nil, false 594 } 595 name = types.TypeString(obj.Type(), nil) 596 case *types.Const, *types.Var: 597 if obj.Pkg() == nil { 598 return nil, false 599 } 600 if obj.Parent() != obj.Pkg().Scope() { 601 return nil, false 602 } 603 name = fmt.Sprintf("%s.%s", obj.Pkg().Path(), obj.Name()) 604 default: 605 return nil, false 606 } 607 608 _, ok = match(m, fn.Name, name) 609 return obj, ok 610 } 611 612 func (or Or) Match(m *Matcher, node interface{}) (interface{}, bool) { 613 for _, opt := range or.Nodes { 614 m.push() 615 if ret, ok := match(m, opt, node); ok { 616 m.merge() 617 return ret, true 618 } else { 619 m.pop() 620 } 621 } 622 return nil, false 623 } 624 625 func (not Not) Match(m *Matcher, node interface{}) (interface{}, bool) { 626 _, ok := match(m, not.Node, node) 627 if ok { 628 return nil, false 629 } 630 return node, true 631 } 632 633 var integerLiteralQ = MustParse(`(Or (BasicLit "INT" _) (UnaryExpr (Or "+" "-") (IntegerLiteral _)))`) 634 635 func (lit IntegerLiteral) Match(m *Matcher, node interface{}) (interface{}, bool) { 636 matched, ok := match(m, integerLiteralQ.Root, node) 637 if !ok { 638 return nil, false 639 } 640 tv, ok := m.TypesInfo.Types[matched.(ast.Expr)] 641 if !ok { 642 return nil, false 643 } 644 if tv.Value == nil { 645 return nil, false 646 } 647 _, ok = match(m, lit.Value, tv) 648 return matched, ok 649 } 650 651 func (texpr TrulyConstantExpression) Match(m *Matcher, node interface{}) (interface{}, bool) { 652 expr, ok := node.(ast.Expr) 653 if !ok { 654 return nil, false 655 } 656 tv, ok := m.TypesInfo.Types[expr] 657 if !ok { 658 return nil, false 659 } 660 if tv.Value == nil { 661 return nil, false 662 } 663 truly := true 664 ast.Inspect(expr, func(node ast.Node) bool { 665 if _, ok := node.(*ast.Ident); ok { 666 truly = false 667 return false 668 } 669 return true 670 }) 671 if !truly { 672 return nil, false 673 } 674 _, ok = match(m, texpr.Value, tv) 675 return expr, ok 676 } 677 678 var ( 679 // Types of fields in go/ast structs that we want to skip 680 rtTokPos = reflect.TypeOf(token.Pos(0)) 681 rtObject = reflect.TypeOf((*ast.Object)(nil)) 682 rtCommentGroup = reflect.TypeOf((*ast.CommentGroup)(nil)) 683 ) 684 685 var ( 686 _ matcher = Binding{} 687 _ matcher = Any{} 688 _ matcher = List{} 689 _ matcher = String("") 690 _ matcher = Token(0) 691 _ matcher = Nil{} 692 _ matcher = Builtin{} 693 _ matcher = Object{} 694 _ matcher = Symbol{} 695 _ matcher = Or{} 696 _ matcher = Not{} 697 _ matcher = IntegerLiteral{} 698 _ matcher = TrulyConstantExpression{} 699 )