github.com/dfcfw/lua@v0.0.0-20230325031207-0cc7ffb7b8b9/pm/pm.go (about) 1 // Lua pattern match functions for Go 2 package pm 3 4 import ( 5 "fmt" 6 ) 7 8 const ( 9 EOS = -1 10 _UNKNOWN = -2 11 ) 12 13 /* Error {{{ */ 14 15 type Error struct { 16 Pos int 17 Message string 18 } 19 20 func newError(pos int, message string, args ...interface{}) *Error { 21 if len(args) == 0 { 22 return &Error{pos, message} 23 } 24 return &Error{pos, fmt.Sprintf(message, args...)} 25 } 26 27 func (e *Error) Error() string { 28 switch e.Pos { 29 case EOS: 30 return fmt.Sprintf("%s at EOS", e.Message) 31 case _UNKNOWN: 32 return fmt.Sprintf("%s", e.Message) 33 default: 34 return fmt.Sprintf("%s at %d", e.Message, e.Pos) 35 } 36 } 37 38 /* }}} */ 39 40 /* MatchData {{{ */ 41 42 type MatchData struct { 43 // captured positions 44 // layout 45 // xxxx xxxx xxxx xxx0 : caputured positions 46 // xxxx xxxx xxxx xxx1 : position captured positions 47 captures []uint32 48 } 49 50 func newMatchState() *MatchData { return &MatchData{[]uint32{}} } 51 52 func (st *MatchData) addPosCapture(s, pos int) { 53 for s+1 >= len(st.captures) { 54 st.captures = append(st.captures, 0) 55 } 56 st.captures[s] = (uint32(pos) << 1) | 1 57 st.captures[s+1] = (uint32(pos) << 1) | 1 58 } 59 60 func (st *MatchData) setCapture(s, pos int) uint32 { 61 for s >= len(st.captures) { 62 st.captures = append(st.captures, 0) 63 } 64 v := st.captures[s] 65 st.captures[s] = (uint32(pos) << 1) 66 return v 67 } 68 69 func (st *MatchData) restoreCapture(s int, pos uint32) { st.captures[s] = pos } 70 71 func (st *MatchData) CaptureLength() int { return len(st.captures) } 72 73 func (st *MatchData) IsPosCapture(idx int) bool { return (st.captures[idx] & 1) == 1 } 74 75 func (st *MatchData) Capture(idx int) int { return int(st.captures[idx] >> 1) } 76 77 /* }}} */ 78 79 /* scanner {{{ */ 80 81 type scannerState struct { 82 Pos int 83 started bool 84 } 85 86 type scanner struct { 87 src []byte 88 State scannerState 89 saved scannerState 90 } 91 92 func newScanner(src []byte) *scanner { 93 return &scanner{ 94 src: src, 95 State: scannerState{ 96 Pos: 0, 97 started: false, 98 }, 99 saved: scannerState{}, 100 } 101 } 102 103 func (sc *scanner) Length() int { return len(sc.src) } 104 105 func (sc *scanner) Next() int { 106 if !sc.State.started { 107 sc.State.started = true 108 if len(sc.src) == 0 { 109 sc.State.Pos = EOS 110 } 111 } else { 112 sc.State.Pos = sc.NextPos() 113 } 114 if sc.State.Pos == EOS { 115 return EOS 116 } 117 return int(sc.src[sc.State.Pos]) 118 } 119 120 func (sc *scanner) CurrentPos() int { 121 return sc.State.Pos 122 } 123 124 func (sc *scanner) NextPos() int { 125 if sc.State.Pos == EOS || sc.State.Pos >= len(sc.src)-1 { 126 return EOS 127 } 128 if !sc.State.started { 129 return 0 130 } else { 131 return sc.State.Pos + 1 132 } 133 } 134 135 func (sc *scanner) Peek() int { 136 cureof := sc.State.Pos == EOS 137 ch := sc.Next() 138 if !cureof { 139 if sc.State.Pos == EOS { 140 sc.State.Pos = len(sc.src) - 1 141 } else { 142 sc.State.Pos-- 143 if sc.State.Pos < 0 { 144 sc.State.Pos = 0 145 sc.State.started = false 146 } 147 } 148 } 149 return ch 150 } 151 152 func (sc *scanner) Save() { sc.saved = sc.State } 153 154 func (sc *scanner) Restore() { sc.State = sc.saved } 155 156 /* }}} */ 157 158 /* bytecode {{{ */ 159 160 type opCode int 161 162 const ( 163 opChar opCode = iota 164 opMatch 165 opTailMatch 166 opJmp 167 opSplit 168 opSave 169 opPSave 170 opBrace 171 opNumber 172 ) 173 174 type inst struct { 175 OpCode opCode 176 Class class 177 Operand1 int 178 Operand2 int 179 } 180 181 /* }}} */ 182 183 /* classes {{{ */ 184 185 type class interface { 186 Matches(ch int) bool 187 } 188 189 type dotClass struct{} 190 191 func (pn *dotClass) Matches(ch int) bool { return true } 192 193 type charClass struct { 194 Ch int 195 } 196 197 func (pn *charClass) Matches(ch int) bool { return pn.Ch == ch } 198 199 type singleClass struct { 200 Class int 201 } 202 203 func (pn *singleClass) Matches(ch int) bool { 204 ret := false 205 switch pn.Class { 206 case 'a', 'A': 207 ret = 'A' <= ch && ch <= 'Z' || 'a' <= ch && ch <= 'z' 208 case 'c', 'C': 209 ret = (0x00 <= ch && ch <= 0x1F) || ch == 0x7F 210 case 'd', 'D': 211 ret = '0' <= ch && ch <= '9' 212 case 'l', 'L': 213 ret = 'a' <= ch && ch <= 'z' 214 case 'p', 'P': 215 ret = (0x21 <= ch && ch <= 0x2f) || (0x3a <= ch && ch <= 0x40) || (0x5b <= ch && ch <= 0x60) || (0x7b <= ch && ch <= 0x7e) 216 case 's', 'S': 217 switch ch { 218 case ' ', '\f', '\n', '\r', '\t', '\v': 219 ret = true 220 } 221 case 'u', 'U': 222 ret = 'A' <= ch && ch <= 'Z' 223 case 'w', 'W': 224 ret = '0' <= ch && ch <= '9' || 'A' <= ch && ch <= 'Z' || 'a' <= ch && ch <= 'z' 225 case 'x', 'X': 226 ret = '0' <= ch && ch <= '9' || 'a' <= ch && ch <= 'f' || 'A' <= ch && ch <= 'F' 227 case 'z', 'Z': 228 ret = ch == 0 229 default: 230 return ch == pn.Class 231 } 232 if 'A' <= pn.Class && pn.Class <= 'Z' { 233 return !ret 234 } 235 return ret 236 } 237 238 type setClass struct { 239 IsNot bool 240 Classes []class 241 } 242 243 func (pn *setClass) Matches(ch int) bool { 244 for _, class := range pn.Classes { 245 if class.Matches(ch) { 246 return !pn.IsNot 247 } 248 } 249 return pn.IsNot 250 } 251 252 type rangeClass struct { 253 Begin class 254 End class 255 } 256 257 func (pn *rangeClass) Matches(ch int) bool { 258 switch begin := pn.Begin.(type) { 259 case *charClass: 260 end, ok := pn.End.(*charClass) 261 if !ok { 262 return false 263 } 264 return begin.Ch <= ch && ch <= end.Ch 265 } 266 return false 267 } 268 269 // }}} 270 271 // patterns {{{ 272 273 type pattern interface{} 274 275 type singlePattern struct { 276 Class class 277 } 278 279 type seqPattern struct { 280 MustHead bool 281 MustTail bool 282 Patterns []pattern 283 } 284 285 type repeatPattern struct { 286 Type int 287 Class class 288 } 289 290 type posCapPattern struct{} 291 292 type capPattern struct { 293 Pattern pattern 294 } 295 296 type numberPattern struct { 297 N int 298 } 299 300 type bracePattern struct { 301 Begin int 302 End int 303 } 304 305 // }}} 306 307 /* parse {{{ */ 308 309 func parseClass(sc *scanner, allowset bool) class { 310 ch := sc.Next() 311 switch ch { 312 case '%': 313 return &singleClass{sc.Next()} 314 case '.': 315 if allowset { 316 return &dotClass{} 317 } 318 return &charClass{ch} 319 case '[': 320 if allowset { 321 return parseClassSet(sc) 322 } 323 return &charClass{ch} 324 // case '^' '$', '(', ')', ']', '*', '+', '-', '?': 325 // panic(newError(sc.CurrentPos(), "invalid %c", ch)) 326 case EOS: 327 panic(newError(sc.CurrentPos(), "unexpected EOS")) 328 default: 329 return &charClass{ch} 330 } 331 } 332 333 func parseClassSet(sc *scanner) class { 334 set := &setClass{false, []class{}} 335 if sc.Peek() == '^' { 336 set.IsNot = true 337 sc.Next() 338 } 339 isrange := false 340 for { 341 ch := sc.Peek() 342 switch ch { 343 // case '[': 344 // panic(newError(sc.CurrentPos(), "'[' can not be nested")) 345 case EOS: 346 panic(newError(sc.CurrentPos(), "unexpected EOS")) 347 case ']': 348 if len(set.Classes) > 0 { 349 sc.Next() 350 goto exit 351 } 352 fallthrough 353 case '-': 354 if len(set.Classes) > 0 { 355 sc.Next() 356 isrange = true 357 continue 358 } 359 fallthrough 360 default: 361 set.Classes = append(set.Classes, parseClass(sc, false)) 362 } 363 if isrange { 364 begin := set.Classes[len(set.Classes)-2] 365 end := set.Classes[len(set.Classes)-1] 366 set.Classes = set.Classes[0 : len(set.Classes)-2] 367 set.Classes = append(set.Classes, &rangeClass{begin, end}) 368 isrange = false 369 } 370 } 371 exit: 372 if isrange { 373 set.Classes = append(set.Classes, &charClass{'-'}) 374 } 375 376 return set 377 } 378 379 func parsePattern(sc *scanner, toplevel bool) *seqPattern { 380 pat := &seqPattern{} 381 if toplevel { 382 if sc.Peek() == '^' { 383 sc.Next() 384 pat.MustHead = true 385 } 386 } 387 for { 388 ch := sc.Peek() 389 switch ch { 390 case '%': 391 sc.Save() 392 sc.Next() 393 switch sc.Peek() { 394 case '0': 395 panic(newError(sc.CurrentPos(), "invalid capture index")) 396 case '1', '2', '3', '4', '5', '6', '7', '8', '9': 397 pat.Patterns = append(pat.Patterns, &numberPattern{sc.Next() - 48}) 398 case 'b': 399 sc.Next() 400 pat.Patterns = append(pat.Patterns, &bracePattern{sc.Next(), sc.Next()}) 401 default: 402 sc.Restore() 403 pat.Patterns = append(pat.Patterns, &singlePattern{parseClass(sc, true)}) 404 } 405 case '.', '[', ']': 406 pat.Patterns = append(pat.Patterns, &singlePattern{parseClass(sc, true)}) 407 // case ']': 408 // panic(newError(sc.CurrentPos(), "invalid ']'")) 409 case ')': 410 if toplevel { 411 panic(newError(sc.CurrentPos(), "invalid ')'")) 412 } 413 return pat 414 case '(': 415 sc.Next() 416 if sc.Peek() == ')' { 417 sc.Next() 418 pat.Patterns = append(pat.Patterns, &posCapPattern{}) 419 } else { 420 ret := &capPattern{parsePattern(sc, false)} 421 if sc.Peek() != ')' { 422 panic(newError(sc.CurrentPos(), "unfinished capture")) 423 } 424 sc.Next() 425 pat.Patterns = append(pat.Patterns, ret) 426 } 427 case '*', '+', '-', '?': 428 sc.Next() 429 if len(pat.Patterns) > 0 { 430 spat, ok := pat.Patterns[len(pat.Patterns)-1].(*singlePattern) 431 if ok { 432 pat.Patterns = pat.Patterns[0 : len(pat.Patterns)-1] 433 pat.Patterns = append(pat.Patterns, &repeatPattern{ch, spat.Class}) 434 continue 435 } 436 } 437 pat.Patterns = append(pat.Patterns, &singlePattern{&charClass{ch}}) 438 case '$': 439 if toplevel && (sc.NextPos() == sc.Length()-1 || sc.NextPos() == EOS) { 440 pat.MustTail = true 441 } else { 442 pat.Patterns = append(pat.Patterns, &singlePattern{&charClass{ch}}) 443 } 444 sc.Next() 445 case EOS: 446 sc.Next() 447 goto exit 448 default: 449 sc.Next() 450 pat.Patterns = append(pat.Patterns, &singlePattern{&charClass{ch}}) 451 } 452 } 453 exit: 454 return pat 455 } 456 457 type iptr struct { 458 insts []inst 459 capture int 460 } 461 462 func compilePattern(p pattern, ps ...*iptr) []inst { 463 var ptr *iptr 464 toplevel := false 465 if len(ps) == 0 { 466 toplevel = true 467 ptr = &iptr{[]inst{{opSave, nil, 0, -1}}, 2} 468 } else { 469 ptr = ps[0] 470 } 471 switch pat := p.(type) { 472 case *singlePattern: 473 ptr.insts = append(ptr.insts, inst{opChar, pat.Class, -1, -1}) 474 case *seqPattern: 475 for _, cp := range pat.Patterns { 476 compilePattern(cp, ptr) 477 } 478 case *repeatPattern: 479 idx := len(ptr.insts) 480 switch pat.Type { 481 case '*': 482 ptr.insts = append(ptr.insts, 483 inst{opSplit, nil, idx + 1, idx + 3}, 484 inst{opChar, pat.Class, -1, -1}, 485 inst{opJmp, nil, idx, -1}) 486 case '+': 487 ptr.insts = append(ptr.insts, 488 inst{opChar, pat.Class, -1, -1}, 489 inst{opSplit, nil, idx, idx + 2}) 490 case '-': 491 ptr.insts = append(ptr.insts, 492 inst{opSplit, nil, idx + 3, idx + 1}, 493 inst{opChar, pat.Class, -1, -1}, 494 inst{opJmp, nil, idx, -1}) 495 case '?': 496 ptr.insts = append(ptr.insts, 497 inst{opSplit, nil, idx + 1, idx + 2}, 498 inst{opChar, pat.Class, -1, -1}) 499 } 500 case *posCapPattern: 501 ptr.insts = append(ptr.insts, inst{opPSave, nil, ptr.capture, -1}) 502 ptr.capture += 2 503 case *capPattern: 504 c0, c1 := ptr.capture, ptr.capture+1 505 ptr.capture += 2 506 ptr.insts = append(ptr.insts, inst{opSave, nil, c0, -1}) 507 compilePattern(pat.Pattern, ptr) 508 ptr.insts = append(ptr.insts, inst{opSave, nil, c1, -1}) 509 case *bracePattern: 510 ptr.insts = append(ptr.insts, inst{opBrace, nil, pat.Begin, pat.End}) 511 case *numberPattern: 512 ptr.insts = append(ptr.insts, inst{opNumber, nil, pat.N, -1}) 513 } 514 if toplevel { 515 if p.(*seqPattern).MustTail { 516 ptr.insts = append(ptr.insts, inst{opSave, nil, 1, -1}, inst{opTailMatch, nil, -1, -1}) 517 } 518 ptr.insts = append(ptr.insts, inst{opSave, nil, 1, -1}, inst{opMatch, nil, -1, -1}) 519 } 520 return ptr.insts 521 } 522 523 /* }}} parse */ 524 525 /* VM {{{ */ 526 527 // Simple recursive virtual machine based on the 528 // "Regular Expression Matching: the Virtual Machine Approach" (https://swtch.com/~rsc/regexp/regexp2.html) 529 func recursiveVM(src []byte, insts []inst, pc, sp int, ms ...*MatchData) (bool, int, *MatchData) { 530 var m *MatchData 531 if len(ms) == 0 { 532 m = newMatchState() 533 } else { 534 m = ms[0] 535 } 536 redo: 537 inst := insts[pc] 538 switch inst.OpCode { 539 case opChar: 540 if sp >= len(src) || !inst.Class.Matches(int(src[sp])) { 541 return false, sp, m 542 } 543 pc++ 544 sp++ 545 goto redo 546 case opMatch: 547 return true, sp, m 548 case opTailMatch: 549 return sp >= len(src), sp, m 550 case opJmp: 551 pc = inst.Operand1 552 goto redo 553 case opSplit: 554 if ok, nsp, _ := recursiveVM(src, insts, inst.Operand1, sp, m); ok { 555 return true, nsp, m 556 } 557 pc = inst.Operand2 558 goto redo 559 case opSave: 560 s := m.setCapture(inst.Operand1, sp) 561 if ok, nsp, _ := recursiveVM(src, insts, pc+1, sp, m); ok { 562 return true, nsp, m 563 } 564 m.restoreCapture(inst.Operand1, s) 565 return false, sp, m 566 case opPSave: 567 m.addPosCapture(inst.Operand1, sp+1) 568 pc++ 569 goto redo 570 case opBrace: 571 if sp >= len(src) || int(src[sp]) != inst.Operand1 { 572 return false, sp, m 573 } 574 count := 1 575 for sp = sp + 1; sp < len(src); sp++ { 576 if int(src[sp]) == inst.Operand2 { 577 count-- 578 } 579 if count == 0 { 580 pc++ 581 sp++ 582 goto redo 583 } 584 if int(src[sp]) == inst.Operand1 { 585 count++ 586 } 587 } 588 return false, sp, m 589 case opNumber: 590 idx := inst.Operand1 * 2 591 if idx >= m.CaptureLength()-1 { 592 panic(newError(_UNKNOWN, "invalid capture index")) 593 } 594 capture := src[m.Capture(idx):m.Capture(idx+1)] 595 for i := 0; i < len(capture); i++ { 596 if i+sp >= len(src) || capture[i] != src[i+sp] { 597 return false, sp, m 598 } 599 } 600 pc++ 601 sp += len(capture) 602 goto redo 603 } 604 panic("should not reach here") 605 } 606 607 /* }}} */ 608 609 /* API {{{ */ 610 611 func Find(p string, src []byte, offset, limit int) (matches []*MatchData, err error) { 612 defer func() { 613 if v := recover(); v != nil { 614 if perr, ok := v.(*Error); ok { 615 err = perr 616 } else { 617 panic(v) 618 } 619 } 620 }() 621 pat := parsePattern(newScanner([]byte(p)), true) 622 insts := compilePattern(pat) 623 matches = []*MatchData{} 624 for sp := offset; sp <= len(src); { 625 ok, nsp, ms := recursiveVM(src, insts, 0, sp) 626 sp++ 627 if ok { 628 if sp < nsp { 629 sp = nsp 630 } 631 matches = append(matches, ms) 632 } 633 if len(matches) == limit || pat.MustHead { 634 break 635 } 636 } 637 return 638 } 639 640 /* }}} */