github.com/amarpal/go-tools@v0.0.0-20240422043104-40142f59f616/pattern/parser.go (about) 1 package pattern 2 3 import ( 4 "errors" 5 "fmt" 6 "go/ast" 7 "go/token" 8 "reflect" 9 ) 10 11 type Pattern struct { 12 Root Node 13 // Relevant contains instances of ast.Node that could potentially 14 // initiate a successful match of the pattern. 15 Relevant map[reflect.Type]struct{} 16 17 // Mapping from binding index to binding name 18 Bindings []string 19 } 20 21 func MustParse(s string) Pattern { 22 p := &Parser{AllowTypeInfo: true} 23 pat, err := p.Parse(s) 24 if err != nil { 25 panic(err) 26 } 27 return pat 28 } 29 30 func roots(node Node, m map[reflect.Type]struct{}) { 31 switch node := node.(type) { 32 case Or: 33 for _, el := range node.Nodes { 34 roots(el, m) 35 } 36 case Not: 37 roots(node.Node, m) 38 case Binding: 39 roots(node.Node, m) 40 case Nil, nil: 41 // this branch is reached via bindings 42 for _, T := range allTypes { 43 m[T] = struct{}{} 44 } 45 default: 46 Ts, ok := nodeToASTTypes[reflect.TypeOf(node)] 47 if !ok { 48 panic(fmt.Sprintf("internal error: unhandled type %T", node)) 49 } 50 for _, T := range Ts { 51 m[T] = struct{}{} 52 } 53 } 54 } 55 56 var allTypes = []reflect.Type{ 57 reflect.TypeOf((*ast.RangeStmt)(nil)), 58 reflect.TypeOf((*ast.AssignStmt)(nil)), 59 reflect.TypeOf((*ast.IndexExpr)(nil)), 60 reflect.TypeOf((*ast.Ident)(nil)), 61 reflect.TypeOf((*ast.ValueSpec)(nil)), 62 reflect.TypeOf((*ast.GenDecl)(nil)), 63 reflect.TypeOf((*ast.BinaryExpr)(nil)), 64 reflect.TypeOf((*ast.ForStmt)(nil)), 65 reflect.TypeOf((*ast.ArrayType)(nil)), 66 reflect.TypeOf((*ast.DeferStmt)(nil)), 67 reflect.TypeOf((*ast.MapType)(nil)), 68 reflect.TypeOf((*ast.ReturnStmt)(nil)), 69 reflect.TypeOf((*ast.SliceExpr)(nil)), 70 reflect.TypeOf((*ast.StarExpr)(nil)), 71 reflect.TypeOf((*ast.UnaryExpr)(nil)), 72 reflect.TypeOf((*ast.SendStmt)(nil)), 73 reflect.TypeOf((*ast.SelectStmt)(nil)), 74 reflect.TypeOf((*ast.ImportSpec)(nil)), 75 reflect.TypeOf((*ast.IfStmt)(nil)), 76 reflect.TypeOf((*ast.GoStmt)(nil)), 77 reflect.TypeOf((*ast.Field)(nil)), 78 reflect.TypeOf((*ast.SelectorExpr)(nil)), 79 reflect.TypeOf((*ast.StructType)(nil)), 80 reflect.TypeOf((*ast.KeyValueExpr)(nil)), 81 reflect.TypeOf((*ast.FuncType)(nil)), 82 reflect.TypeOf((*ast.FuncLit)(nil)), 83 reflect.TypeOf((*ast.FuncDecl)(nil)), 84 reflect.TypeOf((*ast.ChanType)(nil)), 85 reflect.TypeOf((*ast.CallExpr)(nil)), 86 reflect.TypeOf((*ast.CaseClause)(nil)), 87 reflect.TypeOf((*ast.CommClause)(nil)), 88 reflect.TypeOf((*ast.CompositeLit)(nil)), 89 reflect.TypeOf((*ast.EmptyStmt)(nil)), 90 reflect.TypeOf((*ast.SwitchStmt)(nil)), 91 reflect.TypeOf((*ast.TypeSwitchStmt)(nil)), 92 reflect.TypeOf((*ast.TypeAssertExpr)(nil)), 93 reflect.TypeOf((*ast.TypeSpec)(nil)), 94 reflect.TypeOf((*ast.InterfaceType)(nil)), 95 reflect.TypeOf((*ast.BranchStmt)(nil)), 96 reflect.TypeOf((*ast.IncDecStmt)(nil)), 97 reflect.TypeOf((*ast.BasicLit)(nil)), 98 } 99 100 var nodeToASTTypes = map[reflect.Type][]reflect.Type{ 101 reflect.TypeOf(String("")): nil, 102 reflect.TypeOf(Token(0)): nil, 103 reflect.TypeOf(List{}): {reflect.TypeOf((*ast.BlockStmt)(nil)), reflect.TypeOf((*ast.FieldList)(nil))}, 104 reflect.TypeOf(Builtin{}): {reflect.TypeOf((*ast.Ident)(nil))}, 105 reflect.TypeOf(Object{}): {reflect.TypeOf((*ast.Ident)(nil))}, 106 reflect.TypeOf(Symbol{}): {reflect.TypeOf((*ast.Ident)(nil)), reflect.TypeOf((*ast.SelectorExpr)(nil))}, 107 reflect.TypeOf(Any{}): allTypes, 108 reflect.TypeOf(RangeStmt{}): {reflect.TypeOf((*ast.RangeStmt)(nil))}, 109 reflect.TypeOf(AssignStmt{}): {reflect.TypeOf((*ast.AssignStmt)(nil))}, 110 reflect.TypeOf(IndexExpr{}): {reflect.TypeOf((*ast.IndexExpr)(nil))}, 111 reflect.TypeOf(Ident{}): {reflect.TypeOf((*ast.Ident)(nil))}, 112 reflect.TypeOf(ValueSpec{}): {reflect.TypeOf((*ast.ValueSpec)(nil))}, 113 reflect.TypeOf(GenDecl{}): {reflect.TypeOf((*ast.GenDecl)(nil))}, 114 reflect.TypeOf(BinaryExpr{}): {reflect.TypeOf((*ast.BinaryExpr)(nil))}, 115 reflect.TypeOf(ForStmt{}): {reflect.TypeOf((*ast.ForStmt)(nil))}, 116 reflect.TypeOf(ArrayType{}): {reflect.TypeOf((*ast.ArrayType)(nil))}, 117 reflect.TypeOf(DeferStmt{}): {reflect.TypeOf((*ast.DeferStmt)(nil))}, 118 reflect.TypeOf(MapType{}): {reflect.TypeOf((*ast.MapType)(nil))}, 119 reflect.TypeOf(ReturnStmt{}): {reflect.TypeOf((*ast.ReturnStmt)(nil))}, 120 reflect.TypeOf(SliceExpr{}): {reflect.TypeOf((*ast.SliceExpr)(nil))}, 121 reflect.TypeOf(StarExpr{}): {reflect.TypeOf((*ast.StarExpr)(nil))}, 122 reflect.TypeOf(UnaryExpr{}): {reflect.TypeOf((*ast.UnaryExpr)(nil))}, 123 reflect.TypeOf(SendStmt{}): {reflect.TypeOf((*ast.SendStmt)(nil))}, 124 reflect.TypeOf(SelectStmt{}): {reflect.TypeOf((*ast.SelectStmt)(nil))}, 125 reflect.TypeOf(ImportSpec{}): {reflect.TypeOf((*ast.ImportSpec)(nil))}, 126 reflect.TypeOf(IfStmt{}): {reflect.TypeOf((*ast.IfStmt)(nil))}, 127 reflect.TypeOf(GoStmt{}): {reflect.TypeOf((*ast.GoStmt)(nil))}, 128 reflect.TypeOf(Field{}): {reflect.TypeOf((*ast.Field)(nil))}, 129 reflect.TypeOf(SelectorExpr{}): {reflect.TypeOf((*ast.SelectorExpr)(nil))}, 130 reflect.TypeOf(StructType{}): {reflect.TypeOf((*ast.StructType)(nil))}, 131 reflect.TypeOf(KeyValueExpr{}): {reflect.TypeOf((*ast.KeyValueExpr)(nil))}, 132 reflect.TypeOf(FuncType{}): {reflect.TypeOf((*ast.FuncType)(nil))}, 133 reflect.TypeOf(FuncLit{}): {reflect.TypeOf((*ast.FuncLit)(nil))}, 134 reflect.TypeOf(FuncDecl{}): {reflect.TypeOf((*ast.FuncDecl)(nil))}, 135 reflect.TypeOf(ChanType{}): {reflect.TypeOf((*ast.ChanType)(nil))}, 136 reflect.TypeOf(CallExpr{}): {reflect.TypeOf((*ast.CallExpr)(nil))}, 137 reflect.TypeOf(CaseClause{}): {reflect.TypeOf((*ast.CaseClause)(nil))}, 138 reflect.TypeOf(CommClause{}): {reflect.TypeOf((*ast.CommClause)(nil))}, 139 reflect.TypeOf(CompositeLit{}): {reflect.TypeOf((*ast.CompositeLit)(nil))}, 140 reflect.TypeOf(EmptyStmt{}): {reflect.TypeOf((*ast.EmptyStmt)(nil))}, 141 reflect.TypeOf(SwitchStmt{}): {reflect.TypeOf((*ast.SwitchStmt)(nil))}, 142 reflect.TypeOf(TypeSwitchStmt{}): {reflect.TypeOf((*ast.TypeSwitchStmt)(nil))}, 143 reflect.TypeOf(TypeAssertExpr{}): {reflect.TypeOf((*ast.TypeAssertExpr)(nil))}, 144 reflect.TypeOf(TypeSpec{}): {reflect.TypeOf((*ast.TypeSpec)(nil))}, 145 reflect.TypeOf(InterfaceType{}): {reflect.TypeOf((*ast.InterfaceType)(nil))}, 146 reflect.TypeOf(BranchStmt{}): {reflect.TypeOf((*ast.BranchStmt)(nil))}, 147 reflect.TypeOf(IncDecStmt{}): {reflect.TypeOf((*ast.IncDecStmt)(nil))}, 148 reflect.TypeOf(BasicLit{}): {reflect.TypeOf((*ast.BasicLit)(nil))}, 149 reflect.TypeOf(IntegerLiteral{}): {reflect.TypeOf((*ast.BasicLit)(nil)), reflect.TypeOf((*ast.UnaryExpr)(nil))}, 150 reflect.TypeOf(TrulyConstantExpression{}): allTypes, // this is an over-approximation, which is fine 151 } 152 153 var requiresTypeInfo = map[string]bool{ 154 "Symbol": true, 155 "Builtin": true, 156 "Object": true, 157 "IntegerLiteral": true, 158 "TrulyConstantExpression": true, 159 } 160 161 type Parser struct { 162 // Allow nodes that rely on type information 163 AllowTypeInfo bool 164 165 lex *lexer 166 cur item 167 last *item 168 items chan item 169 170 bindings map[string]int 171 } 172 173 func (p *Parser) bindingIndex(name string) int { 174 if p.bindings == nil { 175 p.bindings = map[string]int{} 176 } 177 if idx, ok := p.bindings[name]; ok { 178 return idx 179 } 180 idx := len(p.bindings) 181 p.bindings[name] = idx 182 return idx 183 } 184 185 func (p *Parser) Parse(s string) (Pattern, error) { 186 p.cur = item{} 187 p.last = nil 188 p.items = nil 189 190 fset := token.NewFileSet() 191 p.lex = &lexer{ 192 f: fset.AddFile("<input>", -1, len(s)), 193 input: s, 194 items: make(chan item), 195 } 196 go p.lex.run() 197 p.items = p.lex.items 198 root, err := p.node() 199 if err != nil { 200 // drain lexer if parsing failed 201 for range p.lex.items { 202 } 203 return Pattern{}, err 204 } 205 if item := <-p.lex.items; item.typ != itemEOF { 206 return Pattern{}, fmt.Errorf("unexpected token %s after end of pattern", item.typ) 207 } 208 209 if len(p.bindings) > 64 { 210 return Pattern{}, errors.New("encountered more than 64 bindings") 211 } 212 213 bindings := make([]string, len(p.bindings)) 214 for name, idx := range p.bindings { 215 bindings[idx] = name 216 } 217 218 relevant := map[reflect.Type]struct{}{} 219 roots(root, relevant) 220 return Pattern{ 221 Root: root, 222 Relevant: relevant, 223 Bindings: bindings, 224 }, nil 225 } 226 227 func (p *Parser) next() item { 228 if p.last != nil { 229 n := *p.last 230 p.last = nil 231 return n 232 } 233 var ok bool 234 p.cur, ok = <-p.items 235 if !ok { 236 p.cur = item{typ: eof} 237 } 238 return p.cur 239 } 240 241 func (p *Parser) rewind() { 242 p.last = &p.cur 243 } 244 245 func (p *Parser) peek() item { 246 n := p.next() 247 p.rewind() 248 return n 249 } 250 251 func (p *Parser) accept(typ itemType) (item, bool) { 252 n := p.next() 253 if n.typ == typ { 254 return n, true 255 } 256 p.rewind() 257 return item{}, false 258 } 259 260 func (p *Parser) unexpectedToken(valid string) error { 261 if p.cur.typ == itemError { 262 return fmt.Errorf("error lexing input: %s", p.cur.val) 263 } 264 var got string 265 switch p.cur.typ { 266 case itemTypeName, itemVariable, itemString: 267 got = p.cur.val 268 default: 269 got = "'" + p.cur.typ.String() + "'" 270 } 271 272 pos := p.lex.f.Position(token.Pos(p.cur.pos)) 273 return fmt.Errorf("%s: expected %s, found %s", pos, valid, got) 274 } 275 276 func (p *Parser) node() (Node, error) { 277 if _, ok := p.accept(itemLeftParen); !ok { 278 return nil, p.unexpectedToken("'('") 279 } 280 typ, ok := p.accept(itemTypeName) 281 if !ok { 282 return nil, p.unexpectedToken("Node type") 283 } 284 285 var objs []Node 286 for { 287 if _, ok := p.accept(itemRightParen); ok { 288 break 289 } else { 290 p.rewind() 291 obj, err := p.object() 292 if err != nil { 293 return nil, err 294 } 295 objs = append(objs, obj) 296 } 297 } 298 299 node, err := p.populateNode(typ.val, objs) 300 if err != nil { 301 return nil, err 302 } 303 if node, ok := node.(Binding); ok { 304 node.idx = p.bindingIndex(node.Name) 305 } 306 return node, nil 307 } 308 309 func populateNode(typ string, objs []Node, allowTypeInfo bool) (Node, error) { 310 T, ok := structNodes[typ] 311 if !ok { 312 return nil, fmt.Errorf("unknown node %s", typ) 313 } 314 315 if !allowTypeInfo && requiresTypeInfo[typ] { 316 return nil, fmt.Errorf("Node %s requires type information", typ) 317 } 318 319 pv := reflect.New(T) 320 v := pv.Elem() 321 322 if v.NumField() == 1 { 323 f := v.Field(0) 324 if f.Type().Kind() == reflect.Slice { 325 // Variadic node 326 f.Set(reflect.AppendSlice(f, reflect.ValueOf(objs))) 327 return v.Interface().(Node), nil 328 } 329 } 330 331 n := -1 332 for i := 0; i < T.NumField(); i++ { 333 if !T.Field(i).IsExported() { 334 break 335 } 336 n = i 337 } 338 339 if len(objs) != n+1 { 340 return nil, fmt.Errorf("tried to initialize node %s with %d values, expected %d", typ, len(objs), n+1) 341 } 342 343 for i := 0; i < v.NumField(); i++ { 344 if !T.Field(i).IsExported() { 345 break 346 } 347 f := v.Field(i) 348 if f.Kind() == reflect.String { 349 if obj, ok := objs[i].(String); ok { 350 f.Set(reflect.ValueOf(string(obj))) 351 } else { 352 return nil, fmt.Errorf("first argument of (Binding name node) must be string, but got %s", objs[i]) 353 } 354 } else { 355 f.Set(reflect.ValueOf(objs[i])) 356 } 357 } 358 return v.Interface().(Node), nil 359 } 360 361 func (p *Parser) populateNode(typ string, objs []Node) (Node, error) { 362 return populateNode(typ, objs, p.AllowTypeInfo) 363 } 364 365 var structNodes = map[string]reflect.Type{ 366 "Any": reflect.TypeOf(Any{}), 367 "Ellipsis": reflect.TypeOf(Ellipsis{}), 368 "List": reflect.TypeOf(List{}), 369 "Binding": reflect.TypeOf(Binding{}), 370 "RangeStmt": reflect.TypeOf(RangeStmt{}), 371 "AssignStmt": reflect.TypeOf(AssignStmt{}), 372 "IndexExpr": reflect.TypeOf(IndexExpr{}), 373 "Ident": reflect.TypeOf(Ident{}), 374 "Builtin": reflect.TypeOf(Builtin{}), 375 "ValueSpec": reflect.TypeOf(ValueSpec{}), 376 "GenDecl": reflect.TypeOf(GenDecl{}), 377 "BinaryExpr": reflect.TypeOf(BinaryExpr{}), 378 "ForStmt": reflect.TypeOf(ForStmt{}), 379 "ArrayType": reflect.TypeOf(ArrayType{}), 380 "DeferStmt": reflect.TypeOf(DeferStmt{}), 381 "MapType": reflect.TypeOf(MapType{}), 382 "ReturnStmt": reflect.TypeOf(ReturnStmt{}), 383 "SliceExpr": reflect.TypeOf(SliceExpr{}), 384 "StarExpr": reflect.TypeOf(StarExpr{}), 385 "UnaryExpr": reflect.TypeOf(UnaryExpr{}), 386 "SendStmt": reflect.TypeOf(SendStmt{}), 387 "SelectStmt": reflect.TypeOf(SelectStmt{}), 388 "ImportSpec": reflect.TypeOf(ImportSpec{}), 389 "IfStmt": reflect.TypeOf(IfStmt{}), 390 "GoStmt": reflect.TypeOf(GoStmt{}), 391 "Field": reflect.TypeOf(Field{}), 392 "SelectorExpr": reflect.TypeOf(SelectorExpr{}), 393 "StructType": reflect.TypeOf(StructType{}), 394 "KeyValueExpr": reflect.TypeOf(KeyValueExpr{}), 395 "FuncType": reflect.TypeOf(FuncType{}), 396 "FuncLit": reflect.TypeOf(FuncLit{}), 397 "FuncDecl": reflect.TypeOf(FuncDecl{}), 398 "ChanType": reflect.TypeOf(ChanType{}), 399 "CallExpr": reflect.TypeOf(CallExpr{}), 400 "CaseClause": reflect.TypeOf(CaseClause{}), 401 "CommClause": reflect.TypeOf(CommClause{}), 402 "CompositeLit": reflect.TypeOf(CompositeLit{}), 403 "EmptyStmt": reflect.TypeOf(EmptyStmt{}), 404 "SwitchStmt": reflect.TypeOf(SwitchStmt{}), 405 "TypeSwitchStmt": reflect.TypeOf(TypeSwitchStmt{}), 406 "TypeAssertExpr": reflect.TypeOf(TypeAssertExpr{}), 407 "TypeSpec": reflect.TypeOf(TypeSpec{}), 408 "InterfaceType": reflect.TypeOf(InterfaceType{}), 409 "BranchStmt": reflect.TypeOf(BranchStmt{}), 410 "IncDecStmt": reflect.TypeOf(IncDecStmt{}), 411 "BasicLit": reflect.TypeOf(BasicLit{}), 412 "Object": reflect.TypeOf(Object{}), 413 "Symbol": reflect.TypeOf(Symbol{}), 414 "Or": reflect.TypeOf(Or{}), 415 "Not": reflect.TypeOf(Not{}), 416 "IntegerLiteral": reflect.TypeOf(IntegerLiteral{}), 417 "TrulyConstantExpression": reflect.TypeOf(TrulyConstantExpression{}), 418 } 419 420 func (p *Parser) object() (Node, error) { 421 n := p.next() 422 switch n.typ { 423 case itemLeftParen: 424 p.rewind() 425 node, err := p.node() 426 if err != nil { 427 return node, err 428 } 429 if p.peek().typ == itemColon { 430 p.next() 431 tail, err := p.object() 432 if err != nil { 433 return node, err 434 } 435 return List{Head: node, Tail: tail}, nil 436 } 437 return node, nil 438 case itemLeftBracket: 439 p.rewind() 440 return p.array() 441 case itemVariable: 442 v := n 443 if v.val == "nil" { 444 return Nil{}, nil 445 } 446 var b Binding 447 if _, ok := p.accept(itemAt); ok { 448 o, err := p.node() 449 if err != nil { 450 return nil, err 451 } 452 b = Binding{ 453 Name: v.val, 454 Node: o, 455 idx: p.bindingIndex(v.val), 456 } 457 } else { 458 p.rewind() 459 b = Binding{ 460 Name: v.val, 461 idx: p.bindingIndex(v.val), 462 } 463 } 464 if p.peek().typ == itemColon { 465 p.next() 466 tail, err := p.object() 467 if err != nil { 468 return b, err 469 } 470 return List{Head: b, Tail: tail}, nil 471 } 472 return b, nil 473 case itemBlank: 474 if p.peek().typ == itemColon { 475 p.next() 476 tail, err := p.object() 477 if err != nil { 478 return Any{}, err 479 } 480 return List{Head: Any{}, Tail: tail}, nil 481 } 482 return Any{}, nil 483 case itemString: 484 return String(n.val), nil 485 default: 486 return nil, p.unexpectedToken("object") 487 } 488 } 489 490 func (p *Parser) array() (Node, error) { 491 if _, ok := p.accept(itemLeftBracket); !ok { 492 return nil, p.unexpectedToken("'['") 493 } 494 495 var objs []Node 496 for { 497 if _, ok := p.accept(itemRightBracket); ok { 498 break 499 } else { 500 p.rewind() 501 obj, err := p.object() 502 if err != nil { 503 return nil, err 504 } 505 objs = append(objs, obj) 506 } 507 } 508 509 tail := List{} 510 for i := len(objs) - 1; i >= 0; i-- { 511 l := List{ 512 Head: objs[i], 513 Tail: tail, 514 } 515 tail = l 516 } 517 return tail, nil 518 } 519 520 /* 521 Node ::= itemLeftParen itemTypeName Object* itemRightParen 522 Object ::= Node | Array | Binding | itemVariable | itemBlank | itemString 523 Array := itemLeftBracket Object* itemRightBracket 524 Array := Object itemColon Object 525 Binding ::= itemVariable itemAt Node 526 */