github.com/shawnclovie/gopher-lua@v0.0.0-20200520092726-90b44ec0e2f2/pm/pm.go (about)

     1  // Lua pattern match functions for Go
     2  package pm
     3  
     4  import (
     5  	"fmt"
     6  )
     7  
     8  const EOS = -1
     9  const _UNKNOWN = -2
    10  
    11  /* Error {{{ */
    12  
    13  type Error struct {
    14  	Pos     int
    15  	Message string
    16  }
    17  
    18  func newError(pos int, message string, args ...interface{}) *Error {
    19  	if len(args) == 0 {
    20  		return &Error{pos, message}
    21  	}
    22  	return &Error{pos, fmt.Sprintf(message, args...)}
    23  }
    24  
    25  func (e *Error) Error() string {
    26  	switch e.Pos {
    27  	case EOS:
    28  		return fmt.Sprintf("%s at EOS", e.Message)
    29  	case _UNKNOWN:
    30  		return fmt.Sprintf("%s", e.Message)
    31  	default:
    32  		return fmt.Sprintf("%s at %d", e.Message, e.Pos)
    33  	}
    34  }
    35  
    36  /* }}} */
    37  
    38  /* MatchData {{{ */
    39  
    40  type MatchData struct {
    41  	// captured positions
    42  	// layout
    43  	// xxxx xxxx xxxx xxx0 : caputured positions
    44  	// xxxx xxxx xxxx xxx1 : position captured positions
    45  	captures []uint32
    46  }
    47  
    48  func newMatchState() *MatchData { return &MatchData{[]uint32{}} }
    49  
    50  func (st *MatchData) addPosCapture(s, pos int) {
    51  	for s+1 >= len(st.captures) {
    52  		st.captures = append(st.captures, 0)
    53  	}
    54  	st.captures[s] = (uint32(pos) << 1) | 1
    55  	st.captures[s+1] = (uint32(pos) << 1) | 1
    56  }
    57  
    58  func (st *MatchData) setCapture(s, pos int) uint32 {
    59  	for s >= len(st.captures) {
    60  		st.captures = append(st.captures, 0)
    61  	}
    62  	v := st.captures[s]
    63  	st.captures[s] = (uint32(pos) << 1)
    64  	return v
    65  }
    66  
    67  func (st *MatchData) restoreCapture(s int, pos uint32) { st.captures[s] = pos }
    68  
    69  func (st *MatchData) CaptureLength() int { return len(st.captures) }
    70  
    71  func (st *MatchData) IsPosCapture(idx int) bool { return (st.captures[idx] & 1) == 1 }
    72  
    73  func (st *MatchData) Capture(idx int) int { return int(st.captures[idx] >> 1) }
    74  
    75  /* }}} */
    76  
    77  /* scanner {{{ */
    78  
    79  type scannerState struct {
    80  	Pos     int
    81  	started bool
    82  }
    83  
    84  type scanner struct {
    85  	src   []byte
    86  	State scannerState
    87  	saved scannerState
    88  }
    89  
    90  func newScanner(src []byte) *scanner {
    91  	return &scanner{
    92  		src: src,
    93  		State: scannerState{
    94  			Pos:     0,
    95  			started: false,
    96  		},
    97  		saved: scannerState{},
    98  	}
    99  }
   100  
   101  func (sc *scanner) Length() int { return len(sc.src) }
   102  
   103  func (sc *scanner) Next() int {
   104  	if !sc.State.started {
   105  		sc.State.started = true
   106  		if len(sc.src) == 0 {
   107  			sc.State.Pos = EOS
   108  		}
   109  	} else {
   110  		sc.State.Pos = sc.NextPos()
   111  	}
   112  	if sc.State.Pos == EOS {
   113  		return EOS
   114  	}
   115  	return int(sc.src[sc.State.Pos])
   116  }
   117  
   118  func (sc *scanner) CurrentPos() int {
   119  	return sc.State.Pos
   120  }
   121  
   122  func (sc *scanner) NextPos() int {
   123  	if sc.State.Pos == EOS || sc.State.Pos >= len(sc.src)-1 {
   124  		return EOS
   125  	}
   126  	if !sc.State.started {
   127  		return 0
   128  	} else {
   129  		return sc.State.Pos + 1
   130  	}
   131  }
   132  
   133  func (sc *scanner) Peek() int {
   134  	cureof := sc.State.Pos == EOS
   135  	ch := sc.Next()
   136  	if !cureof {
   137  		if sc.State.Pos == EOS {
   138  			sc.State.Pos = len(sc.src) - 1
   139  		} else {
   140  			sc.State.Pos--
   141  			if sc.State.Pos < 0 {
   142  				sc.State.Pos = 0
   143  				sc.State.started = false
   144  			}
   145  		}
   146  	}
   147  	return ch
   148  }
   149  
   150  func (sc *scanner) Save() { sc.saved = sc.State }
   151  
   152  func (sc *scanner) Restore() { sc.State = sc.saved }
   153  
   154  /* }}} */
   155  
   156  /* bytecode {{{ */
   157  
   158  type opCode int
   159  
   160  const (
   161  	opChar opCode = iota
   162  	opMatch
   163  	opTailMatch
   164  	opJmp
   165  	opSplit
   166  	opSave
   167  	opPSave
   168  	opBrace
   169  	opNumber
   170  )
   171  
   172  type inst struct {
   173  	OpCode   opCode
   174  	Class    class
   175  	Operand1 int
   176  	Operand2 int
   177  }
   178  
   179  /* }}} */
   180  
   181  /* classes {{{ */
   182  
   183  type class interface {
   184  	Matches(ch int) bool
   185  }
   186  
   187  type dotClass struct{}
   188  
   189  func (pn *dotClass) Matches(ch int) bool { return true }
   190  
   191  type charClass struct {
   192  	Ch int
   193  }
   194  
   195  func (pn *charClass) Matches(ch int) bool { return pn.Ch == ch }
   196  
   197  type singleClass struct {
   198  	Class int
   199  }
   200  
   201  func (pn *singleClass) Matches(ch int) bool {
   202  	ret := false
   203  	switch pn.Class {
   204  	case 'a', 'A':
   205  		ret = 'A' <= ch && ch <= 'Z' || 'a' <= ch && ch <= 'z'
   206  	case 'c', 'C':
   207  		ret = (0x00 <= ch && ch <= 0x1F) || ch == 0x7F
   208  	case 'd', 'D':
   209  		ret = '0' <= ch && ch <= '9'
   210  	case 'l', 'L':
   211  		ret = 'a' <= ch && ch <= 'z'
   212  	case 'p', 'P':
   213  		ret = (0x21 <= ch && ch <= 0x2f) || (0x30 <= ch && ch <= 0x40) || (0x5b <= ch && ch <= 0x60) || (0x7b <= ch && ch <= 0x7e)
   214  	case 's', 'S':
   215  		switch ch {
   216  		case ' ', '\f', '\n', '\r', '\t', '\v':
   217  			ret = true
   218  		}
   219  	case 'u', 'U':
   220  		ret = 'A' <= ch && ch <= 'Z'
   221  	case 'w', 'W':
   222  		ret = '0' <= ch && ch <= '9' || 'A' <= ch && ch <= 'Z' || 'a' <= ch && ch <= 'z'
   223  	case 'x', 'X':
   224  		ret = '0' <= ch && ch <= '9' || 'a' <= ch && ch <= 'f' || 'A' <= ch && ch <= 'F'
   225  	case 'z', 'Z':
   226  		ret = ch == 0
   227  	default:
   228  		return ch == pn.Class
   229  	}
   230  	if 'A' <= pn.Class && pn.Class <= 'Z' {
   231  		return !ret
   232  	}
   233  	return ret
   234  }
   235  
   236  type setClass struct {
   237  	IsNot   bool
   238  	Classes []class
   239  }
   240  
   241  func (pn *setClass) Matches(ch int) bool {
   242  	for _, class := range pn.Classes {
   243  		if class.Matches(ch) {
   244  			return !pn.IsNot
   245  		}
   246  	}
   247  	return pn.IsNot
   248  }
   249  
   250  type rangeClass struct {
   251  	Begin class
   252  	End   class
   253  }
   254  
   255  func (pn *rangeClass) Matches(ch int) bool {
   256  	switch begin := pn.Begin.(type) {
   257  	case *charClass:
   258  		end, ok := pn.End.(*charClass)
   259  		if !ok {
   260  			return false
   261  		}
   262  		return begin.Ch <= ch && ch <= end.Ch
   263  	}
   264  	return false
   265  }
   266  
   267  // }}}
   268  
   269  // patterns {{{
   270  
   271  type pattern interface{}
   272  
   273  type singlePattern struct {
   274  	Class class
   275  }
   276  
   277  type seqPattern struct {
   278  	MustHead bool
   279  	MustTail bool
   280  	Patterns []pattern
   281  }
   282  
   283  type repeatPattern struct {
   284  	Type  int
   285  	Class class
   286  }
   287  
   288  type posCapPattern struct{}
   289  
   290  type capPattern struct {
   291  	Pattern pattern
   292  }
   293  
   294  type numberPattern struct {
   295  	N int
   296  }
   297  
   298  type bracePattern struct {
   299  	Begin int
   300  	End   int
   301  }
   302  
   303  // }}}
   304  
   305  /* parse {{{ */
   306  
   307  func parseClass(sc *scanner, allowset bool) class {
   308  	ch := sc.Next()
   309  	switch ch {
   310  	case '%':
   311  		return &singleClass{sc.Next()}
   312  	case '.':
   313  		if allowset {
   314  			return &dotClass{}
   315  		}
   316  		return &charClass{ch}
   317  	case '[':
   318  		if allowset {
   319  			return parseClassSet(sc)
   320  		}
   321  		return &charClass{ch}
   322  	//case '^' '$', '(', ')', ']', '*', '+', '-', '?':
   323  	//	panic(newError(sc.CurrentPos(), "invalid %c", ch))
   324  	case EOS:
   325  		panic(newError(sc.CurrentPos(), "unexpected EOS"))
   326  	default:
   327  		return &charClass{ch}
   328  	}
   329  }
   330  
   331  func parseClassSet(sc *scanner) class {
   332  	set := &setClass{false, []class{}}
   333  	if sc.Peek() == '^' {
   334  		set.IsNot = true
   335  		sc.Next()
   336  	}
   337  	isrange := false
   338  	for {
   339  		ch := sc.Peek()
   340  		switch ch {
   341  		// case '[':
   342  		// 	panic(newError(sc.CurrentPos(), "'[' can not be nested"))
   343  		case EOS:
   344  			panic(newError(sc.CurrentPos(), "unexpected EOS"))
   345  		case ']':
   346  			if len(set.Classes) > 0 {
   347  				sc.Next()
   348  				goto exit
   349  			}
   350  			fallthrough
   351  		case '-':
   352  			if len(set.Classes) > 0 {
   353  				sc.Next()
   354  				isrange = true
   355  				continue
   356  			}
   357  			fallthrough
   358  		default:
   359  			set.Classes = append(set.Classes, parseClass(sc, false))
   360  		}
   361  		if isrange {
   362  			begin := set.Classes[len(set.Classes)-2]
   363  			end := set.Classes[len(set.Classes)-1]
   364  			set.Classes = set.Classes[0 : len(set.Classes)-2]
   365  			set.Classes = append(set.Classes, &rangeClass{begin, end})
   366  			isrange = false
   367  		}
   368  	}
   369  exit:
   370  	if isrange {
   371  		set.Classes = append(set.Classes, &charClass{'-'})
   372  	}
   373  
   374  	return set
   375  }
   376  
   377  func parsePattern(sc *scanner, toplevel bool) *seqPattern {
   378  	pat := &seqPattern{}
   379  	if toplevel {
   380  		if sc.Peek() == '^' {
   381  			sc.Next()
   382  			pat.MustHead = true
   383  		}
   384  	}
   385  	for {
   386  		ch := sc.Peek()
   387  		switch ch {
   388  		case '%':
   389  			sc.Save()
   390  			sc.Next()
   391  			switch sc.Peek() {
   392  			case '0':
   393  				panic(newError(sc.CurrentPos(), "invalid capture index"))
   394  			case '1', '2', '3', '4', '5', '6', '7', '8', '9':
   395  				pat.Patterns = append(pat.Patterns, &numberPattern{sc.Next() - 48})
   396  			case 'b':
   397  				sc.Next()
   398  				pat.Patterns = append(pat.Patterns, &bracePattern{sc.Next(), sc.Next()})
   399  			default:
   400  				sc.Restore()
   401  				pat.Patterns = append(pat.Patterns, &singlePattern{parseClass(sc, true)})
   402  			}
   403  		case '.', '[', ']':
   404  			pat.Patterns = append(pat.Patterns, &singlePattern{parseClass(sc, true)})
   405  		//case ']':
   406  		//	panic(newError(sc.CurrentPos(), "invalid ']'"))
   407  		case ')':
   408  			if toplevel {
   409  				panic(newError(sc.CurrentPos(), "invalid ')'"))
   410  			}
   411  			return pat
   412  		case '(':
   413  			sc.Next()
   414  			if sc.Peek() == ')' {
   415  				sc.Next()
   416  				pat.Patterns = append(pat.Patterns, &posCapPattern{})
   417  			} else {
   418  				ret := &capPattern{parsePattern(sc, false)}
   419  				if sc.Peek() != ')' {
   420  					panic(newError(sc.CurrentPos(), "unfinished capture"))
   421  				}
   422  				sc.Next()
   423  				pat.Patterns = append(pat.Patterns, ret)
   424  			}
   425  		case '*', '+', '-', '?':
   426  			sc.Next()
   427  			if len(pat.Patterns) > 0 {
   428  				spat, ok := pat.Patterns[len(pat.Patterns)-1].(*singlePattern)
   429  				if ok {
   430  					pat.Patterns = pat.Patterns[0 : len(pat.Patterns)-1]
   431  					pat.Patterns = append(pat.Patterns, &repeatPattern{ch, spat.Class})
   432  					continue
   433  				}
   434  			}
   435  			pat.Patterns = append(pat.Patterns, &singlePattern{&charClass{ch}})
   436  		case '$':
   437  			if toplevel && (sc.NextPos() == sc.Length()-1 || sc.NextPos() == EOS) {
   438  				pat.MustTail = true
   439  			} else {
   440  				pat.Patterns = append(pat.Patterns, &singlePattern{&charClass{ch}})
   441  			}
   442  			sc.Next()
   443  		case EOS:
   444  			sc.Next()
   445  			goto exit
   446  		default:
   447  			sc.Next()
   448  			pat.Patterns = append(pat.Patterns, &singlePattern{&charClass{ch}})
   449  		}
   450  	}
   451  exit:
   452  	return pat
   453  }
   454  
   455  type iptr struct {
   456  	insts   []inst
   457  	capture int
   458  }
   459  
   460  func compilePattern(p pattern, ps ...*iptr) []inst {
   461  	var ptr *iptr
   462  	toplevel := false
   463  	if len(ps) == 0 {
   464  		toplevel = true
   465  		ptr = &iptr{[]inst{inst{opSave, nil, 0, -1}}, 2}
   466  	} else {
   467  		ptr = ps[0]
   468  	}
   469  	switch pat := p.(type) {
   470  	case *singlePattern:
   471  		ptr.insts = append(ptr.insts, inst{opChar, pat.Class, -1, -1})
   472  	case *seqPattern:
   473  		for _, cp := range pat.Patterns {
   474  			compilePattern(cp, ptr)
   475  		}
   476  	case *repeatPattern:
   477  		idx := len(ptr.insts)
   478  		switch pat.Type {
   479  		case '*':
   480  			ptr.insts = append(ptr.insts,
   481  				inst{opSplit, nil, idx + 1, idx + 3},
   482  				inst{opChar, pat.Class, -1, -1},
   483  				inst{opJmp, nil, idx, -1})
   484  		case '+':
   485  			ptr.insts = append(ptr.insts,
   486  				inst{opChar, pat.Class, -1, -1},
   487  				inst{opSplit, nil, idx, idx + 2})
   488  		case '-':
   489  			ptr.insts = append(ptr.insts,
   490  				inst{opSplit, nil, idx + 3, idx + 1},
   491  				inst{opChar, pat.Class, -1, -1},
   492  				inst{opJmp, nil, idx, -1})
   493  		case '?':
   494  			ptr.insts = append(ptr.insts,
   495  				inst{opSplit, nil, idx + 1, idx + 2},
   496  				inst{opChar, pat.Class, -1, -1})
   497  		}
   498  	case *posCapPattern:
   499  		ptr.insts = append(ptr.insts, inst{opPSave, nil, ptr.capture, -1})
   500  		ptr.capture += 2
   501  	case *capPattern:
   502  		c0, c1 := ptr.capture, ptr.capture+1
   503  		ptr.capture += 2
   504  		ptr.insts = append(ptr.insts, inst{opSave, nil, c0, -1})
   505  		compilePattern(pat.Pattern, ptr)
   506  		ptr.insts = append(ptr.insts, inst{opSave, nil, c1, -1})
   507  	case *bracePattern:
   508  		ptr.insts = append(ptr.insts, inst{opBrace, nil, pat.Begin, pat.End})
   509  	case *numberPattern:
   510  		ptr.insts = append(ptr.insts, inst{opNumber, nil, pat.N, -1})
   511  	}
   512  	if toplevel {
   513  		if p.(*seqPattern).MustTail {
   514  			ptr.insts = append(ptr.insts, inst{opSave, nil, 1, -1}, inst{opTailMatch, nil, -1, -1})
   515  		}
   516  		ptr.insts = append(ptr.insts, inst{opSave, nil, 1, -1}, inst{opMatch, nil, -1, -1})
   517  	}
   518  	return ptr.insts
   519  }
   520  
   521  /* }}} parse */
   522  
   523  /* VM {{{ */
   524  
   525  // Simple recursive virtual machine based on the
   526  // "Regular Expression Matching: the Virtual Machine Approach" (https://swtch.com/~rsc/regexp/regexp2.html)
   527  func recursiveVM(src []byte, insts []inst, pc, sp int, ms ...*MatchData) (bool, int, *MatchData) {
   528  	var m *MatchData
   529  	if len(ms) == 0 {
   530  		m = newMatchState()
   531  	} else {
   532  		m = ms[0]
   533  	}
   534  redo:
   535  	inst := insts[pc]
   536  	switch inst.OpCode {
   537  	case opChar:
   538  		if sp >= len(src) || !inst.Class.Matches(int(src[sp])) {
   539  			return false, sp, m
   540  		}
   541  		pc++
   542  		sp++
   543  		goto redo
   544  	case opMatch:
   545  		return true, sp, m
   546  	case opTailMatch:
   547  		return sp >= len(src), sp, m
   548  	case opJmp:
   549  		pc = inst.Operand1
   550  		goto redo
   551  	case opSplit:
   552  		if ok, nsp, _ := recursiveVM(src, insts, inst.Operand1, sp, m); ok {
   553  			return true, nsp, m
   554  		}
   555  		pc = inst.Operand2
   556  		goto redo
   557  	case opSave:
   558  		s := m.setCapture(inst.Operand1, sp)
   559  		if ok, nsp, _ := recursiveVM(src, insts, pc+1, sp, m); ok {
   560  			return true, nsp, m
   561  		}
   562  		m.restoreCapture(inst.Operand1, s)
   563  		return false, sp, m
   564  	case opPSave:
   565  		m.addPosCapture(inst.Operand1, sp+1)
   566  		pc++
   567  		goto redo
   568  	case opBrace:
   569  		if sp >= len(src) || int(src[sp]) != inst.Operand1 {
   570  			return false, sp, m
   571  		}
   572  		count := 1
   573  		for sp = sp + 1; sp < len(src); sp++ {
   574  			if int(src[sp]) == inst.Operand2 {
   575  				count--
   576  			}
   577  			if count == 0 {
   578  				pc++
   579  				sp++
   580  				goto redo
   581  			}
   582  			if int(src[sp]) == inst.Operand1 {
   583  				count++
   584  			}
   585  		}
   586  		return false, sp, m
   587  	case opNumber:
   588  		idx := inst.Operand1 * 2
   589  		if idx >= m.CaptureLength()-1 {
   590  			panic(newError(_UNKNOWN, "invalid capture index"))
   591  		}
   592  		capture := src[m.Capture(idx):m.Capture(idx+1)]
   593  		for i := 0; i < len(capture); i++ {
   594  			if i+sp >= len(src) || capture[i] != src[i+sp] {
   595  				return false, sp, m
   596  			}
   597  		}
   598  		pc++
   599  		sp += len(capture)
   600  		goto redo
   601  	}
   602  	panic("should not reach here")
   603  }
   604  
   605  /* }}} */
   606  
   607  /* API {{{ */
   608  
   609  func Find(p string, src []byte, offset, limit int) (matches []*MatchData, err error) {
   610  	defer func() {
   611  		if v := recover(); v != nil {
   612  			if perr, ok := v.(*Error); ok {
   613  				err = perr
   614  			} else {
   615  				panic(v)
   616  			}
   617  		}
   618  	}()
   619  	pat := parsePattern(newScanner([]byte(p)), true)
   620  	insts := compilePattern(pat)
   621  	matches = []*MatchData{}
   622  	for sp := offset; sp <= len(src); {
   623  		ok, nsp, ms := recursiveVM(src, insts, 0, sp)
   624  		sp++
   625  		if ok {
   626  			if sp < nsp {
   627  				sp = nsp
   628  			}
   629  			matches = append(matches, ms)
   630  		}
   631  		if len(matches) == limit || pat.MustHead {
   632  			break
   633  		}
   634  	}
   635  	return
   636  }
   637  
   638  /* }}} */