github.com/amarpal/go-tools@v0.0.0-20240422043104-40142f59f616/pattern/match.go (about)

     1  package pattern
     2  
     3  import (
     4  	"fmt"
     5  	"go/ast"
     6  	"go/token"
     7  	"go/types"
     8  	"reflect"
     9  
    10  	"golang.org/x/tools/go/ast/astutil"
    11  )
    12  
    13  var tokensByString = map[string]Token{
    14  	"INT":         Token(token.INT),
    15  	"FLOAT":       Token(token.FLOAT),
    16  	"IMAG":        Token(token.IMAG),
    17  	"CHAR":        Token(token.CHAR),
    18  	"STRING":      Token(token.STRING),
    19  	"+":           Token(token.ADD),
    20  	"-":           Token(token.SUB),
    21  	"*":           Token(token.MUL),
    22  	"/":           Token(token.QUO),
    23  	"%":           Token(token.REM),
    24  	"&":           Token(token.AND),
    25  	"|":           Token(token.OR),
    26  	"^":           Token(token.XOR),
    27  	"<<":          Token(token.SHL),
    28  	">>":          Token(token.SHR),
    29  	"&^":          Token(token.AND_NOT),
    30  	"+=":          Token(token.ADD_ASSIGN),
    31  	"-=":          Token(token.SUB_ASSIGN),
    32  	"*=":          Token(token.MUL_ASSIGN),
    33  	"/=":          Token(token.QUO_ASSIGN),
    34  	"%=":          Token(token.REM_ASSIGN),
    35  	"&=":          Token(token.AND_ASSIGN),
    36  	"|=":          Token(token.OR_ASSIGN),
    37  	"^=":          Token(token.XOR_ASSIGN),
    38  	"<<=":         Token(token.SHL_ASSIGN),
    39  	">>=":         Token(token.SHR_ASSIGN),
    40  	"&^=":         Token(token.AND_NOT_ASSIGN),
    41  	"&&":          Token(token.LAND),
    42  	"||":          Token(token.LOR),
    43  	"<-":          Token(token.ARROW),
    44  	"++":          Token(token.INC),
    45  	"--":          Token(token.DEC),
    46  	"==":          Token(token.EQL),
    47  	"<":           Token(token.LSS),
    48  	">":           Token(token.GTR),
    49  	"=":           Token(token.ASSIGN),
    50  	"!":           Token(token.NOT),
    51  	"!=":          Token(token.NEQ),
    52  	"<=":          Token(token.LEQ),
    53  	">=":          Token(token.GEQ),
    54  	":=":          Token(token.DEFINE),
    55  	"...":         Token(token.ELLIPSIS),
    56  	"IMPORT":      Token(token.IMPORT),
    57  	"VAR":         Token(token.VAR),
    58  	"TYPE":        Token(token.TYPE),
    59  	"CONST":       Token(token.CONST),
    60  	"BREAK":       Token(token.BREAK),
    61  	"CONTINUE":    Token(token.CONTINUE),
    62  	"GOTO":        Token(token.GOTO),
    63  	"FALLTHROUGH": Token(token.FALLTHROUGH),
    64  }
    65  
    66  func maybeToken(node Node) (Node, bool) {
    67  	if node, ok := node.(String); ok {
    68  		if tok, ok := tokensByString[string(node)]; ok {
    69  			return tok, true
    70  		}
    71  		return node, false
    72  	}
    73  	return node, false
    74  }
    75  
    76  func isNil(v interface{}) bool {
    77  	if v == nil {
    78  		return true
    79  	}
    80  	if _, ok := v.(Nil); ok {
    81  		return true
    82  	}
    83  	return false
    84  }
    85  
    86  type matcher interface {
    87  	Match(*Matcher, interface{}) (interface{}, bool)
    88  }
    89  
    90  type State = map[string]any
    91  
    92  type Matcher struct {
    93  	TypesInfo *types.Info
    94  	State     State
    95  
    96  	bindingsMapping []string
    97  
    98  	setBindings []uint64
    99  }
   100  
   101  func (m *Matcher) set(b Binding, value interface{}) {
   102  	m.State[b.Name] = value
   103  	m.setBindings[len(m.setBindings)-1] |= 1 << b.idx
   104  }
   105  
   106  func (m *Matcher) push() {
   107  	m.setBindings = append(m.setBindings, 0)
   108  }
   109  
   110  func (m *Matcher) pop() {
   111  	set := m.setBindings[len(m.setBindings)-1]
   112  	if set != 0 {
   113  		for i := 0; i < len(m.bindingsMapping); i++ {
   114  			if (set & (1 << i)) != 0 {
   115  				key := m.bindingsMapping[i]
   116  				delete(m.State, key)
   117  			}
   118  		}
   119  	}
   120  	m.setBindings = m.setBindings[:len(m.setBindings)-1]
   121  }
   122  
   123  func (m *Matcher) merge() {
   124  	m.setBindings = m.setBindings[:len(m.setBindings)-1]
   125  }
   126  
   127  func (m *Matcher) Match(a Pattern, b ast.Node) bool {
   128  	m.bindingsMapping = a.Bindings
   129  	m.State = State{}
   130  	m.push()
   131  	_, ok := match(m, a.Root, b)
   132  	m.merge()
   133  	if len(m.setBindings) != 0 {
   134  		panic(fmt.Sprintf("%d entries left on the stack, expected none", len(m.setBindings)))
   135  	}
   136  	return ok
   137  }
   138  
   139  func Match(a Pattern, b ast.Node) (*Matcher, bool) {
   140  	m := &Matcher{}
   141  	ret := m.Match(a, b)
   142  	return m, ret
   143  }
   144  
   145  // Match two items, which may be (Node, AST) or (AST, AST)
   146  func match(m *Matcher, l, r interface{}) (interface{}, bool) {
   147  	if _, ok := r.(Node); ok {
   148  		panic("Node mustn't be on right side of match")
   149  	}
   150  
   151  	switch l := l.(type) {
   152  	case *ast.ParenExpr:
   153  		return match(m, l.X, r)
   154  	case *ast.ExprStmt:
   155  		return match(m, l.X, r)
   156  	case *ast.DeclStmt:
   157  		return match(m, l.Decl, r)
   158  	case *ast.LabeledStmt:
   159  		return match(m, l.Stmt, r)
   160  	case *ast.BlockStmt:
   161  		return match(m, l.List, r)
   162  	case *ast.FieldList:
   163  		if l == nil {
   164  			return match(m, nil, r)
   165  		} else {
   166  			return match(m, l.List, r)
   167  		}
   168  	}
   169  
   170  	switch r := r.(type) {
   171  	case *ast.ParenExpr:
   172  		return match(m, l, r.X)
   173  	case *ast.ExprStmt:
   174  		return match(m, l, r.X)
   175  	case *ast.DeclStmt:
   176  		return match(m, l, r.Decl)
   177  	case *ast.LabeledStmt:
   178  		return match(m, l, r.Stmt)
   179  	case *ast.BlockStmt:
   180  		if r == nil {
   181  			return match(m, l, nil)
   182  		}
   183  		return match(m, l, r.List)
   184  	case *ast.FieldList:
   185  		if r == nil {
   186  			return match(m, l, nil)
   187  		}
   188  		return match(m, l, r.List)
   189  	case *ast.BasicLit:
   190  		if r == nil {
   191  			return match(m, l, nil)
   192  		}
   193  	}
   194  
   195  	if l, ok := l.(matcher); ok {
   196  		return l.Match(m, r)
   197  	}
   198  
   199  	if l, ok := l.(Node); ok {
   200  		// Matching of pattern with concrete value
   201  		return matchNodeAST(m, l, r)
   202  	}
   203  
   204  	if l == nil || r == nil {
   205  		return nil, l == r
   206  	}
   207  
   208  	{
   209  		ln, ok1 := l.(ast.Node)
   210  		rn, ok2 := r.(ast.Node)
   211  		if ok1 && ok2 {
   212  			return matchAST(m, ln, rn)
   213  		}
   214  	}
   215  
   216  	{
   217  		obj, ok := l.(types.Object)
   218  		if ok {
   219  			switch r := r.(type) {
   220  			case *ast.Ident:
   221  				return obj, obj == m.TypesInfo.ObjectOf(r)
   222  			case *ast.SelectorExpr:
   223  				return obj, obj == m.TypesInfo.ObjectOf(r.Sel)
   224  			default:
   225  				return obj, false
   226  			}
   227  		}
   228  	}
   229  
   230  	// TODO(dh): the three blocks handling slices can be combined into a single block if we use reflection
   231  
   232  	{
   233  		ln, ok1 := l.([]ast.Expr)
   234  		rn, ok2 := r.([]ast.Expr)
   235  		if ok1 || ok2 {
   236  			if ok1 && !ok2 {
   237  				cast, ok := r.(ast.Expr)
   238  				if !ok {
   239  					return nil, false
   240  				}
   241  				rn = []ast.Expr{cast}
   242  			} else if !ok1 && ok2 {
   243  				cast, ok := l.(ast.Expr)
   244  				if !ok {
   245  					return nil, false
   246  				}
   247  				ln = []ast.Expr{cast}
   248  			}
   249  
   250  			if len(ln) != len(rn) {
   251  				return nil, false
   252  			}
   253  			for i, ll := range ln {
   254  				if _, ok := match(m, ll, rn[i]); !ok {
   255  					return nil, false
   256  				}
   257  			}
   258  			return r, true
   259  		}
   260  	}
   261  
   262  	{
   263  		ln, ok1 := l.([]ast.Stmt)
   264  		rn, ok2 := r.([]ast.Stmt)
   265  		if ok1 || ok2 {
   266  			if ok1 && !ok2 {
   267  				cast, ok := r.(ast.Stmt)
   268  				if !ok {
   269  					return nil, false
   270  				}
   271  				rn = []ast.Stmt{cast}
   272  			} else if !ok1 && ok2 {
   273  				cast, ok := l.(ast.Stmt)
   274  				if !ok {
   275  					return nil, false
   276  				}
   277  				ln = []ast.Stmt{cast}
   278  			}
   279  
   280  			if len(ln) != len(rn) {
   281  				return nil, false
   282  			}
   283  			for i, ll := range ln {
   284  				if _, ok := match(m, ll, rn[i]); !ok {
   285  					return nil, false
   286  				}
   287  			}
   288  			return r, true
   289  		}
   290  	}
   291  
   292  	{
   293  		ln, ok1 := l.([]*ast.Field)
   294  		rn, ok2 := r.([]*ast.Field)
   295  		if ok1 || ok2 {
   296  			if ok1 && !ok2 {
   297  				cast, ok := r.(*ast.Field)
   298  				if !ok {
   299  					return nil, false
   300  				}
   301  				rn = []*ast.Field{cast}
   302  			} else if !ok1 && ok2 {
   303  				cast, ok := l.(*ast.Field)
   304  				if !ok {
   305  					return nil, false
   306  				}
   307  				ln = []*ast.Field{cast}
   308  			}
   309  
   310  			if len(ln) != len(rn) {
   311  				return nil, false
   312  			}
   313  			for i, ll := range ln {
   314  				if _, ok := match(m, ll, rn[i]); !ok {
   315  					return nil, false
   316  				}
   317  			}
   318  			return r, true
   319  		}
   320  	}
   321  
   322  	return nil, false
   323  }
   324  
   325  // Match a Node with an AST node
   326  func matchNodeAST(m *Matcher, a Node, b interface{}) (interface{}, bool) {
   327  	switch b := b.(type) {
   328  	case []ast.Stmt:
   329  		// 'a' is not a List or we'd be using its Match
   330  		// implementation.
   331  
   332  		if len(b) != 1 {
   333  			return nil, false
   334  		}
   335  		return match(m, a, b[0])
   336  	case []ast.Expr:
   337  		// 'a' is not a List or we'd be using its Match
   338  		// implementation.
   339  
   340  		if len(b) != 1 {
   341  			return nil, false
   342  		}
   343  		return match(m, a, b[0])
   344  	case []*ast.Field:
   345  		// 'a' is not a List or we'd be using its Match
   346  		// implementation
   347  		if len(b) != 1 {
   348  			return nil, false
   349  		}
   350  		return match(m, a, b[0])
   351  	case ast.Node:
   352  		ra := reflect.ValueOf(a)
   353  		rb := reflect.ValueOf(b).Elem()
   354  
   355  		if ra.Type().Name() != rb.Type().Name() {
   356  			return nil, false
   357  		}
   358  
   359  		for i := 0; i < ra.NumField(); i++ {
   360  			af := ra.Field(i)
   361  			fieldName := ra.Type().Field(i).Name
   362  			bf := rb.FieldByName(fieldName)
   363  			if (bf == reflect.Value{}) {
   364  				panic(fmt.Sprintf("internal error: could not find field %s in type %t when comparing with %T", fieldName, b, a))
   365  			}
   366  			ai := af.Interface()
   367  			bi := bf.Interface()
   368  			if ai == nil {
   369  				return b, bi == nil
   370  			}
   371  			if _, ok := match(m, ai.(Node), bi); !ok {
   372  				return b, false
   373  			}
   374  		}
   375  		return b, true
   376  	case nil:
   377  		return nil, a == Nil{}
   378  	case string, token.Token:
   379  		// 'a' can't be a String, Token, or Binding or we'd be using their Match implementations.
   380  		return nil, false
   381  	default:
   382  		panic(fmt.Sprintf("unhandled type %T", b))
   383  	}
   384  }
   385  
   386  // Match two AST nodes
   387  func matchAST(m *Matcher, a, b ast.Node) (interface{}, bool) {
   388  	ra := reflect.ValueOf(a)
   389  	rb := reflect.ValueOf(b)
   390  
   391  	if ra.Type() != rb.Type() {
   392  		return nil, false
   393  	}
   394  	if ra.IsNil() || rb.IsNil() {
   395  		return rb, ra.IsNil() == rb.IsNil()
   396  	}
   397  
   398  	ra = ra.Elem()
   399  	rb = rb.Elem()
   400  	for i := 0; i < ra.NumField(); i++ {
   401  		af := ra.Field(i)
   402  		bf := rb.Field(i)
   403  		if af.Type() == rtTokPos || af.Type() == rtObject || af.Type() == rtCommentGroup {
   404  			continue
   405  		}
   406  
   407  		switch af.Kind() {
   408  		case reflect.Slice:
   409  			if af.Len() != bf.Len() {
   410  				return nil, false
   411  			}
   412  			for j := 0; j < af.Len(); j++ {
   413  				if _, ok := match(m, af.Index(j).Interface().(ast.Node), bf.Index(j).Interface().(ast.Node)); !ok {
   414  					return nil, false
   415  				}
   416  			}
   417  		case reflect.String:
   418  			if af.String() != bf.String() {
   419  				return nil, false
   420  			}
   421  		case reflect.Int:
   422  			if af.Int() != bf.Int() {
   423  				return nil, false
   424  			}
   425  		case reflect.Bool:
   426  			if af.Bool() != bf.Bool() {
   427  				return nil, false
   428  			}
   429  		case reflect.Ptr, reflect.Interface:
   430  			if _, ok := match(m, af.Interface(), bf.Interface()); !ok {
   431  				return nil, false
   432  			}
   433  		default:
   434  			panic(fmt.Sprintf("internal error: unhandled kind %s (%T)", af.Kind(), af.Interface()))
   435  		}
   436  	}
   437  	return b, true
   438  }
   439  
   440  func (b Binding) Match(m *Matcher, node interface{}) (interface{}, bool) {
   441  	if isNil(b.Node) {
   442  		v, ok := m.State[b.Name]
   443  		if ok {
   444  			// Recall value
   445  			return match(m, v, node)
   446  		}
   447  		// Matching anything
   448  		b.Node = Any{}
   449  	}
   450  
   451  	// Store value
   452  	if _, ok := m.State[b.Name]; ok {
   453  		panic(fmt.Sprintf("binding already created: %s", b.Name))
   454  	}
   455  	new, ret := match(m, b.Node, node)
   456  	if ret {
   457  		m.set(b, new)
   458  	}
   459  	return new, ret
   460  }
   461  
   462  func (Any) Match(m *Matcher, node interface{}) (interface{}, bool) {
   463  	return node, true
   464  }
   465  
   466  func (l List) Match(m *Matcher, node interface{}) (interface{}, bool) {
   467  	v := reflect.ValueOf(node)
   468  	if v.Kind() == reflect.Slice {
   469  		if isNil(l.Head) {
   470  			return node, v.Len() == 0
   471  		}
   472  		if v.Len() == 0 {
   473  			return nil, false
   474  		}
   475  		// OPT(dh): don't check the entire tail if head didn't match
   476  		_, ok1 := match(m, l.Head, v.Index(0).Interface())
   477  		_, ok2 := match(m, l.Tail, v.Slice(1, v.Len()).Interface())
   478  		return node, ok1 && ok2
   479  	}
   480  	// Our empty list does not equal an untyped Go nil. This way, we can
   481  	// tell apart an if with no else and an if with an empty else.
   482  	return nil, false
   483  }
   484  
   485  func (s String) Match(m *Matcher, node interface{}) (interface{}, bool) {
   486  	switch o := node.(type) {
   487  	case token.Token:
   488  		if tok, ok := maybeToken(s); ok {
   489  			return match(m, tok, node)
   490  		}
   491  		return nil, false
   492  	case string:
   493  		return o, string(s) == o
   494  	case types.TypeAndValue:
   495  		return o, o.Value != nil && o.Value.String() == string(s)
   496  	default:
   497  		return nil, false
   498  	}
   499  }
   500  
   501  func (tok Token) Match(m *Matcher, node interface{}) (interface{}, bool) {
   502  	o, ok := node.(token.Token)
   503  	if !ok {
   504  		return nil, false
   505  	}
   506  	return o, token.Token(tok) == o
   507  }
   508  
   509  func (Nil) Match(m *Matcher, node interface{}) (interface{}, bool) {
   510  	if isNil(node) {
   511  		return nil, true
   512  	}
   513  	v := reflect.ValueOf(node)
   514  	switch v.Kind() {
   515  	case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Pointer, reflect.Slice:
   516  		return nil, v.IsNil()
   517  	default:
   518  		return nil, false
   519  	}
   520  }
   521  
   522  func (builtin Builtin) Match(m *Matcher, node interface{}) (interface{}, bool) {
   523  	r, ok := match(m, Ident(builtin), node)
   524  	if !ok {
   525  		return nil, false
   526  	}
   527  	ident := r.(*ast.Ident)
   528  	obj := m.TypesInfo.ObjectOf(ident)
   529  	if obj != types.Universe.Lookup(ident.Name) {
   530  		return nil, false
   531  	}
   532  	return ident, true
   533  }
   534  
   535  func (obj Object) Match(m *Matcher, node interface{}) (interface{}, bool) {
   536  	r, ok := match(m, Ident(obj), node)
   537  	if !ok {
   538  		return nil, false
   539  	}
   540  	ident := r.(*ast.Ident)
   541  
   542  	id := m.TypesInfo.ObjectOf(ident)
   543  	_, ok = match(m, obj.Name, ident.Name)
   544  	return id, ok
   545  }
   546  
   547  func (fn Symbol) Match(m *Matcher, node interface{}) (interface{}, bool) {
   548  	var name string
   549  	var obj types.Object
   550  
   551  	base := []Node{
   552  		Ident{Any{}},
   553  		SelectorExpr{Any{}, Any{}},
   554  	}
   555  	p := Or{
   556  		Nodes: append(base,
   557  			IndexExpr{Or{Nodes: base}, Any{}},
   558  			IndexListExpr{Or{Nodes: base}, Any{}})}
   559  
   560  	r, ok := match(m, p, node)
   561  	if !ok {
   562  		return nil, false
   563  	}
   564  
   565  	fun := r.(ast.Expr)
   566  	switch idx := fun.(type) {
   567  	case *ast.IndexExpr:
   568  		fun = idx.X
   569  	case *ast.IndexListExpr:
   570  		fun = idx.X
   571  	}
   572  	fun = astutil.Unparen(fun)
   573  
   574  	switch fun := fun.(type) {
   575  	case *ast.Ident:
   576  		obj = m.TypesInfo.ObjectOf(fun)
   577  	case *ast.SelectorExpr:
   578  		obj = m.TypesInfo.ObjectOf(fun.Sel)
   579  	default:
   580  		panic("unreachable")
   581  	}
   582  	switch obj := obj.(type) {
   583  	case *types.Func:
   584  		// OPT(dh): optimize this similar to code.FuncName
   585  		name = obj.FullName()
   586  	case *types.Builtin:
   587  		name = obj.Name()
   588  	case *types.TypeName:
   589  		if obj.Pkg() == nil {
   590  			return nil, false
   591  		}
   592  		if obj.Parent() != obj.Pkg().Scope() {
   593  			return nil, false
   594  		}
   595  		name = types.TypeString(obj.Type(), nil)
   596  	case *types.Const, *types.Var:
   597  		if obj.Pkg() == nil {
   598  			return nil, false
   599  		}
   600  		if obj.Parent() != obj.Pkg().Scope() {
   601  			return nil, false
   602  		}
   603  		name = fmt.Sprintf("%s.%s", obj.Pkg().Path(), obj.Name())
   604  	default:
   605  		return nil, false
   606  	}
   607  
   608  	_, ok = match(m, fn.Name, name)
   609  	return obj, ok
   610  }
   611  
   612  func (or Or) Match(m *Matcher, node interface{}) (interface{}, bool) {
   613  	for _, opt := range or.Nodes {
   614  		m.push()
   615  		if ret, ok := match(m, opt, node); ok {
   616  			m.merge()
   617  			return ret, true
   618  		} else {
   619  			m.pop()
   620  		}
   621  	}
   622  	return nil, false
   623  }
   624  
   625  func (not Not) Match(m *Matcher, node interface{}) (interface{}, bool) {
   626  	_, ok := match(m, not.Node, node)
   627  	if ok {
   628  		return nil, false
   629  	}
   630  	return node, true
   631  }
   632  
   633  var integerLiteralQ = MustParse(`(Or (BasicLit "INT" _) (UnaryExpr (Or "+" "-") (IntegerLiteral _)))`)
   634  
   635  func (lit IntegerLiteral) Match(m *Matcher, node interface{}) (interface{}, bool) {
   636  	matched, ok := match(m, integerLiteralQ.Root, node)
   637  	if !ok {
   638  		return nil, false
   639  	}
   640  	tv, ok := m.TypesInfo.Types[matched.(ast.Expr)]
   641  	if !ok {
   642  		return nil, false
   643  	}
   644  	if tv.Value == nil {
   645  		return nil, false
   646  	}
   647  	_, ok = match(m, lit.Value, tv)
   648  	return matched, ok
   649  }
   650  
   651  func (texpr TrulyConstantExpression) Match(m *Matcher, node interface{}) (interface{}, bool) {
   652  	expr, ok := node.(ast.Expr)
   653  	if !ok {
   654  		return nil, false
   655  	}
   656  	tv, ok := m.TypesInfo.Types[expr]
   657  	if !ok {
   658  		return nil, false
   659  	}
   660  	if tv.Value == nil {
   661  		return nil, false
   662  	}
   663  	truly := true
   664  	ast.Inspect(expr, func(node ast.Node) bool {
   665  		if _, ok := node.(*ast.Ident); ok {
   666  			truly = false
   667  			return false
   668  		}
   669  		return true
   670  	})
   671  	if !truly {
   672  		return nil, false
   673  	}
   674  	_, ok = match(m, texpr.Value, tv)
   675  	return expr, ok
   676  }
   677  
   678  var (
   679  	// Types of fields in go/ast structs that we want to skip
   680  	rtTokPos       = reflect.TypeOf(token.Pos(0))
   681  	rtObject       = reflect.TypeOf((*ast.Object)(nil))
   682  	rtCommentGroup = reflect.TypeOf((*ast.CommentGroup)(nil))
   683  )
   684  
   685  var (
   686  	_ matcher = Binding{}
   687  	_ matcher = Any{}
   688  	_ matcher = List{}
   689  	_ matcher = String("")
   690  	_ matcher = Token(0)
   691  	_ matcher = Nil{}
   692  	_ matcher = Builtin{}
   693  	_ matcher = Object{}
   694  	_ matcher = Symbol{}
   695  	_ matcher = Or{}
   696  	_ matcher = Not{}
   697  	_ matcher = IntegerLiteral{}
   698  	_ matcher = TrulyConstantExpression{}
   699  )