github.com/dfcfw/lua@v0.0.0-20230325031207-0cc7ffb7b8b9/pm/pm.go (about)

     1  // Lua pattern match functions for Go
     2  package pm
     3  
     4  import (
     5  	"fmt"
     6  )
     7  
     8  const (
     9  	EOS      = -1
    10  	_UNKNOWN = -2
    11  )
    12  
    13  /* Error {{{ */
    14  
    15  type Error struct {
    16  	Pos     int
    17  	Message string
    18  }
    19  
    20  func newError(pos int, message string, args ...interface{}) *Error {
    21  	if len(args) == 0 {
    22  		return &Error{pos, message}
    23  	}
    24  	return &Error{pos, fmt.Sprintf(message, args...)}
    25  }
    26  
    27  func (e *Error) Error() string {
    28  	switch e.Pos {
    29  	case EOS:
    30  		return fmt.Sprintf("%s at EOS", e.Message)
    31  	case _UNKNOWN:
    32  		return fmt.Sprintf("%s", e.Message)
    33  	default:
    34  		return fmt.Sprintf("%s at %d", e.Message, e.Pos)
    35  	}
    36  }
    37  
    38  /* }}} */
    39  
    40  /* MatchData {{{ */
    41  
    42  type MatchData struct {
    43  	// captured positions
    44  	// layout
    45  	// xxxx xxxx xxxx xxx0 : caputured positions
    46  	// xxxx xxxx xxxx xxx1 : position captured positions
    47  	captures []uint32
    48  }
    49  
    50  func newMatchState() *MatchData { return &MatchData{[]uint32{}} }
    51  
    52  func (st *MatchData) addPosCapture(s, pos int) {
    53  	for s+1 >= len(st.captures) {
    54  		st.captures = append(st.captures, 0)
    55  	}
    56  	st.captures[s] = (uint32(pos) << 1) | 1
    57  	st.captures[s+1] = (uint32(pos) << 1) | 1
    58  }
    59  
    60  func (st *MatchData) setCapture(s, pos int) uint32 {
    61  	for s >= len(st.captures) {
    62  		st.captures = append(st.captures, 0)
    63  	}
    64  	v := st.captures[s]
    65  	st.captures[s] = (uint32(pos) << 1)
    66  	return v
    67  }
    68  
    69  func (st *MatchData) restoreCapture(s int, pos uint32) { st.captures[s] = pos }
    70  
    71  func (st *MatchData) CaptureLength() int { return len(st.captures) }
    72  
    73  func (st *MatchData) IsPosCapture(idx int) bool { return (st.captures[idx] & 1) == 1 }
    74  
    75  func (st *MatchData) Capture(idx int) int { return int(st.captures[idx] >> 1) }
    76  
    77  /* }}} */
    78  
    79  /* scanner {{{ */
    80  
    81  type scannerState struct {
    82  	Pos     int
    83  	started bool
    84  }
    85  
    86  type scanner struct {
    87  	src   []byte
    88  	State scannerState
    89  	saved scannerState
    90  }
    91  
    92  func newScanner(src []byte) *scanner {
    93  	return &scanner{
    94  		src: src,
    95  		State: scannerState{
    96  			Pos:     0,
    97  			started: false,
    98  		},
    99  		saved: scannerState{},
   100  	}
   101  }
   102  
   103  func (sc *scanner) Length() int { return len(sc.src) }
   104  
   105  func (sc *scanner) Next() int {
   106  	if !sc.State.started {
   107  		sc.State.started = true
   108  		if len(sc.src) == 0 {
   109  			sc.State.Pos = EOS
   110  		}
   111  	} else {
   112  		sc.State.Pos = sc.NextPos()
   113  	}
   114  	if sc.State.Pos == EOS {
   115  		return EOS
   116  	}
   117  	return int(sc.src[sc.State.Pos])
   118  }
   119  
   120  func (sc *scanner) CurrentPos() int {
   121  	return sc.State.Pos
   122  }
   123  
   124  func (sc *scanner) NextPos() int {
   125  	if sc.State.Pos == EOS || sc.State.Pos >= len(sc.src)-1 {
   126  		return EOS
   127  	}
   128  	if !sc.State.started {
   129  		return 0
   130  	} else {
   131  		return sc.State.Pos + 1
   132  	}
   133  }
   134  
   135  func (sc *scanner) Peek() int {
   136  	cureof := sc.State.Pos == EOS
   137  	ch := sc.Next()
   138  	if !cureof {
   139  		if sc.State.Pos == EOS {
   140  			sc.State.Pos = len(sc.src) - 1
   141  		} else {
   142  			sc.State.Pos--
   143  			if sc.State.Pos < 0 {
   144  				sc.State.Pos = 0
   145  				sc.State.started = false
   146  			}
   147  		}
   148  	}
   149  	return ch
   150  }
   151  
   152  func (sc *scanner) Save() { sc.saved = sc.State }
   153  
   154  func (sc *scanner) Restore() { sc.State = sc.saved }
   155  
   156  /* }}} */
   157  
   158  /* bytecode {{{ */
   159  
   160  type opCode int
   161  
   162  const (
   163  	opChar opCode = iota
   164  	opMatch
   165  	opTailMatch
   166  	opJmp
   167  	opSplit
   168  	opSave
   169  	opPSave
   170  	opBrace
   171  	opNumber
   172  )
   173  
   174  type inst struct {
   175  	OpCode   opCode
   176  	Class    class
   177  	Operand1 int
   178  	Operand2 int
   179  }
   180  
   181  /* }}} */
   182  
   183  /* classes {{{ */
   184  
   185  type class interface {
   186  	Matches(ch int) bool
   187  }
   188  
   189  type dotClass struct{}
   190  
   191  func (pn *dotClass) Matches(ch int) bool { return true }
   192  
   193  type charClass struct {
   194  	Ch int
   195  }
   196  
   197  func (pn *charClass) Matches(ch int) bool { return pn.Ch == ch }
   198  
   199  type singleClass struct {
   200  	Class int
   201  }
   202  
   203  func (pn *singleClass) Matches(ch int) bool {
   204  	ret := false
   205  	switch pn.Class {
   206  	case 'a', 'A':
   207  		ret = 'A' <= ch && ch <= 'Z' || 'a' <= ch && ch <= 'z'
   208  	case 'c', 'C':
   209  		ret = (0x00 <= ch && ch <= 0x1F) || ch == 0x7F
   210  	case 'd', 'D':
   211  		ret = '0' <= ch && ch <= '9'
   212  	case 'l', 'L':
   213  		ret = 'a' <= ch && ch <= 'z'
   214  	case 'p', 'P':
   215  		ret = (0x21 <= ch && ch <= 0x2f) || (0x3a <= ch && ch <= 0x40) || (0x5b <= ch && ch <= 0x60) || (0x7b <= ch && ch <= 0x7e)
   216  	case 's', 'S':
   217  		switch ch {
   218  		case ' ', '\f', '\n', '\r', '\t', '\v':
   219  			ret = true
   220  		}
   221  	case 'u', 'U':
   222  		ret = 'A' <= ch && ch <= 'Z'
   223  	case 'w', 'W':
   224  		ret = '0' <= ch && ch <= '9' || 'A' <= ch && ch <= 'Z' || 'a' <= ch && ch <= 'z'
   225  	case 'x', 'X':
   226  		ret = '0' <= ch && ch <= '9' || 'a' <= ch && ch <= 'f' || 'A' <= ch && ch <= 'F'
   227  	case 'z', 'Z':
   228  		ret = ch == 0
   229  	default:
   230  		return ch == pn.Class
   231  	}
   232  	if 'A' <= pn.Class && pn.Class <= 'Z' {
   233  		return !ret
   234  	}
   235  	return ret
   236  }
   237  
   238  type setClass struct {
   239  	IsNot   bool
   240  	Classes []class
   241  }
   242  
   243  func (pn *setClass) Matches(ch int) bool {
   244  	for _, class := range pn.Classes {
   245  		if class.Matches(ch) {
   246  			return !pn.IsNot
   247  		}
   248  	}
   249  	return pn.IsNot
   250  }
   251  
   252  type rangeClass struct {
   253  	Begin class
   254  	End   class
   255  }
   256  
   257  func (pn *rangeClass) Matches(ch int) bool {
   258  	switch begin := pn.Begin.(type) {
   259  	case *charClass:
   260  		end, ok := pn.End.(*charClass)
   261  		if !ok {
   262  			return false
   263  		}
   264  		return begin.Ch <= ch && ch <= end.Ch
   265  	}
   266  	return false
   267  }
   268  
   269  // }}}
   270  
   271  // patterns {{{
   272  
   273  type pattern interface{}
   274  
   275  type singlePattern struct {
   276  	Class class
   277  }
   278  
   279  type seqPattern struct {
   280  	MustHead bool
   281  	MustTail bool
   282  	Patterns []pattern
   283  }
   284  
   285  type repeatPattern struct {
   286  	Type  int
   287  	Class class
   288  }
   289  
   290  type posCapPattern struct{}
   291  
   292  type capPattern struct {
   293  	Pattern pattern
   294  }
   295  
   296  type numberPattern struct {
   297  	N int
   298  }
   299  
   300  type bracePattern struct {
   301  	Begin int
   302  	End   int
   303  }
   304  
   305  // }}}
   306  
   307  /* parse {{{ */
   308  
   309  func parseClass(sc *scanner, allowset bool) class {
   310  	ch := sc.Next()
   311  	switch ch {
   312  	case '%':
   313  		return &singleClass{sc.Next()}
   314  	case '.':
   315  		if allowset {
   316  			return &dotClass{}
   317  		}
   318  		return &charClass{ch}
   319  	case '[':
   320  		if allowset {
   321  			return parseClassSet(sc)
   322  		}
   323  		return &charClass{ch}
   324  	// case '^' '$', '(', ')', ']', '*', '+', '-', '?':
   325  	//	panic(newError(sc.CurrentPos(), "invalid %c", ch))
   326  	case EOS:
   327  		panic(newError(sc.CurrentPos(), "unexpected EOS"))
   328  	default:
   329  		return &charClass{ch}
   330  	}
   331  }
   332  
   333  func parseClassSet(sc *scanner) class {
   334  	set := &setClass{false, []class{}}
   335  	if sc.Peek() == '^' {
   336  		set.IsNot = true
   337  		sc.Next()
   338  	}
   339  	isrange := false
   340  	for {
   341  		ch := sc.Peek()
   342  		switch ch {
   343  		// case '[':
   344  		// 	panic(newError(sc.CurrentPos(), "'[' can not be nested"))
   345  		case EOS:
   346  			panic(newError(sc.CurrentPos(), "unexpected EOS"))
   347  		case ']':
   348  			if len(set.Classes) > 0 {
   349  				sc.Next()
   350  				goto exit
   351  			}
   352  			fallthrough
   353  		case '-':
   354  			if len(set.Classes) > 0 {
   355  				sc.Next()
   356  				isrange = true
   357  				continue
   358  			}
   359  			fallthrough
   360  		default:
   361  			set.Classes = append(set.Classes, parseClass(sc, false))
   362  		}
   363  		if isrange {
   364  			begin := set.Classes[len(set.Classes)-2]
   365  			end := set.Classes[len(set.Classes)-1]
   366  			set.Classes = set.Classes[0 : len(set.Classes)-2]
   367  			set.Classes = append(set.Classes, &rangeClass{begin, end})
   368  			isrange = false
   369  		}
   370  	}
   371  exit:
   372  	if isrange {
   373  		set.Classes = append(set.Classes, &charClass{'-'})
   374  	}
   375  
   376  	return set
   377  }
   378  
   379  func parsePattern(sc *scanner, toplevel bool) *seqPattern {
   380  	pat := &seqPattern{}
   381  	if toplevel {
   382  		if sc.Peek() == '^' {
   383  			sc.Next()
   384  			pat.MustHead = true
   385  		}
   386  	}
   387  	for {
   388  		ch := sc.Peek()
   389  		switch ch {
   390  		case '%':
   391  			sc.Save()
   392  			sc.Next()
   393  			switch sc.Peek() {
   394  			case '0':
   395  				panic(newError(sc.CurrentPos(), "invalid capture index"))
   396  			case '1', '2', '3', '4', '5', '6', '7', '8', '9':
   397  				pat.Patterns = append(pat.Patterns, &numberPattern{sc.Next() - 48})
   398  			case 'b':
   399  				sc.Next()
   400  				pat.Patterns = append(pat.Patterns, &bracePattern{sc.Next(), sc.Next()})
   401  			default:
   402  				sc.Restore()
   403  				pat.Patterns = append(pat.Patterns, &singlePattern{parseClass(sc, true)})
   404  			}
   405  		case '.', '[', ']':
   406  			pat.Patterns = append(pat.Patterns, &singlePattern{parseClass(sc, true)})
   407  		// case ']':
   408  		//	panic(newError(sc.CurrentPos(), "invalid ']'"))
   409  		case ')':
   410  			if toplevel {
   411  				panic(newError(sc.CurrentPos(), "invalid ')'"))
   412  			}
   413  			return pat
   414  		case '(':
   415  			sc.Next()
   416  			if sc.Peek() == ')' {
   417  				sc.Next()
   418  				pat.Patterns = append(pat.Patterns, &posCapPattern{})
   419  			} else {
   420  				ret := &capPattern{parsePattern(sc, false)}
   421  				if sc.Peek() != ')' {
   422  					panic(newError(sc.CurrentPos(), "unfinished capture"))
   423  				}
   424  				sc.Next()
   425  				pat.Patterns = append(pat.Patterns, ret)
   426  			}
   427  		case '*', '+', '-', '?':
   428  			sc.Next()
   429  			if len(pat.Patterns) > 0 {
   430  				spat, ok := pat.Patterns[len(pat.Patterns)-1].(*singlePattern)
   431  				if ok {
   432  					pat.Patterns = pat.Patterns[0 : len(pat.Patterns)-1]
   433  					pat.Patterns = append(pat.Patterns, &repeatPattern{ch, spat.Class})
   434  					continue
   435  				}
   436  			}
   437  			pat.Patterns = append(pat.Patterns, &singlePattern{&charClass{ch}})
   438  		case '$':
   439  			if toplevel && (sc.NextPos() == sc.Length()-1 || sc.NextPos() == EOS) {
   440  				pat.MustTail = true
   441  			} else {
   442  				pat.Patterns = append(pat.Patterns, &singlePattern{&charClass{ch}})
   443  			}
   444  			sc.Next()
   445  		case EOS:
   446  			sc.Next()
   447  			goto exit
   448  		default:
   449  			sc.Next()
   450  			pat.Patterns = append(pat.Patterns, &singlePattern{&charClass{ch}})
   451  		}
   452  	}
   453  exit:
   454  	return pat
   455  }
   456  
   457  type iptr struct {
   458  	insts   []inst
   459  	capture int
   460  }
   461  
   462  func compilePattern(p pattern, ps ...*iptr) []inst {
   463  	var ptr *iptr
   464  	toplevel := false
   465  	if len(ps) == 0 {
   466  		toplevel = true
   467  		ptr = &iptr{[]inst{{opSave, nil, 0, -1}}, 2}
   468  	} else {
   469  		ptr = ps[0]
   470  	}
   471  	switch pat := p.(type) {
   472  	case *singlePattern:
   473  		ptr.insts = append(ptr.insts, inst{opChar, pat.Class, -1, -1})
   474  	case *seqPattern:
   475  		for _, cp := range pat.Patterns {
   476  			compilePattern(cp, ptr)
   477  		}
   478  	case *repeatPattern:
   479  		idx := len(ptr.insts)
   480  		switch pat.Type {
   481  		case '*':
   482  			ptr.insts = append(ptr.insts,
   483  				inst{opSplit, nil, idx + 1, idx + 3},
   484  				inst{opChar, pat.Class, -1, -1},
   485  				inst{opJmp, nil, idx, -1})
   486  		case '+':
   487  			ptr.insts = append(ptr.insts,
   488  				inst{opChar, pat.Class, -1, -1},
   489  				inst{opSplit, nil, idx, idx + 2})
   490  		case '-':
   491  			ptr.insts = append(ptr.insts,
   492  				inst{opSplit, nil, idx + 3, idx + 1},
   493  				inst{opChar, pat.Class, -1, -1},
   494  				inst{opJmp, nil, idx, -1})
   495  		case '?':
   496  			ptr.insts = append(ptr.insts,
   497  				inst{opSplit, nil, idx + 1, idx + 2},
   498  				inst{opChar, pat.Class, -1, -1})
   499  		}
   500  	case *posCapPattern:
   501  		ptr.insts = append(ptr.insts, inst{opPSave, nil, ptr.capture, -1})
   502  		ptr.capture += 2
   503  	case *capPattern:
   504  		c0, c1 := ptr.capture, ptr.capture+1
   505  		ptr.capture += 2
   506  		ptr.insts = append(ptr.insts, inst{opSave, nil, c0, -1})
   507  		compilePattern(pat.Pattern, ptr)
   508  		ptr.insts = append(ptr.insts, inst{opSave, nil, c1, -1})
   509  	case *bracePattern:
   510  		ptr.insts = append(ptr.insts, inst{opBrace, nil, pat.Begin, pat.End})
   511  	case *numberPattern:
   512  		ptr.insts = append(ptr.insts, inst{opNumber, nil, pat.N, -1})
   513  	}
   514  	if toplevel {
   515  		if p.(*seqPattern).MustTail {
   516  			ptr.insts = append(ptr.insts, inst{opSave, nil, 1, -1}, inst{opTailMatch, nil, -1, -1})
   517  		}
   518  		ptr.insts = append(ptr.insts, inst{opSave, nil, 1, -1}, inst{opMatch, nil, -1, -1})
   519  	}
   520  	return ptr.insts
   521  }
   522  
   523  /* }}} parse */
   524  
   525  /* VM {{{ */
   526  
   527  // Simple recursive virtual machine based on the
   528  // "Regular Expression Matching: the Virtual Machine Approach" (https://swtch.com/~rsc/regexp/regexp2.html)
   529  func recursiveVM(src []byte, insts []inst, pc, sp int, ms ...*MatchData) (bool, int, *MatchData) {
   530  	var m *MatchData
   531  	if len(ms) == 0 {
   532  		m = newMatchState()
   533  	} else {
   534  		m = ms[0]
   535  	}
   536  redo:
   537  	inst := insts[pc]
   538  	switch inst.OpCode {
   539  	case opChar:
   540  		if sp >= len(src) || !inst.Class.Matches(int(src[sp])) {
   541  			return false, sp, m
   542  		}
   543  		pc++
   544  		sp++
   545  		goto redo
   546  	case opMatch:
   547  		return true, sp, m
   548  	case opTailMatch:
   549  		return sp >= len(src), sp, m
   550  	case opJmp:
   551  		pc = inst.Operand1
   552  		goto redo
   553  	case opSplit:
   554  		if ok, nsp, _ := recursiveVM(src, insts, inst.Operand1, sp, m); ok {
   555  			return true, nsp, m
   556  		}
   557  		pc = inst.Operand2
   558  		goto redo
   559  	case opSave:
   560  		s := m.setCapture(inst.Operand1, sp)
   561  		if ok, nsp, _ := recursiveVM(src, insts, pc+1, sp, m); ok {
   562  			return true, nsp, m
   563  		}
   564  		m.restoreCapture(inst.Operand1, s)
   565  		return false, sp, m
   566  	case opPSave:
   567  		m.addPosCapture(inst.Operand1, sp+1)
   568  		pc++
   569  		goto redo
   570  	case opBrace:
   571  		if sp >= len(src) || int(src[sp]) != inst.Operand1 {
   572  			return false, sp, m
   573  		}
   574  		count := 1
   575  		for sp = sp + 1; sp < len(src); sp++ {
   576  			if int(src[sp]) == inst.Operand2 {
   577  				count--
   578  			}
   579  			if count == 0 {
   580  				pc++
   581  				sp++
   582  				goto redo
   583  			}
   584  			if int(src[sp]) == inst.Operand1 {
   585  				count++
   586  			}
   587  		}
   588  		return false, sp, m
   589  	case opNumber:
   590  		idx := inst.Operand1 * 2
   591  		if idx >= m.CaptureLength()-1 {
   592  			panic(newError(_UNKNOWN, "invalid capture index"))
   593  		}
   594  		capture := src[m.Capture(idx):m.Capture(idx+1)]
   595  		for i := 0; i < len(capture); i++ {
   596  			if i+sp >= len(src) || capture[i] != src[i+sp] {
   597  				return false, sp, m
   598  			}
   599  		}
   600  		pc++
   601  		sp += len(capture)
   602  		goto redo
   603  	}
   604  	panic("should not reach here")
   605  }
   606  
   607  /* }}} */
   608  
   609  /* API {{{ */
   610  
   611  func Find(p string, src []byte, offset, limit int) (matches []*MatchData, err error) {
   612  	defer func() {
   613  		if v := recover(); v != nil {
   614  			if perr, ok := v.(*Error); ok {
   615  				err = perr
   616  			} else {
   617  				panic(v)
   618  			}
   619  		}
   620  	}()
   621  	pat := parsePattern(newScanner([]byte(p)), true)
   622  	insts := compilePattern(pat)
   623  	matches = []*MatchData{}
   624  	for sp := offset; sp <= len(src); {
   625  		ok, nsp, ms := recursiveVM(src, insts, 0, sp)
   626  		sp++
   627  		if ok {
   628  			if sp < nsp {
   629  				sp = nsp
   630  			}
   631  			matches = append(matches, ms)
   632  		}
   633  		if len(matches) == limit || pat.MustHead {
   634  			break
   635  		}
   636  	}
   637  	return
   638  }
   639  
   640  /* }}} */