github.com/jslyzt/glua@v0.0.0-20210819023911-4030c8e0234a/pm/pm.go (about)

     1  package pm
     2  
     3  import (
     4  	"fmt"
     5  )
     6  
     7  // 常量定义
     8  const (
     9  	EOS     = -1
    10  	UNKNOWN = -2
    11  )
    12  
    13  // Error 错误
    14  type Error struct {
    15  	Pos     int
    16  	Message string
    17  }
    18  
    19  func newError(pos int, message string, args ...interface{}) *Error {
    20  	if len(args) == 0 {
    21  		return &Error{pos, message}
    22  	}
    23  	return &Error{pos, fmt.Sprintf(message, args...)}
    24  }
    25  
    26  func (e *Error) Error() string {
    27  	switch e.Pos {
    28  	case EOS:
    29  		return fmt.Sprintf("%s at EOS", e.Message)
    30  	case UNKNOWN:
    31  		return e.Message
    32  	default:
    33  		return fmt.Sprintf("%s at %d", e.Message, e.Pos)
    34  	}
    35  }
    36  
    37  // MatchData 匹配数据
    38  type MatchData struct {
    39  	// captured positions
    40  	// layout
    41  	// xxxx xxxx xxxx xxx0 : caputured positions
    42  	// xxxx xxxx xxxx xxx1 : position captured positions
    43  	captures []uint32
    44  }
    45  
    46  func newMatchState() *MatchData { return &MatchData{[]uint32{}} }
    47  
    48  func (st *MatchData) addPosCapture(s, pos int) {
    49  	for s+1 >= len(st.captures) {
    50  		st.captures = append(st.captures, 0)
    51  	}
    52  	st.captures[s] = (uint32(pos) << 1) | 1
    53  	st.captures[s+1] = (uint32(pos) << 1) | 1
    54  }
    55  
    56  func (st *MatchData) setCapture(s, pos int) uint32 {
    57  	for s >= len(st.captures) {
    58  		st.captures = append(st.captures, 0)
    59  	}
    60  	v := st.captures[s]
    61  	st.captures[s] = (uint32(pos) << 1)
    62  	return v
    63  }
    64  
    65  func (st *MatchData) restoreCapture(s int, pos uint32) {
    66  	st.captures[s] = pos
    67  }
    68  
    69  // CaptureLength 长度
    70  func (st *MatchData) CaptureLength() int {
    71  	return len(st.captures)
    72  }
    73  
    74  // IsPosCapture is pos capture
    75  func (st *MatchData) IsPosCapture(idx int) bool {
    76  	return (st.captures[idx] & 1) == 1
    77  }
    78  
    79  // Capture capture
    80  func (st *MatchData) Capture(idx int) int {
    81  	return int(st.captures[idx] >> 1)
    82  }
    83  
    84  type scannerState struct {
    85  	Pos     int
    86  	started bool
    87  }
    88  
    89  type scanner struct {
    90  	src   []byte
    91  	State scannerState
    92  	saved scannerState
    93  }
    94  
    95  func newScanner(src []byte) *scanner {
    96  	return &scanner{
    97  		src: src,
    98  		State: scannerState{
    99  			Pos:     0,
   100  			started: false,
   101  		},
   102  		saved: scannerState{},
   103  	}
   104  }
   105  
   106  func (sc *scanner) Length() int { return len(sc.src) }
   107  
   108  func (sc *scanner) Next() int {
   109  	if !sc.State.started {
   110  		sc.State.started = true
   111  		if len(sc.src) == 0 {
   112  			sc.State.Pos = EOS
   113  		}
   114  	} else {
   115  		sc.State.Pos = sc.NextPos()
   116  	}
   117  	if sc.State.Pos == EOS {
   118  		return EOS
   119  	}
   120  	return int(sc.src[sc.State.Pos])
   121  }
   122  
   123  func (sc *scanner) CurrentPos() int {
   124  	return sc.State.Pos
   125  }
   126  
   127  func (sc *scanner) NextPos() int {
   128  	if sc.State.Pos == EOS || sc.State.Pos >= len(sc.src)-1 {
   129  		return EOS
   130  	}
   131  	if !sc.State.started {
   132  		return 0
   133  	}
   134  	return sc.State.Pos + 1
   135  }
   136  
   137  func (sc *scanner) Peek() int {
   138  	cureof := sc.State.Pos == EOS
   139  	ch := sc.Next()
   140  	if !cureof {
   141  		if sc.State.Pos == EOS {
   142  			sc.State.Pos = len(sc.src) - 1
   143  		} else {
   144  			sc.State.Pos--
   145  			if sc.State.Pos < 0 {
   146  				sc.State.Pos = 0
   147  				sc.State.started = false
   148  			}
   149  		}
   150  	}
   151  	return ch
   152  }
   153  
   154  func (sc *scanner) Save() { sc.saved = sc.State }
   155  
   156  func (sc *scanner) Restore() { sc.State = sc.saved }
   157  
   158  //////////////////////////////////////////////////////////////////////////////////////////////////
   159  type opCode int
   160  
   161  const (
   162  	opChar opCode = iota
   163  	opMatch
   164  	opTailMatch
   165  	opJmp
   166  	opSplit
   167  	opSave
   168  	opPSave
   169  	opBrace
   170  	opNumber
   171  )
   172  
   173  type inst struct {
   174  	OpCode   opCode
   175  	Class    class
   176  	Operand1 int
   177  	Operand2 int
   178  }
   179  
   180  type class interface {
   181  	Matches(ch int) bool
   182  }
   183  
   184  //////////////////////////////////////////////////////////////////////////////////////////////////
   185  type dotClass struct{}
   186  
   187  // Matches 匹配
   188  func (pn *dotClass) Matches(ch int) bool {
   189  	return true
   190  }
   191  
   192  //////////////////////////////////////////////////////////////////////////////////////////////////
   193  type charClass struct {
   194  	Ch int
   195  }
   196  
   197  // Matches 匹配
   198  func (pn *charClass) Matches(ch int) bool {
   199  	return pn.Ch == ch
   200  }
   201  
   202  //////////////////////////////////////////////////////////////////////////////////////////////////
   203  type singleClass struct {
   204  	Class int
   205  }
   206  
   207  // Matches 匹配
   208  func (pn *singleClass) Matches(ch int) bool {
   209  	ret := false
   210  	switch pn.Class {
   211  	case 'a', 'A':
   212  		ret = 'A' <= ch && ch <= 'Z' || 'a' <= ch && ch <= 'z'
   213  	case 'c', 'C':
   214  		ret = (0x00 <= ch && ch <= 0x1F) || ch == 0x7F
   215  	case 'd', 'D':
   216  		ret = '0' <= ch && ch <= '9'
   217  	case 'l', 'L':
   218  		ret = 'a' <= ch && ch <= 'z'
   219  	case 'p', 'P':
   220  		ret = (0x21 <= ch && ch <= 0x2f) || (0x30 <= ch && ch <= 0x40) || (0x5b <= ch && ch <= 0x60) || (0x7b <= ch && ch <= 0x7e)
   221  	case 's', 'S':
   222  		switch ch {
   223  		case ' ', '\f', '\n', '\r', '\t', '\v':
   224  			ret = true
   225  		}
   226  	case 'u', 'U':
   227  		ret = 'A' <= ch && ch <= 'Z'
   228  	case 'w', 'W':
   229  		ret = '0' <= ch && ch <= '9' || 'A' <= ch && ch <= 'Z' || 'a' <= ch && ch <= 'z'
   230  	case 'x', 'X':
   231  		ret = '0' <= ch && ch <= '9' || 'a' <= ch && ch <= 'f' || 'A' <= ch && ch <= 'F'
   232  	case 'z', 'Z':
   233  		ret = ch == 0
   234  	default:
   235  		return ch == pn.Class
   236  	}
   237  	if 'A' <= pn.Class && pn.Class <= 'Z' {
   238  		return !ret
   239  	}
   240  	return ret
   241  }
   242  
   243  //////////////////////////////////////////////////////////////////////////////////////////////////
   244  type setClass struct {
   245  	IsNot   bool
   246  	Classes []class
   247  }
   248  
   249  // Matches 匹配
   250  func (pn *setClass) Matches(ch int) bool {
   251  	for _, class := range pn.Classes {
   252  		if class.Matches(ch) {
   253  			return !pn.IsNot
   254  		}
   255  	}
   256  	return pn.IsNot
   257  }
   258  
   259  //////////////////////////////////////////////////////////////////////////////////////////////////
   260  type rangeClass struct {
   261  	Begin class
   262  	End   class
   263  }
   264  
   265  // Matches 匹配
   266  func (pn *rangeClass) Matches(ch int) bool {
   267  	switch begin := pn.Begin.(type) {
   268  	case *charClass:
   269  		end, ok := pn.End.(*charClass)
   270  		if !ok {
   271  			return false
   272  		}
   273  		return begin.Ch <= ch && ch <= end.Ch
   274  	}
   275  	return false
   276  }
   277  
   278  //////////////////////////////////////////////////////////////////////////////////////////////////
   279  type pattern interface{}
   280  
   281  type singlePattern struct {
   282  	Class class
   283  }
   284  
   285  type seqPattern struct {
   286  	MustHead bool
   287  	MustTail bool
   288  	Patterns []pattern
   289  }
   290  
   291  type repeatPattern struct {
   292  	Type  int
   293  	Class class
   294  }
   295  
   296  type posCapPattern struct{}
   297  
   298  type capPattern struct {
   299  	Pattern pattern
   300  }
   301  
   302  type numberPattern struct {
   303  	N int
   304  }
   305  
   306  type bracePattern struct {
   307  	Begin int
   308  	End   int
   309  }
   310  
   311  func parseClass(sc *scanner, allowset bool) class {
   312  	ch := sc.Next()
   313  	switch ch {
   314  	case '%':
   315  		return &singleClass{sc.Next()}
   316  	case '.':
   317  		if allowset {
   318  			return &dotClass{}
   319  		}
   320  		return &charClass{ch}
   321  	case '[':
   322  		if allowset {
   323  			return parseClassSet(sc)
   324  		}
   325  		return &charClass{ch}
   326  	//case '^' '$', '(', ')', ']', '*', '+', '-', '?':
   327  	//	panic(newError(sc.CurrentPos(), "invalid %c", ch))
   328  	case EOS:
   329  		panic(newError(sc.CurrentPos(), "unexpected EOS"))
   330  	default:
   331  		return &charClass{ch}
   332  	}
   333  }
   334  
   335  func parseClassSet(sc *scanner) class {
   336  	set := &setClass{false, []class{}}
   337  	if sc.Peek() == '^' {
   338  		set.IsNot = true
   339  		sc.Next()
   340  	}
   341  	isrange := false
   342  	for {
   343  		ch := sc.Peek()
   344  		switch ch {
   345  		// case '[':
   346  		// 	panic(newError(sc.CurrentPos(), "'[' can not be nested"))
   347  		case EOS:
   348  			panic(newError(sc.CurrentPos(), "unexpected EOS"))
   349  		case ']':
   350  			if len(set.Classes) > 0 {
   351  				sc.Next()
   352  				goto exit
   353  			}
   354  			fallthrough
   355  		case '-':
   356  			if len(set.Classes) > 0 {
   357  				sc.Next()
   358  				isrange = true
   359  				continue
   360  			}
   361  			fallthrough
   362  		default:
   363  			set.Classes = append(set.Classes, parseClass(sc, false))
   364  		}
   365  		if isrange {
   366  			begin := set.Classes[len(set.Classes)-2]
   367  			end := set.Classes[len(set.Classes)-1]
   368  			set.Classes = set.Classes[0 : len(set.Classes)-2]
   369  			set.Classes = append(set.Classes, &rangeClass{begin, end})
   370  			isrange = false
   371  		}
   372  	}
   373  exit:
   374  	if isrange {
   375  		set.Classes = append(set.Classes, &charClass{'-'})
   376  	}
   377  
   378  	return set
   379  }
   380  
   381  func parsePattern(sc *scanner, toplevel bool) *seqPattern {
   382  	pat := &seqPattern{}
   383  	if toplevel {
   384  		if sc.Peek() == '^' {
   385  			sc.Next()
   386  			pat.MustHead = true
   387  		}
   388  	}
   389  	for {
   390  		ch := sc.Peek()
   391  		switch ch {
   392  		case '%':
   393  			sc.Save()
   394  			sc.Next()
   395  			switch sc.Peek() {
   396  			case '0':
   397  				panic(newError(sc.CurrentPos(), "invalid capture index"))
   398  			case '1', '2', '3', '4', '5', '6', '7', '8', '9':
   399  				pat.Patterns = append(pat.Patterns, &numberPattern{sc.Next() - 48})
   400  			case 'b':
   401  				sc.Next()
   402  				pat.Patterns = append(pat.Patterns, &bracePattern{sc.Next(), sc.Next()})
   403  			default:
   404  				sc.Restore()
   405  				pat.Patterns = append(pat.Patterns, &singlePattern{parseClass(sc, true)})
   406  			}
   407  		case '.', '[', ']':
   408  			pat.Patterns = append(pat.Patterns, &singlePattern{parseClass(sc, true)})
   409  		//case ']':
   410  		//	panic(newError(sc.CurrentPos(), "invalid ']'"))
   411  		case ')':
   412  			if toplevel {
   413  				panic(newError(sc.CurrentPos(), "invalid ')'"))
   414  			}
   415  			return pat
   416  		case '(':
   417  			sc.Next()
   418  			if sc.Peek() == ')' {
   419  				sc.Next()
   420  				pat.Patterns = append(pat.Patterns, &posCapPattern{})
   421  			} else {
   422  				ret := &capPattern{parsePattern(sc, false)}
   423  				if sc.Peek() != ')' {
   424  					panic(newError(sc.CurrentPos(), "unfinished capture"))
   425  				}
   426  				sc.Next()
   427  				pat.Patterns = append(pat.Patterns, ret)
   428  			}
   429  		case '*', '+', '-', '?':
   430  			sc.Next()
   431  			if len(pat.Patterns) > 0 {
   432  				spat, ok := pat.Patterns[len(pat.Patterns)-1].(*singlePattern)
   433  				if ok {
   434  					pat.Patterns = pat.Patterns[0 : len(pat.Patterns)-1]
   435  					pat.Patterns = append(pat.Patterns, &repeatPattern{ch, spat.Class})
   436  					continue
   437  				}
   438  			}
   439  			pat.Patterns = append(pat.Patterns, &singlePattern{&charClass{ch}})
   440  		case '$':
   441  			if toplevel && (sc.NextPos() == sc.Length()-1 || sc.NextPos() == EOS) {
   442  				pat.MustTail = true
   443  			} else {
   444  				pat.Patterns = append(pat.Patterns, &singlePattern{&charClass{ch}})
   445  			}
   446  			sc.Next()
   447  		case EOS:
   448  			sc.Next()
   449  			goto exit
   450  		default:
   451  			sc.Next()
   452  			pat.Patterns = append(pat.Patterns, &singlePattern{&charClass{ch}})
   453  		}
   454  	}
   455  exit:
   456  	return pat
   457  }
   458  
   459  type iptr struct {
   460  	insts   []inst
   461  	capture int
   462  }
   463  
   464  func compilePattern(p pattern, ps ...*iptr) []inst {
   465  	var ptr *iptr
   466  	toplevel := false
   467  	if len(ps) == 0 {
   468  		toplevel = true
   469  		ptr = &iptr{[]inst{{opSave, nil, 0, -1}}, 2}
   470  	} else {
   471  		ptr = ps[0]
   472  	}
   473  	switch pat := p.(type) {
   474  	case *singlePattern:
   475  		ptr.insts = append(ptr.insts, inst{opChar, pat.Class, -1, -1})
   476  	case *seqPattern:
   477  		for _, cp := range pat.Patterns {
   478  			compilePattern(cp, ptr)
   479  		}
   480  	case *repeatPattern:
   481  		idx := len(ptr.insts)
   482  		switch pat.Type {
   483  		case '*':
   484  			ptr.insts = append(ptr.insts,
   485  				inst{opSplit, nil, idx + 1, idx + 3},
   486  				inst{opChar, pat.Class, -1, -1},
   487  				inst{opJmp, nil, idx, -1})
   488  		case '+':
   489  			ptr.insts = append(ptr.insts,
   490  				inst{opChar, pat.Class, -1, -1},
   491  				inst{opSplit, nil, idx, idx + 2})
   492  		case '-':
   493  			ptr.insts = append(ptr.insts,
   494  				inst{opSplit, nil, idx + 3, idx + 1},
   495  				inst{opChar, pat.Class, -1, -1},
   496  				inst{opJmp, nil, idx, -1})
   497  		case '?':
   498  			ptr.insts = append(ptr.insts,
   499  				inst{opSplit, nil, idx + 1, idx + 2},
   500  				inst{opChar, pat.Class, -1, -1})
   501  		}
   502  	case *posCapPattern:
   503  		ptr.insts = append(ptr.insts, inst{opPSave, nil, ptr.capture, -1})
   504  		ptr.capture += 2
   505  	case *capPattern:
   506  		c0, c1 := ptr.capture, ptr.capture+1
   507  		ptr.capture += 2
   508  		ptr.insts = append(ptr.insts, inst{opSave, nil, c0, -1})
   509  		compilePattern(pat.Pattern, ptr)
   510  		ptr.insts = append(ptr.insts, inst{opSave, nil, c1, -1})
   511  	case *bracePattern:
   512  		ptr.insts = append(ptr.insts, inst{opBrace, nil, pat.Begin, pat.End})
   513  	case *numberPattern:
   514  		ptr.insts = append(ptr.insts, inst{opNumber, nil, pat.N, -1})
   515  	}
   516  	if toplevel {
   517  		if p.(*seqPattern).MustTail {
   518  			ptr.insts = append(ptr.insts, inst{opSave, nil, 1, -1}, inst{opTailMatch, nil, -1, -1})
   519  		}
   520  		ptr.insts = append(ptr.insts, inst{opSave, nil, 1, -1}, inst{opMatch, nil, -1, -1})
   521  	}
   522  	return ptr.insts
   523  }
   524  
   525  // Simple recursive virtual machine based on the
   526  // "Regular Expression Matching: the Virtual Machine Approach" (https://swtch.com/~rsc/regexp/regexp2.html)
   527  func recursiveVM(src []byte, insts []inst, pc, sp int, ms ...*MatchData) (bool, int, *MatchData) {
   528  	var m *MatchData
   529  	if len(ms) == 0 {
   530  		m = newMatchState()
   531  	} else {
   532  		m = ms[0]
   533  	}
   534  redo:
   535  	inst := insts[pc]
   536  	switch inst.OpCode {
   537  	case opChar:
   538  		if sp >= len(src) || !inst.Class.Matches(int(src[sp])) {
   539  			return false, sp, m
   540  		}
   541  		pc++
   542  		sp++
   543  		goto redo
   544  	case opMatch:
   545  		return true, sp, m
   546  	case opTailMatch:
   547  		return sp >= len(src), sp, m
   548  	case opJmp:
   549  		pc = inst.Operand1
   550  		goto redo
   551  	case opSplit:
   552  		if ok, nsp, _ := recursiveVM(src, insts, inst.Operand1, sp, m); ok {
   553  			return true, nsp, m
   554  		}
   555  		pc = inst.Operand2
   556  		goto redo
   557  	case opSave:
   558  		s := m.setCapture(inst.Operand1, sp)
   559  		if ok, nsp, _ := recursiveVM(src, insts, pc+1, sp, m); ok {
   560  			return true, nsp, m
   561  		}
   562  		m.restoreCapture(inst.Operand1, s)
   563  		return false, sp, m
   564  	case opPSave:
   565  		m.addPosCapture(inst.Operand1, sp+1)
   566  		pc++
   567  		goto redo
   568  	case opBrace:
   569  		if sp >= len(src) || int(src[sp]) != inst.Operand1 {
   570  			return false, sp, m
   571  		}
   572  		count := 1
   573  		for sp = sp + 1; sp < len(src); sp++ {
   574  			if int(src[sp]) == inst.Operand2 {
   575  				count--
   576  			}
   577  			if count == 0 {
   578  				pc++
   579  				sp++
   580  				goto redo
   581  			}
   582  			if int(src[sp]) == inst.Operand1 {
   583  				count++
   584  			}
   585  		}
   586  		return false, sp, m
   587  	case opNumber:
   588  		idx := inst.Operand1 * 2
   589  		if idx >= m.CaptureLength()-1 {
   590  			panic(newError(UNKNOWN, "invalid capture index"))
   591  		}
   592  		capture := src[m.Capture(idx):m.Capture(idx+1)]
   593  		for i := 0; i < len(capture); i++ {
   594  			if i+sp >= len(src) || capture[i] != src[i+sp] {
   595  				return false, sp, m
   596  			}
   597  		}
   598  		pc++
   599  		sp += len(capture)
   600  		goto redo
   601  	}
   602  	panic("should not reach here")
   603  }
   604  
   605  // Find 查询
   606  func Find(p string, src []byte, offset, limit int) (matches []*MatchData, err error) {
   607  	defer func() {
   608  		if v := recover(); v != nil {
   609  			if perr, ok := v.(*Error); ok {
   610  				err = perr
   611  			} else {
   612  				panic(v)
   613  			}
   614  		}
   615  	}()
   616  	pat := parsePattern(newScanner([]byte(p)), true)
   617  	insts := compilePattern(pat)
   618  	matches = []*MatchData{}
   619  	for sp := offset; sp <= len(src); {
   620  		ok, nsp, ms := recursiveVM(src, insts, 0, sp)
   621  		sp++
   622  		if ok {
   623  			if sp < nsp {
   624  				sp = nsp
   625  			}
   626  			matches = append(matches, ms)
   627  		}
   628  		if len(matches) == limit || pat.MustHead {
   629  			break
   630  		}
   631  	}
   632  	return
   633  }