honnef.co/go/tools@v0.5.0-0.dev.0.20240520180541-dcae280a5e87/pattern/match.go (about) 1 package pattern 2 3 import ( 4 "fmt" 5 "go/ast" 6 "go/token" 7 "go/types" 8 "reflect" 9 10 "golang.org/x/tools/go/ast/astutil" 11 ) 12 13 var tokensByString = map[string]Token{ 14 "INT": Token(token.INT), 15 "FLOAT": Token(token.FLOAT), 16 "IMAG": Token(token.IMAG), 17 "CHAR": Token(token.CHAR), 18 "STRING": Token(token.STRING), 19 "+": Token(token.ADD), 20 "-": Token(token.SUB), 21 "*": Token(token.MUL), 22 "/": Token(token.QUO), 23 "%": Token(token.REM), 24 "&": Token(token.AND), 25 "|": Token(token.OR), 26 "^": Token(token.XOR), 27 "<<": Token(token.SHL), 28 ">>": Token(token.SHR), 29 "&^": Token(token.AND_NOT), 30 "+=": Token(token.ADD_ASSIGN), 31 "-=": Token(token.SUB_ASSIGN), 32 "*=": Token(token.MUL_ASSIGN), 33 "/=": Token(token.QUO_ASSIGN), 34 "%=": Token(token.REM_ASSIGN), 35 "&=": Token(token.AND_ASSIGN), 36 "|=": Token(token.OR_ASSIGN), 37 "^=": Token(token.XOR_ASSIGN), 38 "<<=": Token(token.SHL_ASSIGN), 39 ">>=": Token(token.SHR_ASSIGN), 40 "&^=": Token(token.AND_NOT_ASSIGN), 41 "&&": Token(token.LAND), 42 "||": Token(token.LOR), 43 "<-": Token(token.ARROW), 44 "++": Token(token.INC), 45 "--": Token(token.DEC), 46 "==": Token(token.EQL), 47 "<": Token(token.LSS), 48 ">": Token(token.GTR), 49 "=": Token(token.ASSIGN), 50 "!": Token(token.NOT), 51 "!=": Token(token.NEQ), 52 "<=": Token(token.LEQ), 53 ">=": Token(token.GEQ), 54 ":=": Token(token.DEFINE), 55 "...": Token(token.ELLIPSIS), 56 "IMPORT": Token(token.IMPORT), 57 "VAR": Token(token.VAR), 58 "TYPE": Token(token.TYPE), 59 "CONST": Token(token.CONST), 60 "BREAK": Token(token.BREAK), 61 "CONTINUE": Token(token.CONTINUE), 62 "GOTO": Token(token.GOTO), 63 "FALLTHROUGH": Token(token.FALLTHROUGH), 64 } 65 66 func maybeToken(node Node) (Node, bool) { 67 if node, ok := node.(String); ok { 68 if tok, ok := tokensByString[string(node)]; ok { 69 return tok, true 70 } 71 return node, false 72 } 73 return node, false 74 } 75 76 func isNil(v interface{}) bool { 77 if v == nil { 78 return true 79 } 80 if _, ok := v.(Nil); ok { 81 return true 82 } 83 return false 84 } 85 86 type matcher interface { 87 Match(*Matcher, interface{}) (interface{}, bool) 88 } 89 90 type State = map[string]any 91 92 type Matcher struct { 93 TypesInfo *types.Info 94 State State 95 96 bindingsMapping []string 97 98 setBindings []uint64 99 } 100 101 func (m *Matcher) set(b Binding, value interface{}) { 102 m.State[b.Name] = value 103 m.setBindings[len(m.setBindings)-1] |= 1 << b.idx 104 } 105 106 func (m *Matcher) push() { 107 m.setBindings = append(m.setBindings, 0) 108 } 109 110 func (m *Matcher) pop() { 111 set := m.setBindings[len(m.setBindings)-1] 112 if set != 0 { 113 for i := 0; i < len(m.bindingsMapping); i++ { 114 if (set & (1 << i)) != 0 { 115 key := m.bindingsMapping[i] 116 delete(m.State, key) 117 } 118 } 119 } 120 m.setBindings = m.setBindings[:len(m.setBindings)-1] 121 } 122 123 func (m *Matcher) merge() { 124 m.setBindings = m.setBindings[:len(m.setBindings)-1] 125 } 126 127 func (m *Matcher) Match(a Pattern, b ast.Node) bool { 128 m.bindingsMapping = a.Bindings 129 m.State = State{} 130 m.push() 131 _, ok := match(m, a.Root, b) 132 m.merge() 133 if len(m.setBindings) != 0 { 134 panic(fmt.Sprintf("%d entries left on the stack, expected none", len(m.setBindings))) 135 } 136 return ok 137 } 138 139 func Match(a Pattern, b ast.Node) (*Matcher, bool) { 140 m := &Matcher{} 141 ret := m.Match(a, b) 142 return m, ret 143 } 144 145 // Match two items, which may be (Node, AST) or (AST, AST) 146 func match(m *Matcher, l, r interface{}) (interface{}, bool) { 147 if _, ok := r.(Node); ok { 148 panic("Node mustn't be on right side of match") 149 } 150 151 switch l := l.(type) { 152 case *ast.ParenExpr: 153 return match(m, l.X, r) 154 case *ast.ExprStmt: 155 return match(m, l.X, r) 156 case *ast.DeclStmt: 157 return match(m, l.Decl, r) 158 case *ast.LabeledStmt: 159 return match(m, l.Stmt, r) 160 case *ast.BlockStmt: 161 return match(m, l.List, r) 162 case *ast.FieldList: 163 if l == nil { 164 return match(m, nil, r) 165 } else { 166 return match(m, l.List, r) 167 } 168 } 169 170 switch r := r.(type) { 171 case *ast.ParenExpr: 172 return match(m, l, r.X) 173 case *ast.ExprStmt: 174 return match(m, l, r.X) 175 case *ast.DeclStmt: 176 return match(m, l, r.Decl) 177 case *ast.LabeledStmt: 178 return match(m, l, r.Stmt) 179 case *ast.BlockStmt: 180 if r == nil { 181 return match(m, l, nil) 182 } 183 return match(m, l, r.List) 184 case *ast.FieldList: 185 if r == nil { 186 return match(m, l, nil) 187 } 188 return match(m, l, r.List) 189 case *ast.BasicLit: 190 if r == nil { 191 return match(m, l, nil) 192 } 193 } 194 195 if l, ok := l.(matcher); ok { 196 return l.Match(m, r) 197 } 198 199 if l, ok := l.(Node); ok { 200 // Matching of pattern with concrete value 201 return matchNodeAST(m, l, r) 202 } 203 204 if l == nil || r == nil { 205 return nil, l == r 206 } 207 208 { 209 ln, ok1 := l.(ast.Node) 210 rn, ok2 := r.(ast.Node) 211 if ok1 && ok2 { 212 return matchAST(m, ln, rn) 213 } 214 } 215 216 { 217 obj, ok := l.(types.Object) 218 if ok { 219 switch r := r.(type) { 220 case *ast.Ident: 221 return obj, obj == m.TypesInfo.ObjectOf(r) 222 case *ast.SelectorExpr: 223 return obj, obj == m.TypesInfo.ObjectOf(r.Sel) 224 default: 225 return obj, false 226 } 227 } 228 } 229 230 // TODO(dh): the three blocks handling slices can be combined into a single block if we use reflection 231 232 { 233 ln, ok1 := l.([]ast.Expr) 234 rn, ok2 := r.([]ast.Expr) 235 if ok1 || ok2 { 236 if ok1 && !ok2 { 237 cast, ok := r.(ast.Expr) 238 if !ok { 239 return nil, false 240 } 241 rn = []ast.Expr{cast} 242 } else if !ok1 && ok2 { 243 cast, ok := l.(ast.Expr) 244 if !ok { 245 return nil, false 246 } 247 ln = []ast.Expr{cast} 248 } 249 250 if len(ln) != len(rn) { 251 return nil, false 252 } 253 for i, ll := range ln { 254 if _, ok := match(m, ll, rn[i]); !ok { 255 return nil, false 256 } 257 } 258 return r, true 259 } 260 } 261 262 { 263 ln, ok1 := l.([]ast.Stmt) 264 rn, ok2 := r.([]ast.Stmt) 265 if ok1 || ok2 { 266 if ok1 && !ok2 { 267 cast, ok := r.(ast.Stmt) 268 if !ok { 269 return nil, false 270 } 271 rn = []ast.Stmt{cast} 272 } else if !ok1 && ok2 { 273 cast, ok := l.(ast.Stmt) 274 if !ok { 275 return nil, false 276 } 277 ln = []ast.Stmt{cast} 278 } 279 280 if len(ln) != len(rn) { 281 return nil, false 282 } 283 for i, ll := range ln { 284 if _, ok := match(m, ll, rn[i]); !ok { 285 return nil, false 286 } 287 } 288 return r, true 289 } 290 } 291 292 { 293 ln, ok1 := l.([]*ast.Field) 294 rn, ok2 := r.([]*ast.Field) 295 if ok1 || ok2 { 296 if ok1 && !ok2 { 297 cast, ok := r.(*ast.Field) 298 if !ok { 299 return nil, false 300 } 301 rn = []*ast.Field{cast} 302 } else if !ok1 && ok2 { 303 cast, ok := l.(*ast.Field) 304 if !ok { 305 return nil, false 306 } 307 ln = []*ast.Field{cast} 308 } 309 310 if len(ln) != len(rn) { 311 return nil, false 312 } 313 for i, ll := range ln { 314 if _, ok := match(m, ll, rn[i]); !ok { 315 return nil, false 316 } 317 } 318 return r, true 319 } 320 } 321 322 return nil, false 323 } 324 325 // Match a Node with an AST node 326 func matchNodeAST(m *Matcher, a Node, b interface{}) (interface{}, bool) { 327 switch b := b.(type) { 328 case []ast.Stmt: 329 // 'a' is not a List or we'd be using its Match 330 // implementation. 331 332 if len(b) != 1 { 333 return nil, false 334 } 335 return match(m, a, b[0]) 336 case []ast.Expr: 337 // 'a' is not a List or we'd be using its Match 338 // implementation. 339 340 if len(b) != 1 { 341 return nil, false 342 } 343 return match(m, a, b[0]) 344 case []*ast.Field: 345 // 'a' is not a List or we'd be using its Match 346 // implementation 347 if len(b) != 1 { 348 return nil, false 349 } 350 return match(m, a, b[0]) 351 case ast.Node: 352 ra := reflect.ValueOf(a) 353 rb := reflect.ValueOf(b).Elem() 354 355 if ra.Type().Name() != rb.Type().Name() { 356 return nil, false 357 } 358 359 for i := 0; i < ra.NumField(); i++ { 360 af := ra.Field(i) 361 fieldName := ra.Type().Field(i).Name 362 bf := rb.FieldByName(fieldName) 363 if (bf == reflect.Value{}) { 364 panic(fmt.Sprintf("internal error: could not find field %s in type %t when comparing with %T", fieldName, b, a)) 365 } 366 ai := af.Interface() 367 bi := bf.Interface() 368 if ai == nil { 369 return b, bi == nil 370 } 371 if _, ok := match(m, ai.(Node), bi); !ok { 372 return b, false 373 } 374 } 375 return b, true 376 case nil: 377 return nil, a == Nil{} 378 case string, token.Token: 379 // 'a' can't be a String, Token, or Binding or we'd be using their Match implementations. 380 return nil, false 381 default: 382 panic(fmt.Sprintf("unhandled type %T", b)) 383 } 384 } 385 386 // Match two AST nodes 387 func matchAST(m *Matcher, a, b ast.Node) (interface{}, bool) { 388 ra := reflect.ValueOf(a) 389 rb := reflect.ValueOf(b) 390 391 if ra.Type() != rb.Type() { 392 return nil, false 393 } 394 if ra.IsNil() || rb.IsNil() { 395 return rb, ra.IsNil() == rb.IsNil() 396 } 397 398 ra = ra.Elem() 399 rb = rb.Elem() 400 for i := 0; i < ra.NumField(); i++ { 401 af := ra.Field(i) 402 bf := rb.Field(i) 403 if af.Type() == rtTokPos || af.Type() == rtObject || af.Type() == rtCommentGroup { 404 continue 405 } 406 407 switch af.Kind() { 408 case reflect.Slice: 409 if af.Len() != bf.Len() { 410 return nil, false 411 } 412 for j := 0; j < af.Len(); j++ { 413 if _, ok := match(m, af.Index(j).Interface().(ast.Node), bf.Index(j).Interface().(ast.Node)); !ok { 414 return nil, false 415 } 416 } 417 case reflect.String: 418 if af.String() != bf.String() { 419 return nil, false 420 } 421 case reflect.Int: 422 if af.Int() != bf.Int() { 423 return nil, false 424 } 425 case reflect.Bool: 426 if af.Bool() != bf.Bool() { 427 return nil, false 428 } 429 case reflect.Ptr, reflect.Interface: 430 if _, ok := match(m, af.Interface(), bf.Interface()); !ok { 431 return nil, false 432 } 433 default: 434 panic(fmt.Sprintf("internal error: unhandled kind %s (%T)", af.Kind(), af.Interface())) 435 } 436 } 437 return b, true 438 } 439 440 func (b Binding) Match(m *Matcher, node interface{}) (interface{}, bool) { 441 if isNil(b.Node) { 442 v, ok := m.State[b.Name] 443 if ok { 444 // Recall value 445 return match(m, v, node) 446 } 447 // Matching anything 448 b.Node = Any{} 449 } 450 451 // Store value 452 if _, ok := m.State[b.Name]; ok { 453 panic(fmt.Sprintf("binding already created: %s", b.Name)) 454 } 455 new, ret := match(m, b.Node, node) 456 if ret { 457 m.set(b, new) 458 } 459 return new, ret 460 } 461 462 func (Any) Match(m *Matcher, node interface{}) (interface{}, bool) { 463 return node, true 464 } 465 466 func (l List) Match(m *Matcher, node interface{}) (interface{}, bool) { 467 v := reflect.ValueOf(node) 468 if v.Kind() == reflect.Slice { 469 if isNil(l.Head) { 470 return node, v.Len() == 0 471 } 472 if v.Len() == 0 { 473 return nil, false 474 } 475 // OPT(dh): don't check the entire tail if head didn't match 476 _, ok1 := match(m, l.Head, v.Index(0).Interface()) 477 _, ok2 := match(m, l.Tail, v.Slice(1, v.Len()).Interface()) 478 return node, ok1 && ok2 479 } 480 // Our empty list does not equal an untyped Go nil. This way, we can 481 // tell apart an if with no else and an if with an empty else. 482 return nil, false 483 } 484 485 func (s String) Match(m *Matcher, node interface{}) (interface{}, bool) { 486 switch o := node.(type) { 487 case token.Token: 488 if tok, ok := maybeToken(s); ok { 489 return match(m, tok, node) 490 } 491 return nil, false 492 case string: 493 return o, string(s) == o 494 case types.TypeAndValue: 495 return o, o.Value != nil && o.Value.String() == string(s) 496 default: 497 return nil, false 498 } 499 } 500 501 func (tok Token) Match(m *Matcher, node interface{}) (interface{}, bool) { 502 o, ok := node.(token.Token) 503 if !ok { 504 return nil, false 505 } 506 return o, token.Token(tok) == o 507 } 508 509 func (Nil) Match(m *Matcher, node interface{}) (interface{}, bool) { 510 if isNil(node) { 511 return nil, true 512 } 513 v := reflect.ValueOf(node) 514 switch v.Kind() { 515 case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Pointer, reflect.Slice: 516 return nil, v.IsNil() 517 default: 518 return nil, false 519 } 520 } 521 522 func (builtin Builtin) Match(m *Matcher, node interface{}) (interface{}, bool) { 523 r, ok := match(m, Ident(builtin), node) 524 if !ok { 525 return nil, false 526 } 527 ident := r.(*ast.Ident) 528 obj := m.TypesInfo.ObjectOf(ident) 529 if obj != types.Universe.Lookup(ident.Name) { 530 return nil, false 531 } 532 return ident, true 533 } 534 535 func (obj Object) Match(m *Matcher, node interface{}) (interface{}, bool) { 536 r, ok := match(m, Ident(obj), node) 537 if !ok { 538 return nil, false 539 } 540 ident := r.(*ast.Ident) 541 542 id := m.TypesInfo.ObjectOf(ident) 543 _, ok = match(m, obj.Name, ident.Name) 544 return id, ok 545 } 546 547 func (fn Symbol) Match(m *Matcher, node interface{}) (interface{}, bool) { 548 var name string 549 var obj types.Object 550 551 base := []Node{ 552 Ident{Any{}}, 553 SelectorExpr{Any{}, Any{}}, 554 } 555 p := Or{ 556 Nodes: append(base, 557 IndexExpr{Or{Nodes: base}, Any{}}, 558 IndexListExpr{Or{Nodes: base}, Any{}})} 559 560 r, ok := match(m, p, node) 561 if !ok { 562 return nil, false 563 } 564 565 fun := r.(ast.Expr) 566 switch idx := fun.(type) { 567 case *ast.IndexExpr: 568 fun = idx.X 569 case *ast.IndexListExpr: 570 fun = idx.X 571 } 572 fun = astutil.Unparen(fun) 573 574 switch fun := fun.(type) { 575 case *ast.Ident: 576 obj = m.TypesInfo.ObjectOf(fun) 577 case *ast.SelectorExpr: 578 obj = m.TypesInfo.ObjectOf(fun.Sel) 579 default: 580 panic("unreachable") 581 } 582 switch obj := obj.(type) { 583 case *types.Func: 584 // OPT(dh): optimize this similar to code.FuncName 585 name = obj.FullName() 586 case *types.Builtin: 587 name = obj.Name() 588 case *types.TypeName: 589 origObj := obj 590 for { 591 if obj.Parent() != obj.Pkg().Scope() { 592 return nil, false 593 } 594 name = types.TypeString(obj.Type(), nil) 595 _, ok = match(m, fn.Name, name) 596 if ok || !obj.IsAlias() { 597 return origObj, ok 598 } else { 599 // FIXME(dh): we should peel away one layer of alias at a time; this is blocked on 600 // github.com/golang/go/issues/66559 601 switch typ := types.Unalias(obj.Type()).(type) { 602 case interface{ Obj() *types.TypeName }: 603 obj = typ.Obj() 604 case *types.Basic: 605 return match(m, fn.Name, typ.Name()) 606 default: 607 return nil, false 608 } 609 } 610 } 611 case *types.Const, *types.Var: 612 if obj.Pkg() == nil { 613 return nil, false 614 } 615 if obj.Parent() != obj.Pkg().Scope() { 616 return nil, false 617 } 618 name = fmt.Sprintf("%s.%s", obj.Pkg().Path(), obj.Name()) 619 default: 620 return nil, false 621 } 622 623 _, ok = match(m, fn.Name, name) 624 return obj, ok 625 } 626 627 func (or Or) Match(m *Matcher, node interface{}) (interface{}, bool) { 628 for _, opt := range or.Nodes { 629 m.push() 630 if ret, ok := match(m, opt, node); ok { 631 m.merge() 632 return ret, true 633 } else { 634 m.pop() 635 } 636 } 637 return nil, false 638 } 639 640 func (not Not) Match(m *Matcher, node interface{}) (interface{}, bool) { 641 _, ok := match(m, not.Node, node) 642 if ok { 643 return nil, false 644 } 645 return node, true 646 } 647 648 var integerLiteralQ = MustParse(`(Or (BasicLit "INT" _) (UnaryExpr (Or "+" "-") (IntegerLiteral _)))`) 649 650 func (lit IntegerLiteral) Match(m *Matcher, node interface{}) (interface{}, bool) { 651 matched, ok := match(m, integerLiteralQ.Root, node) 652 if !ok { 653 return nil, false 654 } 655 tv, ok := m.TypesInfo.Types[matched.(ast.Expr)] 656 if !ok { 657 return nil, false 658 } 659 if tv.Value == nil { 660 return nil, false 661 } 662 _, ok = match(m, lit.Value, tv) 663 return matched, ok 664 } 665 666 func (texpr TrulyConstantExpression) Match(m *Matcher, node interface{}) (interface{}, bool) { 667 expr, ok := node.(ast.Expr) 668 if !ok { 669 return nil, false 670 } 671 tv, ok := m.TypesInfo.Types[expr] 672 if !ok { 673 return nil, false 674 } 675 if tv.Value == nil { 676 return nil, false 677 } 678 truly := true 679 ast.Inspect(expr, func(node ast.Node) bool { 680 if _, ok := node.(*ast.Ident); ok { 681 truly = false 682 return false 683 } 684 return true 685 }) 686 if !truly { 687 return nil, false 688 } 689 _, ok = match(m, texpr.Value, tv) 690 return expr, ok 691 } 692 693 var ( 694 // Types of fields in go/ast structs that we want to skip 695 rtTokPos = reflect.TypeOf(token.Pos(0)) 696 rtObject = reflect.TypeOf((*ast.Object)(nil)) 697 rtCommentGroup = reflect.TypeOf((*ast.CommentGroup)(nil)) 698 ) 699 700 var ( 701 _ matcher = Binding{} 702 _ matcher = Any{} 703 _ matcher = List{} 704 _ matcher = String("") 705 _ matcher = Token(0) 706 _ matcher = Nil{} 707 _ matcher = Builtin{} 708 _ matcher = Object{} 709 _ matcher = Symbol{} 710 _ matcher = Or{} 711 _ matcher = Not{} 712 _ matcher = IntegerLiteral{} 713 _ matcher = TrulyConstantExpression{} 714 )