github.com/jslyzt/glua@v0.0.0-20210819023911-4030c8e0234a/pm/pm.go (about) 1 package pm 2 3 import ( 4 "fmt" 5 ) 6 7 // 常量定义 8 const ( 9 EOS = -1 10 UNKNOWN = -2 11 ) 12 13 // Error 错误 14 type Error struct { 15 Pos int 16 Message string 17 } 18 19 func newError(pos int, message string, args ...interface{}) *Error { 20 if len(args) == 0 { 21 return &Error{pos, message} 22 } 23 return &Error{pos, fmt.Sprintf(message, args...)} 24 } 25 26 func (e *Error) Error() string { 27 switch e.Pos { 28 case EOS: 29 return fmt.Sprintf("%s at EOS", e.Message) 30 case UNKNOWN: 31 return e.Message 32 default: 33 return fmt.Sprintf("%s at %d", e.Message, e.Pos) 34 } 35 } 36 37 // MatchData 匹配数据 38 type MatchData struct { 39 // captured positions 40 // layout 41 // xxxx xxxx xxxx xxx0 : caputured positions 42 // xxxx xxxx xxxx xxx1 : position captured positions 43 captures []uint32 44 } 45 46 func newMatchState() *MatchData { return &MatchData{[]uint32{}} } 47 48 func (st *MatchData) addPosCapture(s, pos int) { 49 for s+1 >= len(st.captures) { 50 st.captures = append(st.captures, 0) 51 } 52 st.captures[s] = (uint32(pos) << 1) | 1 53 st.captures[s+1] = (uint32(pos) << 1) | 1 54 } 55 56 func (st *MatchData) setCapture(s, pos int) uint32 { 57 for s >= len(st.captures) { 58 st.captures = append(st.captures, 0) 59 } 60 v := st.captures[s] 61 st.captures[s] = (uint32(pos) << 1) 62 return v 63 } 64 65 func (st *MatchData) restoreCapture(s int, pos uint32) { 66 st.captures[s] = pos 67 } 68 69 // CaptureLength 长度 70 func (st *MatchData) CaptureLength() int { 71 return len(st.captures) 72 } 73 74 // IsPosCapture is pos capture 75 func (st *MatchData) IsPosCapture(idx int) bool { 76 return (st.captures[idx] & 1) == 1 77 } 78 79 // Capture capture 80 func (st *MatchData) Capture(idx int) int { 81 return int(st.captures[idx] >> 1) 82 } 83 84 type scannerState struct { 85 Pos int 86 started bool 87 } 88 89 type scanner struct { 90 src []byte 91 State scannerState 92 saved scannerState 93 } 94 95 func newScanner(src []byte) *scanner { 96 return &scanner{ 97 src: src, 98 State: scannerState{ 99 Pos: 0, 100 started: false, 101 }, 102 saved: scannerState{}, 103 } 104 } 105 106 func (sc *scanner) Length() int { return len(sc.src) } 107 108 func (sc *scanner) Next() int { 109 if !sc.State.started { 110 sc.State.started = true 111 if len(sc.src) == 0 { 112 sc.State.Pos = EOS 113 } 114 } else { 115 sc.State.Pos = sc.NextPos() 116 } 117 if sc.State.Pos == EOS { 118 return EOS 119 } 120 return int(sc.src[sc.State.Pos]) 121 } 122 123 func (sc *scanner) CurrentPos() int { 124 return sc.State.Pos 125 } 126 127 func (sc *scanner) NextPos() int { 128 if sc.State.Pos == EOS || sc.State.Pos >= len(sc.src)-1 { 129 return EOS 130 } 131 if !sc.State.started { 132 return 0 133 } 134 return sc.State.Pos + 1 135 } 136 137 func (sc *scanner) Peek() int { 138 cureof := sc.State.Pos == EOS 139 ch := sc.Next() 140 if !cureof { 141 if sc.State.Pos == EOS { 142 sc.State.Pos = len(sc.src) - 1 143 } else { 144 sc.State.Pos-- 145 if sc.State.Pos < 0 { 146 sc.State.Pos = 0 147 sc.State.started = false 148 } 149 } 150 } 151 return ch 152 } 153 154 func (sc *scanner) Save() { sc.saved = sc.State } 155 156 func (sc *scanner) Restore() { sc.State = sc.saved } 157 158 ////////////////////////////////////////////////////////////////////////////////////////////////// 159 type opCode int 160 161 const ( 162 opChar opCode = iota 163 opMatch 164 opTailMatch 165 opJmp 166 opSplit 167 opSave 168 opPSave 169 opBrace 170 opNumber 171 ) 172 173 type inst struct { 174 OpCode opCode 175 Class class 176 Operand1 int 177 Operand2 int 178 } 179 180 type class interface { 181 Matches(ch int) bool 182 } 183 184 ////////////////////////////////////////////////////////////////////////////////////////////////// 185 type dotClass struct{} 186 187 // Matches 匹配 188 func (pn *dotClass) Matches(ch int) bool { 189 return true 190 } 191 192 ////////////////////////////////////////////////////////////////////////////////////////////////// 193 type charClass struct { 194 Ch int 195 } 196 197 // Matches 匹配 198 func (pn *charClass) Matches(ch int) bool { 199 return pn.Ch == ch 200 } 201 202 ////////////////////////////////////////////////////////////////////////////////////////////////// 203 type singleClass struct { 204 Class int 205 } 206 207 // Matches 匹配 208 func (pn *singleClass) Matches(ch int) bool { 209 ret := false 210 switch pn.Class { 211 case 'a', 'A': 212 ret = 'A' <= ch && ch <= 'Z' || 'a' <= ch && ch <= 'z' 213 case 'c', 'C': 214 ret = (0x00 <= ch && ch <= 0x1F) || ch == 0x7F 215 case 'd', 'D': 216 ret = '0' <= ch && ch <= '9' 217 case 'l', 'L': 218 ret = 'a' <= ch && ch <= 'z' 219 case 'p', 'P': 220 ret = (0x21 <= ch && ch <= 0x2f) || (0x30 <= ch && ch <= 0x40) || (0x5b <= ch && ch <= 0x60) || (0x7b <= ch && ch <= 0x7e) 221 case 's', 'S': 222 switch ch { 223 case ' ', '\f', '\n', '\r', '\t', '\v': 224 ret = true 225 } 226 case 'u', 'U': 227 ret = 'A' <= ch && ch <= 'Z' 228 case 'w', 'W': 229 ret = '0' <= ch && ch <= '9' || 'A' <= ch && ch <= 'Z' || 'a' <= ch && ch <= 'z' 230 case 'x', 'X': 231 ret = '0' <= ch && ch <= '9' || 'a' <= ch && ch <= 'f' || 'A' <= ch && ch <= 'F' 232 case 'z', 'Z': 233 ret = ch == 0 234 default: 235 return ch == pn.Class 236 } 237 if 'A' <= pn.Class && pn.Class <= 'Z' { 238 return !ret 239 } 240 return ret 241 } 242 243 ////////////////////////////////////////////////////////////////////////////////////////////////// 244 type setClass struct { 245 IsNot bool 246 Classes []class 247 } 248 249 // Matches 匹配 250 func (pn *setClass) Matches(ch int) bool { 251 for _, class := range pn.Classes { 252 if class.Matches(ch) { 253 return !pn.IsNot 254 } 255 } 256 return pn.IsNot 257 } 258 259 ////////////////////////////////////////////////////////////////////////////////////////////////// 260 type rangeClass struct { 261 Begin class 262 End class 263 } 264 265 // Matches 匹配 266 func (pn *rangeClass) Matches(ch int) bool { 267 switch begin := pn.Begin.(type) { 268 case *charClass: 269 end, ok := pn.End.(*charClass) 270 if !ok { 271 return false 272 } 273 return begin.Ch <= ch && ch <= end.Ch 274 } 275 return false 276 } 277 278 ////////////////////////////////////////////////////////////////////////////////////////////////// 279 type pattern interface{} 280 281 type singlePattern struct { 282 Class class 283 } 284 285 type seqPattern struct { 286 MustHead bool 287 MustTail bool 288 Patterns []pattern 289 } 290 291 type repeatPattern struct { 292 Type int 293 Class class 294 } 295 296 type posCapPattern struct{} 297 298 type capPattern struct { 299 Pattern pattern 300 } 301 302 type numberPattern struct { 303 N int 304 } 305 306 type bracePattern struct { 307 Begin int 308 End int 309 } 310 311 func parseClass(sc *scanner, allowset bool) class { 312 ch := sc.Next() 313 switch ch { 314 case '%': 315 return &singleClass{sc.Next()} 316 case '.': 317 if allowset { 318 return &dotClass{} 319 } 320 return &charClass{ch} 321 case '[': 322 if allowset { 323 return parseClassSet(sc) 324 } 325 return &charClass{ch} 326 //case '^' '$', '(', ')', ']', '*', '+', '-', '?': 327 // panic(newError(sc.CurrentPos(), "invalid %c", ch)) 328 case EOS: 329 panic(newError(sc.CurrentPos(), "unexpected EOS")) 330 default: 331 return &charClass{ch} 332 } 333 } 334 335 func parseClassSet(sc *scanner) class { 336 set := &setClass{false, []class{}} 337 if sc.Peek() == '^' { 338 set.IsNot = true 339 sc.Next() 340 } 341 isrange := false 342 for { 343 ch := sc.Peek() 344 switch ch { 345 // case '[': 346 // panic(newError(sc.CurrentPos(), "'[' can not be nested")) 347 case EOS: 348 panic(newError(sc.CurrentPos(), "unexpected EOS")) 349 case ']': 350 if len(set.Classes) > 0 { 351 sc.Next() 352 goto exit 353 } 354 fallthrough 355 case '-': 356 if len(set.Classes) > 0 { 357 sc.Next() 358 isrange = true 359 continue 360 } 361 fallthrough 362 default: 363 set.Classes = append(set.Classes, parseClass(sc, false)) 364 } 365 if isrange { 366 begin := set.Classes[len(set.Classes)-2] 367 end := set.Classes[len(set.Classes)-1] 368 set.Classes = set.Classes[0 : len(set.Classes)-2] 369 set.Classes = append(set.Classes, &rangeClass{begin, end}) 370 isrange = false 371 } 372 } 373 exit: 374 if isrange { 375 set.Classes = append(set.Classes, &charClass{'-'}) 376 } 377 378 return set 379 } 380 381 func parsePattern(sc *scanner, toplevel bool) *seqPattern { 382 pat := &seqPattern{} 383 if toplevel { 384 if sc.Peek() == '^' { 385 sc.Next() 386 pat.MustHead = true 387 } 388 } 389 for { 390 ch := sc.Peek() 391 switch ch { 392 case '%': 393 sc.Save() 394 sc.Next() 395 switch sc.Peek() { 396 case '0': 397 panic(newError(sc.CurrentPos(), "invalid capture index")) 398 case '1', '2', '3', '4', '5', '6', '7', '8', '9': 399 pat.Patterns = append(pat.Patterns, &numberPattern{sc.Next() - 48}) 400 case 'b': 401 sc.Next() 402 pat.Patterns = append(pat.Patterns, &bracePattern{sc.Next(), sc.Next()}) 403 default: 404 sc.Restore() 405 pat.Patterns = append(pat.Patterns, &singlePattern{parseClass(sc, true)}) 406 } 407 case '.', '[', ']': 408 pat.Patterns = append(pat.Patterns, &singlePattern{parseClass(sc, true)}) 409 //case ']': 410 // panic(newError(sc.CurrentPos(), "invalid ']'")) 411 case ')': 412 if toplevel { 413 panic(newError(sc.CurrentPos(), "invalid ')'")) 414 } 415 return pat 416 case '(': 417 sc.Next() 418 if sc.Peek() == ')' { 419 sc.Next() 420 pat.Patterns = append(pat.Patterns, &posCapPattern{}) 421 } else { 422 ret := &capPattern{parsePattern(sc, false)} 423 if sc.Peek() != ')' { 424 panic(newError(sc.CurrentPos(), "unfinished capture")) 425 } 426 sc.Next() 427 pat.Patterns = append(pat.Patterns, ret) 428 } 429 case '*', '+', '-', '?': 430 sc.Next() 431 if len(pat.Patterns) > 0 { 432 spat, ok := pat.Patterns[len(pat.Patterns)-1].(*singlePattern) 433 if ok { 434 pat.Patterns = pat.Patterns[0 : len(pat.Patterns)-1] 435 pat.Patterns = append(pat.Patterns, &repeatPattern{ch, spat.Class}) 436 continue 437 } 438 } 439 pat.Patterns = append(pat.Patterns, &singlePattern{&charClass{ch}}) 440 case '$': 441 if toplevel && (sc.NextPos() == sc.Length()-1 || sc.NextPos() == EOS) { 442 pat.MustTail = true 443 } else { 444 pat.Patterns = append(pat.Patterns, &singlePattern{&charClass{ch}}) 445 } 446 sc.Next() 447 case EOS: 448 sc.Next() 449 goto exit 450 default: 451 sc.Next() 452 pat.Patterns = append(pat.Patterns, &singlePattern{&charClass{ch}}) 453 } 454 } 455 exit: 456 return pat 457 } 458 459 type iptr struct { 460 insts []inst 461 capture int 462 } 463 464 func compilePattern(p pattern, ps ...*iptr) []inst { 465 var ptr *iptr 466 toplevel := false 467 if len(ps) == 0 { 468 toplevel = true 469 ptr = &iptr{[]inst{{opSave, nil, 0, -1}}, 2} 470 } else { 471 ptr = ps[0] 472 } 473 switch pat := p.(type) { 474 case *singlePattern: 475 ptr.insts = append(ptr.insts, inst{opChar, pat.Class, -1, -1}) 476 case *seqPattern: 477 for _, cp := range pat.Patterns { 478 compilePattern(cp, ptr) 479 } 480 case *repeatPattern: 481 idx := len(ptr.insts) 482 switch pat.Type { 483 case '*': 484 ptr.insts = append(ptr.insts, 485 inst{opSplit, nil, idx + 1, idx + 3}, 486 inst{opChar, pat.Class, -1, -1}, 487 inst{opJmp, nil, idx, -1}) 488 case '+': 489 ptr.insts = append(ptr.insts, 490 inst{opChar, pat.Class, -1, -1}, 491 inst{opSplit, nil, idx, idx + 2}) 492 case '-': 493 ptr.insts = append(ptr.insts, 494 inst{opSplit, nil, idx + 3, idx + 1}, 495 inst{opChar, pat.Class, -1, -1}, 496 inst{opJmp, nil, idx, -1}) 497 case '?': 498 ptr.insts = append(ptr.insts, 499 inst{opSplit, nil, idx + 1, idx + 2}, 500 inst{opChar, pat.Class, -1, -1}) 501 } 502 case *posCapPattern: 503 ptr.insts = append(ptr.insts, inst{opPSave, nil, ptr.capture, -1}) 504 ptr.capture += 2 505 case *capPattern: 506 c0, c1 := ptr.capture, ptr.capture+1 507 ptr.capture += 2 508 ptr.insts = append(ptr.insts, inst{opSave, nil, c0, -1}) 509 compilePattern(pat.Pattern, ptr) 510 ptr.insts = append(ptr.insts, inst{opSave, nil, c1, -1}) 511 case *bracePattern: 512 ptr.insts = append(ptr.insts, inst{opBrace, nil, pat.Begin, pat.End}) 513 case *numberPattern: 514 ptr.insts = append(ptr.insts, inst{opNumber, nil, pat.N, -1}) 515 } 516 if toplevel { 517 if p.(*seqPattern).MustTail { 518 ptr.insts = append(ptr.insts, inst{opSave, nil, 1, -1}, inst{opTailMatch, nil, -1, -1}) 519 } 520 ptr.insts = append(ptr.insts, inst{opSave, nil, 1, -1}, inst{opMatch, nil, -1, -1}) 521 } 522 return ptr.insts 523 } 524 525 // Simple recursive virtual machine based on the 526 // "Regular Expression Matching: the Virtual Machine Approach" (https://swtch.com/~rsc/regexp/regexp2.html) 527 func recursiveVM(src []byte, insts []inst, pc, sp int, ms ...*MatchData) (bool, int, *MatchData) { 528 var m *MatchData 529 if len(ms) == 0 { 530 m = newMatchState() 531 } else { 532 m = ms[0] 533 } 534 redo: 535 inst := insts[pc] 536 switch inst.OpCode { 537 case opChar: 538 if sp >= len(src) || !inst.Class.Matches(int(src[sp])) { 539 return false, sp, m 540 } 541 pc++ 542 sp++ 543 goto redo 544 case opMatch: 545 return true, sp, m 546 case opTailMatch: 547 return sp >= len(src), sp, m 548 case opJmp: 549 pc = inst.Operand1 550 goto redo 551 case opSplit: 552 if ok, nsp, _ := recursiveVM(src, insts, inst.Operand1, sp, m); ok { 553 return true, nsp, m 554 } 555 pc = inst.Operand2 556 goto redo 557 case opSave: 558 s := m.setCapture(inst.Operand1, sp) 559 if ok, nsp, _ := recursiveVM(src, insts, pc+1, sp, m); ok { 560 return true, nsp, m 561 } 562 m.restoreCapture(inst.Operand1, s) 563 return false, sp, m 564 case opPSave: 565 m.addPosCapture(inst.Operand1, sp+1) 566 pc++ 567 goto redo 568 case opBrace: 569 if sp >= len(src) || int(src[sp]) != inst.Operand1 { 570 return false, sp, m 571 } 572 count := 1 573 for sp = sp + 1; sp < len(src); sp++ { 574 if int(src[sp]) == inst.Operand2 { 575 count-- 576 } 577 if count == 0 { 578 pc++ 579 sp++ 580 goto redo 581 } 582 if int(src[sp]) == inst.Operand1 { 583 count++ 584 } 585 } 586 return false, sp, m 587 case opNumber: 588 idx := inst.Operand1 * 2 589 if idx >= m.CaptureLength()-1 { 590 panic(newError(UNKNOWN, "invalid capture index")) 591 } 592 capture := src[m.Capture(idx):m.Capture(idx+1)] 593 for i := 0; i < len(capture); i++ { 594 if i+sp >= len(src) || capture[i] != src[i+sp] { 595 return false, sp, m 596 } 597 } 598 pc++ 599 sp += len(capture) 600 goto redo 601 } 602 panic("should not reach here") 603 } 604 605 // Find 查询 606 func Find(p string, src []byte, offset, limit int) (matches []*MatchData, err error) { 607 defer func() { 608 if v := recover(); v != nil { 609 if perr, ok := v.(*Error); ok { 610 err = perr 611 } else { 612 panic(v) 613 } 614 } 615 }() 616 pat := parsePattern(newScanner([]byte(p)), true) 617 insts := compilePattern(pat) 618 matches = []*MatchData{} 619 for sp := offset; sp <= len(src); { 620 ok, nsp, ms := recursiveVM(src, insts, 0, sp) 621 sp++ 622 if ok { 623 if sp < nsp { 624 sp = nsp 625 } 626 matches = append(matches, ms) 627 } 628 if len(matches) == limit || pat.MustHead { 629 break 630 } 631 } 632 return 633 }