github.com/shawnclovie/gopher-lua@v0.0.0-20200520092726-90b44ec0e2f2/pm/pm.go (about) 1 // Lua pattern match functions for Go 2 package pm 3 4 import ( 5 "fmt" 6 ) 7 8 const EOS = -1 9 const _UNKNOWN = -2 10 11 /* Error {{{ */ 12 13 type Error struct { 14 Pos int 15 Message string 16 } 17 18 func newError(pos int, message string, args ...interface{}) *Error { 19 if len(args) == 0 { 20 return &Error{pos, message} 21 } 22 return &Error{pos, fmt.Sprintf(message, args...)} 23 } 24 25 func (e *Error) Error() string { 26 switch e.Pos { 27 case EOS: 28 return fmt.Sprintf("%s at EOS", e.Message) 29 case _UNKNOWN: 30 return fmt.Sprintf("%s", e.Message) 31 default: 32 return fmt.Sprintf("%s at %d", e.Message, e.Pos) 33 } 34 } 35 36 /* }}} */ 37 38 /* MatchData {{{ */ 39 40 type MatchData struct { 41 // captured positions 42 // layout 43 // xxxx xxxx xxxx xxx0 : caputured positions 44 // xxxx xxxx xxxx xxx1 : position captured positions 45 captures []uint32 46 } 47 48 func newMatchState() *MatchData { return &MatchData{[]uint32{}} } 49 50 func (st *MatchData) addPosCapture(s, pos int) { 51 for s+1 >= len(st.captures) { 52 st.captures = append(st.captures, 0) 53 } 54 st.captures[s] = (uint32(pos) << 1) | 1 55 st.captures[s+1] = (uint32(pos) << 1) | 1 56 } 57 58 func (st *MatchData) setCapture(s, pos int) uint32 { 59 for s >= len(st.captures) { 60 st.captures = append(st.captures, 0) 61 } 62 v := st.captures[s] 63 st.captures[s] = (uint32(pos) << 1) 64 return v 65 } 66 67 func (st *MatchData) restoreCapture(s int, pos uint32) { st.captures[s] = pos } 68 69 func (st *MatchData) CaptureLength() int { return len(st.captures) } 70 71 func (st *MatchData) IsPosCapture(idx int) bool { return (st.captures[idx] & 1) == 1 } 72 73 func (st *MatchData) Capture(idx int) int { return int(st.captures[idx] >> 1) } 74 75 /* }}} */ 76 77 /* scanner {{{ */ 78 79 type scannerState struct { 80 Pos int 81 started bool 82 } 83 84 type scanner struct { 85 src []byte 86 State scannerState 87 saved scannerState 88 } 89 90 func newScanner(src []byte) *scanner { 91 return &scanner{ 92 src: src, 93 State: scannerState{ 94 Pos: 0, 95 started: false, 96 }, 97 saved: scannerState{}, 98 } 99 } 100 101 func (sc *scanner) Length() int { return len(sc.src) } 102 103 func (sc *scanner) Next() int { 104 if !sc.State.started { 105 sc.State.started = true 106 if len(sc.src) == 0 { 107 sc.State.Pos = EOS 108 } 109 } else { 110 sc.State.Pos = sc.NextPos() 111 } 112 if sc.State.Pos == EOS { 113 return EOS 114 } 115 return int(sc.src[sc.State.Pos]) 116 } 117 118 func (sc *scanner) CurrentPos() int { 119 return sc.State.Pos 120 } 121 122 func (sc *scanner) NextPos() int { 123 if sc.State.Pos == EOS || sc.State.Pos >= len(sc.src)-1 { 124 return EOS 125 } 126 if !sc.State.started { 127 return 0 128 } else { 129 return sc.State.Pos + 1 130 } 131 } 132 133 func (sc *scanner) Peek() int { 134 cureof := sc.State.Pos == EOS 135 ch := sc.Next() 136 if !cureof { 137 if sc.State.Pos == EOS { 138 sc.State.Pos = len(sc.src) - 1 139 } else { 140 sc.State.Pos-- 141 if sc.State.Pos < 0 { 142 sc.State.Pos = 0 143 sc.State.started = false 144 } 145 } 146 } 147 return ch 148 } 149 150 func (sc *scanner) Save() { sc.saved = sc.State } 151 152 func (sc *scanner) Restore() { sc.State = sc.saved } 153 154 /* }}} */ 155 156 /* bytecode {{{ */ 157 158 type opCode int 159 160 const ( 161 opChar opCode = iota 162 opMatch 163 opTailMatch 164 opJmp 165 opSplit 166 opSave 167 opPSave 168 opBrace 169 opNumber 170 ) 171 172 type inst struct { 173 OpCode opCode 174 Class class 175 Operand1 int 176 Operand2 int 177 } 178 179 /* }}} */ 180 181 /* classes {{{ */ 182 183 type class interface { 184 Matches(ch int) bool 185 } 186 187 type dotClass struct{} 188 189 func (pn *dotClass) Matches(ch int) bool { return true } 190 191 type charClass struct { 192 Ch int 193 } 194 195 func (pn *charClass) Matches(ch int) bool { return pn.Ch == ch } 196 197 type singleClass struct { 198 Class int 199 } 200 201 func (pn *singleClass) Matches(ch int) bool { 202 ret := false 203 switch pn.Class { 204 case 'a', 'A': 205 ret = 'A' <= ch && ch <= 'Z' || 'a' <= ch && ch <= 'z' 206 case 'c', 'C': 207 ret = (0x00 <= ch && ch <= 0x1F) || ch == 0x7F 208 case 'd', 'D': 209 ret = '0' <= ch && ch <= '9' 210 case 'l', 'L': 211 ret = 'a' <= ch && ch <= 'z' 212 case 'p', 'P': 213 ret = (0x21 <= ch && ch <= 0x2f) || (0x30 <= ch && ch <= 0x40) || (0x5b <= ch && ch <= 0x60) || (0x7b <= ch && ch <= 0x7e) 214 case 's', 'S': 215 switch ch { 216 case ' ', '\f', '\n', '\r', '\t', '\v': 217 ret = true 218 } 219 case 'u', 'U': 220 ret = 'A' <= ch && ch <= 'Z' 221 case 'w', 'W': 222 ret = '0' <= ch && ch <= '9' || 'A' <= ch && ch <= 'Z' || 'a' <= ch && ch <= 'z' 223 case 'x', 'X': 224 ret = '0' <= ch && ch <= '9' || 'a' <= ch && ch <= 'f' || 'A' <= ch && ch <= 'F' 225 case 'z', 'Z': 226 ret = ch == 0 227 default: 228 return ch == pn.Class 229 } 230 if 'A' <= pn.Class && pn.Class <= 'Z' { 231 return !ret 232 } 233 return ret 234 } 235 236 type setClass struct { 237 IsNot bool 238 Classes []class 239 } 240 241 func (pn *setClass) Matches(ch int) bool { 242 for _, class := range pn.Classes { 243 if class.Matches(ch) { 244 return !pn.IsNot 245 } 246 } 247 return pn.IsNot 248 } 249 250 type rangeClass struct { 251 Begin class 252 End class 253 } 254 255 func (pn *rangeClass) Matches(ch int) bool { 256 switch begin := pn.Begin.(type) { 257 case *charClass: 258 end, ok := pn.End.(*charClass) 259 if !ok { 260 return false 261 } 262 return begin.Ch <= ch && ch <= end.Ch 263 } 264 return false 265 } 266 267 // }}} 268 269 // patterns {{{ 270 271 type pattern interface{} 272 273 type singlePattern struct { 274 Class class 275 } 276 277 type seqPattern struct { 278 MustHead bool 279 MustTail bool 280 Patterns []pattern 281 } 282 283 type repeatPattern struct { 284 Type int 285 Class class 286 } 287 288 type posCapPattern struct{} 289 290 type capPattern struct { 291 Pattern pattern 292 } 293 294 type numberPattern struct { 295 N int 296 } 297 298 type bracePattern struct { 299 Begin int 300 End int 301 } 302 303 // }}} 304 305 /* parse {{{ */ 306 307 func parseClass(sc *scanner, allowset bool) class { 308 ch := sc.Next() 309 switch ch { 310 case '%': 311 return &singleClass{sc.Next()} 312 case '.': 313 if allowset { 314 return &dotClass{} 315 } 316 return &charClass{ch} 317 case '[': 318 if allowset { 319 return parseClassSet(sc) 320 } 321 return &charClass{ch} 322 //case '^' '$', '(', ')', ']', '*', '+', '-', '?': 323 // panic(newError(sc.CurrentPos(), "invalid %c", ch)) 324 case EOS: 325 panic(newError(sc.CurrentPos(), "unexpected EOS")) 326 default: 327 return &charClass{ch} 328 } 329 } 330 331 func parseClassSet(sc *scanner) class { 332 set := &setClass{false, []class{}} 333 if sc.Peek() == '^' { 334 set.IsNot = true 335 sc.Next() 336 } 337 isrange := false 338 for { 339 ch := sc.Peek() 340 switch ch { 341 // case '[': 342 // panic(newError(sc.CurrentPos(), "'[' can not be nested")) 343 case EOS: 344 panic(newError(sc.CurrentPos(), "unexpected EOS")) 345 case ']': 346 if len(set.Classes) > 0 { 347 sc.Next() 348 goto exit 349 } 350 fallthrough 351 case '-': 352 if len(set.Classes) > 0 { 353 sc.Next() 354 isrange = true 355 continue 356 } 357 fallthrough 358 default: 359 set.Classes = append(set.Classes, parseClass(sc, false)) 360 } 361 if isrange { 362 begin := set.Classes[len(set.Classes)-2] 363 end := set.Classes[len(set.Classes)-1] 364 set.Classes = set.Classes[0 : len(set.Classes)-2] 365 set.Classes = append(set.Classes, &rangeClass{begin, end}) 366 isrange = false 367 } 368 } 369 exit: 370 if isrange { 371 set.Classes = append(set.Classes, &charClass{'-'}) 372 } 373 374 return set 375 } 376 377 func parsePattern(sc *scanner, toplevel bool) *seqPattern { 378 pat := &seqPattern{} 379 if toplevel { 380 if sc.Peek() == '^' { 381 sc.Next() 382 pat.MustHead = true 383 } 384 } 385 for { 386 ch := sc.Peek() 387 switch ch { 388 case '%': 389 sc.Save() 390 sc.Next() 391 switch sc.Peek() { 392 case '0': 393 panic(newError(sc.CurrentPos(), "invalid capture index")) 394 case '1', '2', '3', '4', '5', '6', '7', '8', '9': 395 pat.Patterns = append(pat.Patterns, &numberPattern{sc.Next() - 48}) 396 case 'b': 397 sc.Next() 398 pat.Patterns = append(pat.Patterns, &bracePattern{sc.Next(), sc.Next()}) 399 default: 400 sc.Restore() 401 pat.Patterns = append(pat.Patterns, &singlePattern{parseClass(sc, true)}) 402 } 403 case '.', '[', ']': 404 pat.Patterns = append(pat.Patterns, &singlePattern{parseClass(sc, true)}) 405 //case ']': 406 // panic(newError(sc.CurrentPos(), "invalid ']'")) 407 case ')': 408 if toplevel { 409 panic(newError(sc.CurrentPos(), "invalid ')'")) 410 } 411 return pat 412 case '(': 413 sc.Next() 414 if sc.Peek() == ')' { 415 sc.Next() 416 pat.Patterns = append(pat.Patterns, &posCapPattern{}) 417 } else { 418 ret := &capPattern{parsePattern(sc, false)} 419 if sc.Peek() != ')' { 420 panic(newError(sc.CurrentPos(), "unfinished capture")) 421 } 422 sc.Next() 423 pat.Patterns = append(pat.Patterns, ret) 424 } 425 case '*', '+', '-', '?': 426 sc.Next() 427 if len(pat.Patterns) > 0 { 428 spat, ok := pat.Patterns[len(pat.Patterns)-1].(*singlePattern) 429 if ok { 430 pat.Patterns = pat.Patterns[0 : len(pat.Patterns)-1] 431 pat.Patterns = append(pat.Patterns, &repeatPattern{ch, spat.Class}) 432 continue 433 } 434 } 435 pat.Patterns = append(pat.Patterns, &singlePattern{&charClass{ch}}) 436 case '$': 437 if toplevel && (sc.NextPos() == sc.Length()-1 || sc.NextPos() == EOS) { 438 pat.MustTail = true 439 } else { 440 pat.Patterns = append(pat.Patterns, &singlePattern{&charClass{ch}}) 441 } 442 sc.Next() 443 case EOS: 444 sc.Next() 445 goto exit 446 default: 447 sc.Next() 448 pat.Patterns = append(pat.Patterns, &singlePattern{&charClass{ch}}) 449 } 450 } 451 exit: 452 return pat 453 } 454 455 type iptr struct { 456 insts []inst 457 capture int 458 } 459 460 func compilePattern(p pattern, ps ...*iptr) []inst { 461 var ptr *iptr 462 toplevel := false 463 if len(ps) == 0 { 464 toplevel = true 465 ptr = &iptr{[]inst{inst{opSave, nil, 0, -1}}, 2} 466 } else { 467 ptr = ps[0] 468 } 469 switch pat := p.(type) { 470 case *singlePattern: 471 ptr.insts = append(ptr.insts, inst{opChar, pat.Class, -1, -1}) 472 case *seqPattern: 473 for _, cp := range pat.Patterns { 474 compilePattern(cp, ptr) 475 } 476 case *repeatPattern: 477 idx := len(ptr.insts) 478 switch pat.Type { 479 case '*': 480 ptr.insts = append(ptr.insts, 481 inst{opSplit, nil, idx + 1, idx + 3}, 482 inst{opChar, pat.Class, -1, -1}, 483 inst{opJmp, nil, idx, -1}) 484 case '+': 485 ptr.insts = append(ptr.insts, 486 inst{opChar, pat.Class, -1, -1}, 487 inst{opSplit, nil, idx, idx + 2}) 488 case '-': 489 ptr.insts = append(ptr.insts, 490 inst{opSplit, nil, idx + 3, idx + 1}, 491 inst{opChar, pat.Class, -1, -1}, 492 inst{opJmp, nil, idx, -1}) 493 case '?': 494 ptr.insts = append(ptr.insts, 495 inst{opSplit, nil, idx + 1, idx + 2}, 496 inst{opChar, pat.Class, -1, -1}) 497 } 498 case *posCapPattern: 499 ptr.insts = append(ptr.insts, inst{opPSave, nil, ptr.capture, -1}) 500 ptr.capture += 2 501 case *capPattern: 502 c0, c1 := ptr.capture, ptr.capture+1 503 ptr.capture += 2 504 ptr.insts = append(ptr.insts, inst{opSave, nil, c0, -1}) 505 compilePattern(pat.Pattern, ptr) 506 ptr.insts = append(ptr.insts, inst{opSave, nil, c1, -1}) 507 case *bracePattern: 508 ptr.insts = append(ptr.insts, inst{opBrace, nil, pat.Begin, pat.End}) 509 case *numberPattern: 510 ptr.insts = append(ptr.insts, inst{opNumber, nil, pat.N, -1}) 511 } 512 if toplevel { 513 if p.(*seqPattern).MustTail { 514 ptr.insts = append(ptr.insts, inst{opSave, nil, 1, -1}, inst{opTailMatch, nil, -1, -1}) 515 } 516 ptr.insts = append(ptr.insts, inst{opSave, nil, 1, -1}, inst{opMatch, nil, -1, -1}) 517 } 518 return ptr.insts 519 } 520 521 /* }}} parse */ 522 523 /* VM {{{ */ 524 525 // Simple recursive virtual machine based on the 526 // "Regular Expression Matching: the Virtual Machine Approach" (https://swtch.com/~rsc/regexp/regexp2.html) 527 func recursiveVM(src []byte, insts []inst, pc, sp int, ms ...*MatchData) (bool, int, *MatchData) { 528 var m *MatchData 529 if len(ms) == 0 { 530 m = newMatchState() 531 } else { 532 m = ms[0] 533 } 534 redo: 535 inst := insts[pc] 536 switch inst.OpCode { 537 case opChar: 538 if sp >= len(src) || !inst.Class.Matches(int(src[sp])) { 539 return false, sp, m 540 } 541 pc++ 542 sp++ 543 goto redo 544 case opMatch: 545 return true, sp, m 546 case opTailMatch: 547 return sp >= len(src), sp, m 548 case opJmp: 549 pc = inst.Operand1 550 goto redo 551 case opSplit: 552 if ok, nsp, _ := recursiveVM(src, insts, inst.Operand1, sp, m); ok { 553 return true, nsp, m 554 } 555 pc = inst.Operand2 556 goto redo 557 case opSave: 558 s := m.setCapture(inst.Operand1, sp) 559 if ok, nsp, _ := recursiveVM(src, insts, pc+1, sp, m); ok { 560 return true, nsp, m 561 } 562 m.restoreCapture(inst.Operand1, s) 563 return false, sp, m 564 case opPSave: 565 m.addPosCapture(inst.Operand1, sp+1) 566 pc++ 567 goto redo 568 case opBrace: 569 if sp >= len(src) || int(src[sp]) != inst.Operand1 { 570 return false, sp, m 571 } 572 count := 1 573 for sp = sp + 1; sp < len(src); sp++ { 574 if int(src[sp]) == inst.Operand2 { 575 count-- 576 } 577 if count == 0 { 578 pc++ 579 sp++ 580 goto redo 581 } 582 if int(src[sp]) == inst.Operand1 { 583 count++ 584 } 585 } 586 return false, sp, m 587 case opNumber: 588 idx := inst.Operand1 * 2 589 if idx >= m.CaptureLength()-1 { 590 panic(newError(_UNKNOWN, "invalid capture index")) 591 } 592 capture := src[m.Capture(idx):m.Capture(idx+1)] 593 for i := 0; i < len(capture); i++ { 594 if i+sp >= len(src) || capture[i] != src[i+sp] { 595 return false, sp, m 596 } 597 } 598 pc++ 599 sp += len(capture) 600 goto redo 601 } 602 panic("should not reach here") 603 } 604 605 /* }}} */ 606 607 /* API {{{ */ 608 609 func Find(p string, src []byte, offset, limit int) (matches []*MatchData, err error) { 610 defer func() { 611 if v := recover(); v != nil { 612 if perr, ok := v.(*Error); ok { 613 err = perr 614 } else { 615 panic(v) 616 } 617 } 618 }() 619 pat := parsePattern(newScanner([]byte(p)), true) 620 insts := compilePattern(pat) 621 matches = []*MatchData{} 622 for sp := offset; sp <= len(src); { 623 ok, nsp, ms := recursiveVM(src, insts, 0, sp) 624 sp++ 625 if ok { 626 if sp < nsp { 627 sp = nsp 628 } 629 matches = append(matches, ms) 630 } 631 if len(matches) == limit || pat.MustHead { 632 break 633 } 634 } 635 return 636 } 637 638 /* }}} */