github.com/qiniu/gopher-lua@v0.2017.11/pm/pm.go (about)

     1  // Lua pattern match functions for Go
     2  package pm
     3  
     4  import (
     5  	"fmt"
     6  )
     7  
     8  const EOS = -1
     9  const _UNKNOWN = -2
    10  
    11  /* Error {{{ */
    12  
    13  type Error struct {
    14  	Pos     int
    15  	Message string
    16  }
    17  
    18  func newError(pos int, message string, args ...interface{}) *Error {
    19  	if len(args) == 0 {
    20  		return &Error{pos, message}
    21  	}
    22  	return &Error{pos, fmt.Sprintf(message, args...)}
    23  }
    24  
    25  func (e *Error) Error() string {
    26  	switch e.Pos {
    27  	case EOS:
    28  		return fmt.Sprintf("%s at EOS", e.Message)
    29  	case _UNKNOWN:
    30  		return fmt.Sprintf("%s", e.Message)
    31  	default:
    32  		return fmt.Sprintf("%s at %d", e.Message, e.Pos)
    33  	}
    34  }
    35  
    36  /* }}} */
    37  
    38  /* MatchData {{{ */
    39  
    40  type MatchData struct {
    41  	// captured positions
    42  	// layout
    43  	// xxxx xxxx xxxx xxx0 : caputured positions
    44  	// xxxx xxxx xxxx xxx1 : position captured positions
    45  	captures []uint32
    46  }
    47  
    48  func newMatchState() *MatchData { return &MatchData{[]uint32{}} }
    49  
    50  func (st *MatchData) addPosCapture(s, pos int) {
    51  	for s+1 >= len(st.captures) {
    52  		st.captures = append(st.captures, 0)
    53  	}
    54  	st.captures[s] = (uint32(pos) << 1) | 1
    55  	st.captures[s+1] = (uint32(pos) << 1) | 1
    56  }
    57  
    58  func (st *MatchData) setCapture(s, pos int) uint32 {
    59  	for s >= len(st.captures) {
    60  		st.captures = append(st.captures, 0)
    61  	}
    62  	v := st.captures[s]
    63  	st.captures[s] = (uint32(pos) << 1)
    64  	return v
    65  }
    66  
    67  func (st *MatchData) restoreCapture(s int, pos uint32) { st.captures[s] = pos }
    68  
    69  func (st *MatchData) CaptureLength() int { return len(st.captures) }
    70  
    71  func (st *MatchData) IsPosCapture(idx int) bool { return (st.captures[idx] & 1) == 1 }
    72  
    73  func (st *MatchData) Capture(idx int) int { return int(st.captures[idx] >> 1) }
    74  
    75  /* }}} */
    76  
    77  /* scanner {{{ */
    78  
    79  type scannerState struct {
    80  	Pos     int
    81  	started bool
    82  }
    83  
    84  type scanner struct {
    85  	src   []byte
    86  	State scannerState
    87  	saved scannerState
    88  }
    89  
    90  func newScanner(src []byte) *scanner {
    91  	return &scanner{
    92  		src: src,
    93  		State: scannerState{
    94  			Pos:     0,
    95  			started: false,
    96  		},
    97  		saved: scannerState{},
    98  	}
    99  }
   100  
   101  func (sc *scanner) Length() int { return len(sc.src) }
   102  
   103  func (sc *scanner) Next() int {
   104  	if !sc.State.started {
   105  		sc.State.started = true
   106  		if len(sc.src) == 0 {
   107  			sc.State.Pos = EOS
   108  		}
   109  	} else {
   110  		sc.State.Pos = sc.NextPos()
   111  	}
   112  	if sc.State.Pos == EOS {
   113  		return EOS
   114  	}
   115  	return int(sc.src[sc.State.Pos])
   116  }
   117  
   118  func (sc *scanner) CurrentPos() int {
   119  	return sc.State.Pos
   120  }
   121  
   122  func (sc *scanner) NextPos() int {
   123  	if sc.State.Pos == EOS || sc.State.Pos >= len(sc.src)-1 {
   124  		return EOS
   125  	}
   126  	if !sc.State.started {
   127  		return 0
   128  	} else {
   129  		return sc.State.Pos + 1
   130  	}
   131  }
   132  
   133  func (sc *scanner) Peek() int {
   134  	cureof := sc.State.Pos == EOS
   135  	ch := sc.Next()
   136  	if !cureof {
   137  		if sc.State.Pos == EOS {
   138  			sc.State.Pos = len(sc.src) - 1
   139  		} else {
   140  			sc.State.Pos--
   141  			if sc.State.Pos < 0 {
   142  				sc.State.Pos = 0
   143  				sc.State.started = false
   144  			}
   145  		}
   146  	}
   147  	return ch
   148  }
   149  
   150  func (sc *scanner) Save() { sc.saved = sc.State }
   151  
   152  func (sc *scanner) Restore() { sc.State = sc.saved }
   153  
   154  /* }}} */
   155  
   156  /* bytecode {{{ */
   157  
   158  type opCode int
   159  
   160  const (
   161  	opChar opCode = iota
   162  	opMatch
   163  	opTailMatch
   164  	opJmp
   165  	opSplit
   166  	opSave
   167  	opPSave
   168  	opBrace
   169  	opNumber
   170  )
   171  
   172  type inst struct {
   173  	OpCode   opCode
   174  	Class    class
   175  	Operand1 int
   176  	Operand2 int
   177  }
   178  
   179  /* }}} */
   180  
   181  /* classes {{{ */
   182  
   183  type class interface {
   184  	Matches(ch int) bool
   185  }
   186  
   187  type dotClass struct{}
   188  
   189  func (pn *dotClass) Matches(ch int) bool { return true }
   190  
   191  type charClass struct {
   192  	Ch int
   193  }
   194  
   195  func (pn *charClass) Matches(ch int) bool { return pn.Ch == ch }
   196  
   197  type singleClass struct {
   198  	Class int
   199  }
   200  
   201  func (pn *singleClass) Matches(ch int) bool {
   202  	ret := false
   203  	switch pn.Class {
   204  	case 'a', 'A':
   205  		ret = 'A' <= ch && ch <= 'Z' || 'a' <= ch && ch <= 'z'
   206  	case 'c', 'C':
   207  		ret = (0x00 <= ch && ch <= 0x1F) || ch == 0x7F
   208  	case 'd', 'D':
   209  		ret = '0' <= ch && ch <= '9'
   210  	case 'l', 'L':
   211  		ret = 'a' <= ch && ch <= 'z'
   212  	case 'p', 'P':
   213  		ret = (0x21 <= ch && ch <= 0x2f) || (0x30 <= ch && ch <= 0x40) || (0x5b <= ch && ch <= 0x60) || (0x7b <= ch && ch <= 0x7e)
   214  	case 's', 'S':
   215  		switch ch {
   216  		case ' ', '\f', '\n', '\r', '\t', '\v':
   217  			ret = true
   218  		}
   219  	case 'u', 'U':
   220  		ret = 'A' <= ch && ch <= 'Z'
   221  	case 'w', 'W':
   222  		ret = '0' <= ch && ch <= '9' || 'A' <= ch && ch <= 'Z' || 'a' <= ch && ch <= 'z'
   223  	case 'x', 'X':
   224  		ret = '0' <= ch && ch <= '9' || 'a' <= ch && ch <= 'f' || 'A' <= ch && ch <= 'F'
   225  	case 'z', 'Z':
   226  		ret = ch == 0
   227  	default:
   228  		return ch == pn.Class
   229  	}
   230  	if 'A' <= pn.Class && pn.Class <= 'Z' {
   231  		return !ret
   232  	}
   233  	return ret
   234  }
   235  
   236  type setClass struct {
   237  	IsNot   bool
   238  	Classes []class
   239  }
   240  
   241  func (pn *setClass) Matches(ch int) bool {
   242  	for _, class := range pn.Classes {
   243  		if class.Matches(ch) {
   244  			return !pn.IsNot
   245  		}
   246  	}
   247  	return pn.IsNot
   248  }
   249  
   250  type rangeClass struct {
   251  	Begin class
   252  	End   class
   253  }
   254  
   255  func (pn *rangeClass) Matches(ch int) bool {
   256  	switch begin := pn.Begin.(type) {
   257  	case *charClass:
   258  		end, ok := pn.End.(*charClass)
   259  		if !ok {
   260  			return false
   261  		}
   262  		return begin.Ch <= ch && ch <= end.Ch
   263  	}
   264  	return false
   265  }
   266  
   267  // }}}
   268  
   269  // patterns {{{
   270  
   271  type pattern interface{}
   272  
   273  type singlePattern struct {
   274  	Class class
   275  }
   276  
   277  type seqPattern struct {
   278  	MustHead bool
   279  	MustTail bool
   280  	Patterns []pattern
   281  }
   282  
   283  type repeatPattern struct {
   284  	Type  int
   285  	Class class
   286  }
   287  
   288  type posCapPattern struct{}
   289  
   290  type capPattern struct {
   291  	Pattern pattern
   292  }
   293  
   294  type numberPattern struct {
   295  	N int
   296  }
   297  
   298  type bracePattern struct {
   299  	Begin int
   300  	End   int
   301  }
   302  
   303  // }}}
   304  
   305  /* parse {{{ */
   306  
   307  func parseClass(sc *scanner, allowset bool) class {
   308  	ch := sc.Next()
   309  	switch ch {
   310  	case '%':
   311  		return &singleClass{sc.Next()}
   312  	case '.':
   313  		if allowset {
   314  			return &dotClass{}
   315  		} else {
   316  			return &charClass{ch}
   317  		}
   318  	case '[':
   319  		if !allowset {
   320  			panic(newError(sc.CurrentPos(), "invalid '['"))
   321  		}
   322  		return parseClassSet(sc)
   323  	//case '^' '$', '(', ')', ']', '*', '+', '-', '?':
   324  	//	panic(newError(sc.CurrentPos(), "invalid %c", ch))
   325  	case EOS:
   326  		panic(newError(sc.CurrentPos(), "unexpected EOS"))
   327  	default:
   328  		return &charClass{ch}
   329  	}
   330  }
   331  
   332  func parseClassSet(sc *scanner) class {
   333  	set := &setClass{false, []class{}}
   334  	if sc.Peek() == '^' {
   335  		set.IsNot = true
   336  		sc.Next()
   337  	}
   338  	isrange := false
   339  	for {
   340  		ch := sc.Peek()
   341  		switch ch {
   342  		case '[':
   343  			panic(newError(sc.CurrentPos(), "'[' can not be nested"))
   344  		case ']':
   345  			sc.Next()
   346  			goto exit
   347  		case EOS:
   348  			panic(newError(sc.CurrentPos(), "unexpected EOS"))
   349  		case '-':
   350  			if len(set.Classes) > 0 {
   351  				sc.Next()
   352  				isrange = true
   353  				continue
   354  			}
   355  			fallthrough
   356  		default:
   357  			set.Classes = append(set.Classes, parseClass(sc, false))
   358  		}
   359  		if isrange {
   360  			begin := set.Classes[len(set.Classes)-2]
   361  			end := set.Classes[len(set.Classes)-1]
   362  			set.Classes = set.Classes[0 : len(set.Classes)-2]
   363  			set.Classes = append(set.Classes, &rangeClass{begin, end})
   364  			isrange = false
   365  		}
   366  	}
   367  exit:
   368  	if isrange {
   369  		set.Classes = append(set.Classes, &charClass{'-'})
   370  	}
   371  
   372  	return set
   373  }
   374  
   375  func parsePattern(sc *scanner, toplevel bool) *seqPattern {
   376  	pat := &seqPattern{}
   377  	if toplevel {
   378  		if sc.Peek() == '^' {
   379  			sc.Next()
   380  			pat.MustHead = true
   381  		}
   382  	}
   383  	for {
   384  		ch := sc.Peek()
   385  		switch ch {
   386  		case '%':
   387  			sc.Save()
   388  			sc.Next()
   389  			switch sc.Peek() {
   390  			case '0':
   391  				panic(newError(sc.CurrentPos(), "invalid capture index"))
   392  			case '1', '2', '3', '4', '5', '6', '7', '8', '9':
   393  				pat.Patterns = append(pat.Patterns, &numberPattern{sc.Next() - 48})
   394  			case 'b':
   395  				sc.Next()
   396  				pat.Patterns = append(pat.Patterns, &bracePattern{sc.Next(), sc.Next()})
   397  			default:
   398  				sc.Restore()
   399  				pat.Patterns = append(pat.Patterns, &singlePattern{parseClass(sc, true)})
   400  			}
   401  		case '.', '[':
   402  			pat.Patterns = append(pat.Patterns, &singlePattern{parseClass(sc, true)})
   403  		case ']':
   404  			panic(newError(sc.CurrentPos(), "invalid ']'"))
   405  		case ')':
   406  			if toplevel {
   407  				panic(newError(sc.CurrentPos(), "invalid ')'"))
   408  			}
   409  			return pat
   410  		case '(':
   411  			sc.Next()
   412  			if sc.Peek() == ')' {
   413  				sc.Next()
   414  				pat.Patterns = append(pat.Patterns, &posCapPattern{})
   415  			} else {
   416  				ret := &capPattern{parsePattern(sc, false)}
   417  				if sc.Peek() != ')' {
   418  					panic(newError(sc.CurrentPos(), "unfinished capture"))
   419  				}
   420  				sc.Next()
   421  				pat.Patterns = append(pat.Patterns, ret)
   422  			}
   423  		case '*', '+', '-', '?':
   424  			sc.Next()
   425  			if len(pat.Patterns) > 0 {
   426  				spat, ok := pat.Patterns[len(pat.Patterns)-1].(*singlePattern)
   427  				if ok {
   428  					pat.Patterns = pat.Patterns[0 : len(pat.Patterns)-1]
   429  					pat.Patterns = append(pat.Patterns, &repeatPattern{ch, spat.Class})
   430  					continue
   431  				}
   432  			}
   433  			pat.Patterns = append(pat.Patterns, &singlePattern{&charClass{ch}})
   434  		case '$':
   435  			if toplevel && (sc.NextPos() == sc.Length()-1 || sc.NextPos() == EOS) {
   436  				pat.MustTail = true
   437  			} else {
   438  				pat.Patterns = append(pat.Patterns, &singlePattern{&charClass{ch}})
   439  			}
   440  			sc.Next()
   441  		case EOS:
   442  			sc.Next()
   443  			goto exit
   444  		default:
   445  			sc.Next()
   446  			pat.Patterns = append(pat.Patterns, &singlePattern{&charClass{ch}})
   447  		}
   448  	}
   449  exit:
   450  	return pat
   451  }
   452  
   453  type iptr struct {
   454  	insts   []inst
   455  	capture int
   456  }
   457  
   458  func compilePattern(p pattern, ps ...*iptr) []inst {
   459  	var ptr *iptr
   460  	toplevel := false
   461  	if len(ps) == 0 {
   462  		toplevel = true
   463  		ptr = &iptr{[]inst{inst{opSave, nil, 0, -1}}, 2}
   464  	} else {
   465  		ptr = ps[0]
   466  	}
   467  	switch pat := p.(type) {
   468  	case *singlePattern:
   469  		ptr.insts = append(ptr.insts, inst{opChar, pat.Class, -1, -1})
   470  	case *seqPattern:
   471  		for _, cp := range pat.Patterns {
   472  			compilePattern(cp, ptr)
   473  		}
   474  	case *repeatPattern:
   475  		idx := len(ptr.insts)
   476  		switch pat.Type {
   477  		case '*':
   478  			ptr.insts = append(ptr.insts,
   479  				inst{opSplit, nil, idx + 1, idx + 3},
   480  				inst{opChar, pat.Class, -1, -1},
   481  				inst{opJmp, nil, idx, -1})
   482  		case '+':
   483  			ptr.insts = append(ptr.insts,
   484  				inst{opChar, pat.Class, -1, -1},
   485  				inst{opSplit, nil, idx, idx + 2})
   486  		case '-':
   487  			ptr.insts = append(ptr.insts,
   488  				inst{opSplit, nil, idx + 3, idx + 1},
   489  				inst{opChar, pat.Class, -1, -1},
   490  				inst{opJmp, nil, idx, -1})
   491  		case '?':
   492  			ptr.insts = append(ptr.insts,
   493  				inst{opSplit, nil, idx + 1, idx + 2},
   494  				inst{opChar, pat.Class, -1, -1})
   495  		}
   496  	case *posCapPattern:
   497  		ptr.insts = append(ptr.insts, inst{opPSave, nil, ptr.capture, -1})
   498  		ptr.capture += 2
   499  	case *capPattern:
   500  		c0, c1 := ptr.capture, ptr.capture+1
   501  		ptr.capture += 2
   502  		ptr.insts = append(ptr.insts, inst{opSave, nil, c0, -1})
   503  		compilePattern(pat.Pattern, ptr)
   504  		ptr.insts = append(ptr.insts, inst{opSave, nil, c1, -1})
   505  	case *bracePattern:
   506  		ptr.insts = append(ptr.insts, inst{opBrace, nil, pat.Begin, pat.End})
   507  	case *numberPattern:
   508  		ptr.insts = append(ptr.insts, inst{opNumber, nil, pat.N, -1})
   509  	}
   510  	if toplevel {
   511  		if p.(*seqPattern).MustTail {
   512  			ptr.insts = append(ptr.insts, inst{opSave, nil, 1, -1}, inst{opTailMatch, nil, -1, -1})
   513  		}
   514  		ptr.insts = append(ptr.insts, inst{opSave, nil, 1, -1}, inst{opMatch, nil, -1, -1})
   515  	}
   516  	return ptr.insts
   517  }
   518  
   519  /* }}} parse */
   520  
   521  /* VM {{{ */
   522  
   523  // Simple recursive virtual machine based on the
   524  // "Regular Expression Matching: the Virtual Machine Approach" (https://swtch.com/~rsc/regexp/regexp2.html)
   525  func recursiveVM(src []byte, insts []inst, pc, sp int, ms ...*MatchData) (bool, int, *MatchData) {
   526  	var m *MatchData
   527  	if len(ms) == 0 {
   528  		m = newMatchState()
   529  	} else {
   530  		m = ms[0]
   531  	}
   532  redo:
   533  	inst := insts[pc]
   534  	switch inst.OpCode {
   535  	case opChar:
   536  		if sp >= len(src) || !inst.Class.Matches(int(src[sp])) {
   537  			return false, sp, m
   538  		}
   539  		pc++
   540  		sp++
   541  		goto redo
   542  	case opMatch:
   543  		return true, sp, m
   544  	case opTailMatch:
   545  		return sp >= len(src), sp, m
   546  	case opJmp:
   547  		pc = inst.Operand1
   548  		goto redo
   549  	case opSplit:
   550  		if ok, nsp, _ := recursiveVM(src, insts, inst.Operand1, sp, m); ok {
   551  			return true, nsp, m
   552  		}
   553  		pc = inst.Operand2
   554  		goto redo
   555  	case opSave:
   556  		s := m.setCapture(inst.Operand1, sp)
   557  		if ok, nsp, _ := recursiveVM(src, insts, pc+1, sp, m); ok {
   558  			return true, nsp, m
   559  		}
   560  		m.restoreCapture(inst.Operand1, s)
   561  		return false, sp, m
   562  	case opPSave:
   563  		m.addPosCapture(inst.Operand1, sp+1)
   564  		pc++
   565  		goto redo
   566  	case opBrace:
   567  		if sp >= len(src) || int(src[sp]) != inst.Operand1 {
   568  			return false, sp, m
   569  		}
   570  		count := 1
   571  		for sp = sp + 1; sp < len(src); sp++ {
   572  			if int(src[sp]) == inst.Operand2 {
   573  				count--
   574  			}
   575  			if count == 0 {
   576  				pc++
   577  				sp++
   578  				goto redo
   579  			}
   580  			if int(src[sp]) == inst.Operand1 {
   581  				count++
   582  			}
   583  		}
   584  		return false, sp, m
   585  	case opNumber:
   586  		idx := inst.Operand1 * 2
   587  		if idx >= m.CaptureLength()-1 {
   588  			panic(newError(_UNKNOWN, "invalid capture index"))
   589  		}
   590  		capture := src[m.Capture(idx):m.Capture(idx+1)]
   591  		for i := 0; i < len(capture); i++ {
   592  			if i+sp >= len(src) || capture[i] != src[i+sp] {
   593  				return false, sp, m
   594  			}
   595  		}
   596  		pc++
   597  		sp += len(capture)
   598  		goto redo
   599  	}
   600  	panic("should not reach here")
   601  	return false, sp, m
   602  }
   603  
   604  /* }}} */
   605  
   606  /* API {{{ */
   607  
   608  func Find(p string, src []byte, offset, limit int) (matches []*MatchData, err error) {
   609  	defer func() {
   610  		if v := recover(); v != nil {
   611  			if perr, ok := v.(*Error); ok {
   612  				err = perr
   613  			} else {
   614  				panic(v)
   615  			}
   616  		}
   617  	}()
   618  	pat := parsePattern(newScanner([]byte(p)), true)
   619  	insts := compilePattern(pat)
   620  	matches = []*MatchData{}
   621  	for sp := offset; sp <= len(src); {
   622  		ok, nsp, ms := recursiveVM(src, insts, 0, sp)
   623  		sp++
   624  		if ok {
   625  			if sp < nsp {
   626  				sp = nsp
   627  			}
   628  			matches = append(matches, ms)
   629  		}
   630  		if len(matches) == limit || pat.MustHead {
   631  			break
   632  		}
   633  	}
   634  	return
   635  }
   636  
   637  /* }}} */