github.com/arnodel/golua@v0.0.0-20230215163904-e0b5347eaaa1/lib/stringlib/matching.go (about)

     1  package stringlib
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"regexp"
     7  	"strconv"
     8  	"strings"
     9  
    10  	"github.com/arnodel/golua/lib/stringlib/pattern"
    11  	"github.com/arnodel/golua/luastrings"
    12  	rt "github.com/arnodel/golua/runtime"
    13  )
    14  
    15  func find(t *rt.Thread, c *rt.GoCont) (rt.Cont, error) {
    16  	var (
    17  		s, ptn string
    18  		init   int64 = 1
    19  		plain  bool
    20  	)
    21  	err := c.CheckNArgs(2)
    22  	if err == nil {
    23  		s, err = c.StringArg(0)
    24  	}
    25  	if err == nil {
    26  		ptn, err = c.StringArg(1)
    27  	}
    28  	if err == nil && c.NArgs() >= 3 {
    29  		init, err = c.IntArg(2)
    30  		if err == nil && c.NArgs() >= 4 {
    31  			plain = rt.Truth(c.Arg(3))
    32  		}
    33  	}
    34  	if err != nil {
    35  		return nil, err
    36  	}
    37  	si := luastrings.StringNormPos(s, int(init)) - 1
    38  	if si < 0 {
    39  		si = 0
    40  	}
    41  	next := c.Next()
    42  	switch {
    43  	case si < 0 || si > len(s):
    44  		t.Push1(next, rt.NilValue)
    45  	case plain || len(ptn) == 0:
    46  		// strings.Index is linear
    47  		t.RequireCPU(uint64(len(s) - si))
    48  		i := strings.Index(s[si:], ptn)
    49  		if i == -1 {
    50  			t.Push1(next, rt.NilValue)
    51  		} else {
    52  			t.Push1(next, rt.IntValue(int64(i+1)))
    53  			t.Push1(next, rt.IntValue(int64(i+len(ptn))))
    54  		}
    55  	default:
    56  		pat, err := pattern.New(string(ptn))
    57  		if err != nil {
    58  			return nil, err
    59  		}
    60  		captures, usedCPU := pat.MatchFromStart(string(s), si, t.UnusedCPU())
    61  		t.RequireCPU(usedCPU)
    62  		if len(captures) == 0 {
    63  			t.Push1(next, rt.NilValue)
    64  		} else {
    65  			first := captures[0]
    66  			t.Push1(next, rt.IntValue(int64(first.Start()+1)))
    67  			t.Push1(next, rt.IntValue(int64(first.End())))
    68  			pushExtraCaptures(t.Runtime, captures, s, next)
    69  		}
    70  	}
    71  	return next, nil
    72  }
    73  
    74  func match(t *rt.Thread, c *rt.GoCont) (rt.Cont, error) {
    75  	var (
    76  		s, ptn string
    77  		init   int64 = 1
    78  	)
    79  	err := c.CheckNArgs(2)
    80  	if err == nil {
    81  		s, err = c.StringArg(0)
    82  	}
    83  	if err == nil {
    84  		ptn, err = c.StringArg(1)
    85  	}
    86  	if err == nil && c.NArgs() >= 3 {
    87  		init, err = c.IntArg(2)
    88  	}
    89  	if err != nil {
    90  		return nil, err
    91  	}
    92  	si := luastrings.StringNormPos(s, int(init)) - 1
    93  	if si < 0 {
    94  		si = 0
    95  	}
    96  	next := c.Next()
    97  	pat, ptnErr := pattern.New(string(ptn))
    98  	if ptnErr != nil {
    99  		return nil, ptnErr
   100  	}
   101  	captures, usedCPU := pat.MatchFromStart(string(s), si, t.UnusedCPU())
   102  	t.RequireCPU(usedCPU)
   103  	pushCaptures(t.Runtime, captures, s, next)
   104  	return next, nil
   105  }
   106  
   107  func pushCaptures(r *rt.Runtime, captures []pattern.Capture, s string, next rt.Cont) {
   108  	switch len(captures) {
   109  	case 0:
   110  		r.Push1(next, rt.NilValue)
   111  	case 1:
   112  		c := captures[0]
   113  		r.RequireBytes(c.End() - c.Start())
   114  		r.Push1(next, rt.StringValue(s[c.Start():c.End()]))
   115  	default:
   116  		pushExtraCaptures(r, captures, s, next)
   117  	}
   118  }
   119  
   120  func pushExtraCaptures(r *rt.Runtime, captures []pattern.Capture, s string, next rt.Cont) {
   121  	if len(captures) < 2 {
   122  		return
   123  	}
   124  	for _, c := range captures[1:] {
   125  		r.Push1(next, captureValue(r, c, s))
   126  	}
   127  }
   128  
   129  func captureValue(r *rt.Runtime, c pattern.Capture, s string) rt.Value {
   130  	if c.IsEmpty() {
   131  		return rt.IntValue(int64(c.Start() + 1))
   132  	}
   133  	r.RequireBytes(c.End() - c.Start())
   134  	return rt.StringValue(s[c.Start():c.End()])
   135  }
   136  
   137  func gmatch(t *rt.Thread, c *rt.GoCont) (rt.Cont, error) {
   138  	var (
   139  		s, ptn string
   140  		err          = c.CheckNArgs(2)
   141  		init   int64 = 1
   142  	)
   143  	if err == nil {
   144  		s, err = c.StringArg(0)
   145  	}
   146  	if err == nil {
   147  		ptn, err = c.StringArg(1)
   148  	}
   149  	if err == nil && c.NArgs() >= 3 {
   150  		init, err = c.IntArg(2)
   151  	}
   152  	if err != nil {
   153  		return nil, err
   154  	}
   155  	pat, ptnErr := pattern.New(string(ptn))
   156  	if ptnErr != nil {
   157  		return nil, ptnErr
   158  	}
   159  	si := luastrings.StringNormPos(s, int(init)) - 1
   160  	if si < 0 {
   161  		si = 0
   162  	}
   163  	allowEmpty := true
   164  	var iterator = func(t *rt.Thread, c *rt.GoCont) (rt.Cont, error) {
   165  		next := c.Next()
   166  		var (
   167  			captures []pattern.Capture
   168  			usedCPU  uint64
   169  		)
   170  		for {
   171  			captures, usedCPU = pat.Match(s, si, t.UnusedCPU())
   172  			t.RequireCPU(usedCPU)
   173  			if len(captures) == 0 {
   174  				break
   175  			}
   176  			gc := captures[0]
   177  			start, end := gc.Start(), gc.End()
   178  			if allowEmpty || start != si || end != si {
   179  				allowEmpty = start >= end
   180  				if allowEmpty {
   181  					si = start + 1
   182  				} else {
   183  					si = end
   184  				}
   185  				break
   186  			}
   187  			si++
   188  			allowEmpty = true
   189  		}
   190  		pushCaptures(t.Runtime, captures, s, next)
   191  		return next, nil
   192  	}
   193  	iterGof := rt.NewGoFunction(iterator, "gmatchiterator", 0, false)
   194  	iterGof.SolemnlyDeclareCompliance(rt.ComplyCpuSafe | rt.ComplyMemSafe | rt.ComplyTimeSafe | rt.ComplyIoSafe)
   195  	return c.PushingNext(t.Runtime, rt.FunctionValue(iterGof)), nil
   196  }
   197  
   198  func gsub(t *rt.Thread, c *rt.GoCont) (rt.Cont, error) {
   199  	var (
   200  		s, ptn string
   201  		n      int64 = -1
   202  		repl   rt.Value
   203  	)
   204  	err := c.CheckNArgs(3)
   205  	if err == nil {
   206  		s, err = c.StringArg(0)
   207  	}
   208  	if err == nil {
   209  		ptn, err = c.StringArg(1)
   210  	}
   211  	if err == nil && c.NArgs() >= 4 {
   212  		n, err = c.IntArg(3)
   213  	}
   214  	if err != nil {
   215  		return nil, err
   216  	}
   217  	repl = c.Arg(2)
   218  	pat, ptnErr := pattern.New(string(ptn))
   219  	if ptnErr != nil {
   220  		return nil, ptnErr
   221  	}
   222  
   223  	// replF will be the function that does the substitution of the match given
   224  	// the captures in the match.  It must require the memory for the
   225  	// substitution string.  It returns the substituted string, true if a
   226  	// substitution was actually made, and a non-nil error if something went
   227  	// wrong.
   228  	var replF func([]pattern.Capture) (string, bool, error)
   229  
   230  	if replString, ok := repl.TryString(); ok {
   231  		replF = func(captures []pattern.Capture) (string, bool, error) {
   232  			cStrings := [10]string{}
   233  			maxIndex := len(captures) - 1
   234  			for i, c := range captures {
   235  				v := captureValue(t.Runtime, c, s)
   236  				switch v.Type() {
   237  				case rt.StringType:
   238  					cStrings[i] = v.AsString()
   239  				case rt.IntType:
   240  					cStrings[i] = strconv.Itoa(int(v.AsInt()))
   241  				}
   242  			}
   243  			if len(captures) == 1 {
   244  				cStrings[1] = cStrings[0]
   245  				maxIndex = 1
   246  			}
   247  			var err error
   248  			t.RequireCPU(uint64(len(replString)))
   249  			t.RequireBytes(len(replString))
   250  			return gsubPtn.ReplaceAllStringFunc(replString, func(x string) string {
   251  				if err != nil {
   252  					return ""
   253  				}
   254  				b := x[1]
   255  				switch {
   256  				case '0' <= b && b <= '9':
   257  					idx := int(b - '0')
   258  					if idx > maxIndex {
   259  						err = pattern.ErrInvalidCaptureIdx(idx)
   260  						return ""
   261  					}
   262  					s := cStrings[b-'0']
   263  					if len(s) > 2 {
   264  						t.RequireBytes(len(s) - 2)
   265  					}
   266  					return s
   267  				case b == '%':
   268  					return x[1:]
   269  				default:
   270  					err = pattern.ErrInvalidPct
   271  				}
   272  				return x[1:]
   273  			}), false, err
   274  		}
   275  	} else if replTable, ok := repl.TryTable(); ok {
   276  		replF = func(captures []pattern.Capture) (string, bool, error) {
   277  			gc := captures[0]
   278  			i := 0
   279  			if len(captures) >= 2 {
   280  				i = 1
   281  			}
   282  			c := captures[i]
   283  			val, err := rt.Index(t, rt.TableValue(replTable), captureValue(t.Runtime, c, s))
   284  			if err != nil {
   285  				return "", false, err
   286  			}
   287  			return subToString(t.Runtime, s[gc.Start():gc.End()], val)
   288  		}
   289  	} else if replC, ok := repl.TryCallable(); ok {
   290  		replF = func(captures []pattern.Capture) (string, bool, error) {
   291  			term := rt.NewTerminationWith(c, 1, false)
   292  			cont := replC.Continuation(t, term)
   293  			gc := captures[0]
   294  			i := 0
   295  			if len(captures) >= 2 {
   296  				i = 1
   297  			}
   298  			for _, c := range captures[i:] {
   299  				t.Push1(cont, captureValue(t.Runtime, c, s))
   300  			}
   301  			err := t.RunContinuation(cont)
   302  			if err != nil {
   303  				return "", false, err
   304  			}
   305  			return subToString(t.Runtime, s[gc.Start():gc.End()], term.Get(0))
   306  		}
   307  	} else {
   308  		return nil, errors.New("#3 must be a string, table or function")
   309  	}
   310  	var (
   311  		si         int             // Index in s where to start finding the next match
   312  		sj         int             // Index in s of the first byte not yet copied
   313  		sb         strings.Builder // Build the result string into this
   314  		matchCount int64
   315  		allowEmpty = true
   316  	)
   317  	// We require memory for the string we build as we go along.  In order to
   318  	// save allocations in case there are no substitutions, we do not start
   319  	// copying the string until one substitution has actually taken place.  This
   320  	// is achieved by keeping the variable sj the same until bytes are written
   321  	// in the string builder.
   322  	for ; matchCount != n; matchCount++ {
   323  		captures, usedCPU := pat.Match(string(s), si, t.UnusedCPU())
   324  		t.RequireCPU(usedCPU)
   325  		if len(captures) == 0 {
   326  			break
   327  		}
   328  		gc := captures[0]
   329  		start, end := gc.Start(), gc.End()
   330  		if allowEmpty || start != si || end != si {
   331  			sub, same, err := replF(captures)
   332  			if err != nil {
   333  				return nil, err
   334  			}
   335  			if !same {
   336  				t.RequireBytes(start - sj)
   337  				// No need to require memory for sub as that has been done already
   338  				// by replF
   339  				_, _ = sb.WriteString(s[sj:start])
   340  				_, _ = sb.WriteString(sub)
   341  				sj = end
   342  			}
   343  		}
   344  		allowEmpty = start >= end
   345  		if allowEmpty {
   346  			si = start + 1
   347  		} else {
   348  			si = end
   349  		}
   350  	}
   351  	var res rt.Value
   352  	switch {
   353  	case sb.Len() == 0:
   354  		// We return the input string to save an allocation.
   355  		res = c.Arg(0)
   356  	case sj < len(s):
   357  		t.RequireBytes(len(s) - sj)
   358  		_, _ = sb.WriteString(s[sj:])
   359  		fallthrough
   360  	default:
   361  		res = rt.StringValue(sb.String())
   362  	}
   363  	next := c.Next()
   364  	// Already required memory for the string below.
   365  	t.Push1(next, res)
   366  	t.Push1(next, rt.IntValue(matchCount))
   367  	return next, nil
   368  }
   369  
   370  var gsubPtn = regexp.MustCompile("%.")
   371  
   372  func subToString(r *rt.Runtime, key string, val rt.Value) (string, bool, error) {
   373  	if !rt.Truth(val) {
   374  		return key, true, nil
   375  	}
   376  	res, ok := val.ToString()
   377  	if ok {
   378  		r.RequireBytes(len(res))
   379  		return res, false, nil
   380  	}
   381  	return "", false, fmt.Errorf("invalid replacement value (a %s)", val.TypeName())
   382  }