github.com/pgavlin/text@v0.0.0-20240419000839-8438d0a47805/replace.go (about)

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package text
     6  
     7  import (
     8  	"io"
     9  	"sync"
    10  )
    11  
    12  // Replacer replaces a list of strings with replacements.
    13  // It is safe for concurrent use by multiple goroutines.
    14  type Replacer[S String] struct {
    15  	once   sync.Once // guards buildOnce method
    16  	r      replacer[S]
    17  	oldnew []S
    18  }
    19  
    20  // replacer is the interface that a replacement algorithm needs to implement.
    21  type replacer[S String] interface {
    22  	Replace(s S) S
    23  	WriteString(w io.Writer, s S) (n int, err error)
    24  }
    25  
    26  // NewReplacer returns a new Replacer from a list of old, new string
    27  // pairs. Replacements are performed in the order they appear in the
    28  // target string, without overlapping matches. The old string
    29  // comparisons are done in argument order.
    30  //
    31  // NewReplacer panics if given an odd number of arguments.
    32  func NewReplacer[S String](oldnew ...S) *Replacer[S] {
    33  	if len(oldnew)%2 == 1 {
    34  		panic("strings.NewReplacer: odd argument count")
    35  	}
    36  	return &Replacer[S]{oldnew: append([]S(nil), oldnew...)}
    37  }
    38  
    39  func (r *Replacer[S]) buildOnce() {
    40  	r.r = r.build()
    41  	r.oldnew = nil
    42  }
    43  
    44  func (b *Replacer[S]) build() replacer[S] {
    45  	oldnew := b.oldnew
    46  	if len(oldnew) == 2 && len(oldnew[0]) > 1 {
    47  		return makeSingleStringReplacer(oldnew[0], oldnew[1])
    48  	}
    49  
    50  	allNewBytes := true
    51  	for i := 0; i < len(oldnew); i += 2 {
    52  		if len(oldnew[i]) != 1 {
    53  			return makeGenericReplacer(oldnew)
    54  		}
    55  		if len(oldnew[i+1]) != 1 {
    56  			allNewBytes = false
    57  		}
    58  	}
    59  
    60  	if allNewBytes {
    61  		r := byteReplacer[S]{}
    62  		for i := range r.m {
    63  			r.m[i] = byte(i)
    64  		}
    65  		// The first occurrence of old->new map takes precedence
    66  		// over the others with the same old string.
    67  		for i := len(oldnew) - 2; i >= 0; i -= 2 {
    68  			o := oldnew[i][0]
    69  			n := oldnew[i+1][0]
    70  			r.m[o] = n
    71  		}
    72  		return &r
    73  	}
    74  
    75  	r := byteStringReplacer[S]{toReplace: make([]S, 0, len(oldnew)/2)}
    76  	// The first occurrence of old->new map takes precedence
    77  	// over the others with the same old string.
    78  	for i := len(oldnew) - 2; i >= 0; i -= 2 {
    79  		o := oldnew[i][0]
    80  		n := oldnew[i+1]
    81  		// To avoid counting repetitions multiple times.
    82  		if r.replacements[o] == nil {
    83  			// We need to use string([]byte{o}) instead of string(o),
    84  			// to avoid utf8 encoding of o.
    85  			// E. g. byte(150) produces string of length 2.
    86  			r.toReplace = append(r.toReplace, S([]byte{o}))
    87  		}
    88  		r.replacements[o] = []byte(n)
    89  
    90  	}
    91  	return &r
    92  }
    93  
    94  // Replace returns a copy of s with all replacements performed.
    95  func (r *Replacer[S]) Replace(s S) S {
    96  	r.once.Do(r.buildOnce)
    97  	return r.r.Replace(s)
    98  }
    99  
   100  // WriteString writes s to w with all replacements performed.
   101  func (r *Replacer[S]) WriteString(w io.Writer, s S) (n int, err error) {
   102  	r.once.Do(r.buildOnce)
   103  	return r.r.WriteString(w, s)
   104  }
   105  
   106  // trieNode is a node in a lookup trie for prioritized key/value pairs. Keys
   107  // and values may be empty. For example, the trie containing keys "ax", "ay",
   108  // "bcbc", "x" and "xy" could have eight nodes:
   109  //
   110  //	n0  -
   111  //	n1  a-
   112  //	n2  .x+
   113  //	n3  .y+
   114  //	n4  b-
   115  //	n5  .cbc+
   116  //	n6  x+
   117  //	n7  .y+
   118  //
   119  // n0 is the root node, and its children are n1, n4 and n6; n1's children are
   120  // n2 and n3; n4's child is n5; n6's child is n7. Nodes n0, n1 and n4 (marked
   121  // with a trailing "-") are partial keys, and nodes n2, n3, n5, n6 and n7
   122  // (marked with a trailing "+") are complete keys.
   123  type trieNode[S String] struct {
   124  	// value is the value of the trie node's key/value pair. It is empty if
   125  	// this node is not a complete key.
   126  	value S
   127  	// priority is the priority (higher is more important) of the trie node's
   128  	// key/value pair; keys are not necessarily matched shortest- or longest-
   129  	// first. Priority is positive if this node is a complete key, and zero
   130  	// otherwise. In the example above, positive/zero priorities are marked
   131  	// with a trailing "+" or "-".
   132  	priority int
   133  
   134  	// A trie node may have zero, one or more child nodes:
   135  	//  * if the remaining fields are zero, there are no children.
   136  	//  * if prefix and next are non-zero, there is one child in next.
   137  	//  * if table is non-zero, it defines all the children.
   138  	//
   139  	// Prefixes are preferred over tables when there is one child, but the
   140  	// root node always uses a table for lookup efficiency.
   141  
   142  	// prefix is the difference in keys between this trie node and the next.
   143  	// In the example above, node n4 has prefix "cbc" and n4's next node is n5.
   144  	// Node n5 has no children and so has zero prefix, next and table fields.
   145  	prefix S
   146  	next   *trieNode[S]
   147  
   148  	// table is a lookup table indexed by the next byte in the key, after
   149  	// remapping that byte through genericReplacer.mapping to create a dense
   150  	// index. In the example above, the keys only use 'a', 'b', 'c', 'x' and
   151  	// 'y', which remap to 0, 1, 2, 3 and 4. All other bytes remap to 5, and
   152  	// genericReplacer.tableSize will be 5. Node n0's table will be
   153  	// []*trieNode[S]{ 0:n1, 1:n4, 3:n6 }, where the 0, 1 and 3 are the remapped
   154  	// 'a', 'b' and 'x'.
   155  	table []*trieNode[S]
   156  }
   157  
   158  func (t *trieNode[S]) add(key, val S, priority int, r *genericReplacer[S]) {
   159  	if IsEmpty(key) {
   160  		if t.priority == 0 {
   161  			t.value = val
   162  			t.priority = priority
   163  		}
   164  		return
   165  	}
   166  
   167  	if !IsEmpty(t.prefix) {
   168  		// Need to split the prefix among multiple nodes.
   169  		var n int // length of the longest common prefix
   170  		for ; n < len(t.prefix) && n < len(key); n++ {
   171  			if t.prefix[n] != key[n] {
   172  				break
   173  			}
   174  		}
   175  		if n == len(t.prefix) {
   176  			t.next.add(key[n:], val, priority, r)
   177  		} else if n == 0 {
   178  			// First byte differs, start a new lookup table here. Looking up
   179  			// what is currently t.prefix[0] will lead to prefixNode, and
   180  			// looking up key[0] will lead to keyNode.
   181  			var prefixNode *trieNode[S]
   182  			if len(t.prefix) == 1 {
   183  				prefixNode = t.next
   184  			} else {
   185  				prefixNode = &trieNode[S]{
   186  					prefix: t.prefix[1:],
   187  					next:   t.next,
   188  				}
   189  			}
   190  			keyNode := new(trieNode[S])
   191  			t.table = make([]*trieNode[S], r.tableSize)
   192  			t.table[r.mapping[t.prefix[0]]] = prefixNode
   193  			t.table[r.mapping[key[0]]] = keyNode
   194  			t.prefix = Empty[S]()
   195  			t.next = nil
   196  			keyNode.add(key[1:], val, priority, r)
   197  		} else {
   198  			// Insert new node after the common section of the prefix.
   199  			next := &trieNode[S]{
   200  				prefix: t.prefix[n:],
   201  				next:   t.next,
   202  			}
   203  			t.prefix = t.prefix[:n]
   204  			t.next = next
   205  			next.add(key[n:], val, priority, r)
   206  		}
   207  	} else if t.table != nil {
   208  		// Insert into existing table.
   209  		m := r.mapping[key[0]]
   210  		if t.table[m] == nil {
   211  			t.table[m] = new(trieNode[S])
   212  		}
   213  		t.table[m].add(key[1:], val, priority, r)
   214  	} else {
   215  		t.prefix = key
   216  		t.next = new(trieNode[S])
   217  		t.next.add(Empty[S](), val, priority, r)
   218  	}
   219  }
   220  
   221  func (r *genericReplacer[S]) lookup(s S, ignoreRoot bool) (val S, keylen int, found bool) {
   222  	// Iterate down the trie to the end, and grab the value and keylen with
   223  	// the highest priority.
   224  	bestPriority := 0
   225  	node := &r.root
   226  	n := 0
   227  	for node != nil {
   228  		if node.priority > bestPriority && !(ignoreRoot && node == &r.root) {
   229  			bestPriority = node.priority
   230  			val = node.value
   231  			keylen = n
   232  			found = true
   233  		}
   234  
   235  		if IsEmpty(s) {
   236  			break
   237  		}
   238  		if node.table != nil {
   239  			index := r.mapping[s[0]]
   240  			if int(index) == r.tableSize {
   241  				break
   242  			}
   243  			node = node.table[index]
   244  			s = s[1:]
   245  			n++
   246  		} else if len(node.prefix) != 0 && HasPrefix(s, node.prefix) {
   247  			n += len(node.prefix)
   248  			s = s[len(node.prefix):]
   249  			node = node.next
   250  		} else {
   251  			break
   252  		}
   253  	}
   254  	return
   255  }
   256  
   257  // genericReplacer is the fully generic algorithm.
   258  // It's used as a fallback when nothing faster can be used.
   259  type genericReplacer[S String] struct {
   260  	root trieNode[S]
   261  	// tableSize is the size of a trie node's lookup table. It is the number
   262  	// of unique key bytes.
   263  	tableSize int
   264  	// mapping maps from key bytes to a dense index for trieNode[S].table.
   265  	mapping [256]byte
   266  }
   267  
   268  func makeGenericReplacer[S String](oldnew []S) *genericReplacer[S] {
   269  	r := new(genericReplacer[S])
   270  	// Find each byte used, then assign them each an index.
   271  	for i := 0; i < len(oldnew); i += 2 {
   272  		key := oldnew[i]
   273  		for j := 0; j < len(key); j++ {
   274  			r.mapping[key[j]] = 1
   275  		}
   276  	}
   277  
   278  	for _, b := range r.mapping {
   279  		r.tableSize += int(b)
   280  	}
   281  
   282  	var index byte
   283  	for i, b := range r.mapping {
   284  		if b == 0 {
   285  			r.mapping[i] = byte(r.tableSize)
   286  		} else {
   287  			r.mapping[i] = index
   288  			index++
   289  		}
   290  	}
   291  	// Ensure root node uses a lookup table (for performance).
   292  	r.root.table = make([]*trieNode[S], r.tableSize)
   293  
   294  	for i := 0; i < len(oldnew); i += 2 {
   295  		r.root.add(oldnew[i], oldnew[i+1], len(oldnew)-i, r)
   296  	}
   297  	return r
   298  }
   299  
   300  type appendSliceWriter[S String] struct {
   301  	b []byte
   302  }
   303  
   304  // Write writes to the buffer to satisfy io.Writer.
   305  func (w *appendSliceWriter[S]) Write(p []byte) (int, error) {
   306  	w.b = append(w.b, p...)
   307  	return len(p), nil
   308  }
   309  
   310  // WriteString writes to the buffer without string->[]byte->string allocations.
   311  func (w *appendSliceWriter[S]) WriteString(s string) (int, error) {
   312  	w.b = append(w.b, s...)
   313  	return len(s), nil
   314  }
   315  
   316  // WriteText writes to the buffer without string->[]byte->string allocations.
   317  func (w *appendSliceWriter[S]) WriteText(s S) (int, error) {
   318  	w.b = append(w.b, s...)
   319  	return len(s), nil
   320  }
   321  
   322  type textWriter[S String] struct {
   323  	w io.Writer
   324  }
   325  
   326  func (w textWriter[S]) WriteText(s S) (int, error) {
   327  	return w.w.Write([]byte(s))
   328  }
   329  
   330  func getTextWriter[S String](w io.Writer) Writer[S] {
   331  	tw, ok := w.(Writer[S])
   332  	if ok {
   333  		return tw
   334  	}
   335  	sw, ok := w.(io.StringWriter)
   336  	if ok {
   337  		return AsWriter[S](sw)
   338  	}
   339  	return textWriter[S]{w: w}
   340  }
   341  
   342  func (r *genericReplacer[S]) Replace(s S) S {
   343  	w := appendSliceWriter[S]{b: make([]byte, 0, len(s))}
   344  	r.WriteString(&w, s)
   345  	return S(w.b)
   346  }
   347  
   348  func (r *genericReplacer[S]) WriteString(w io.Writer, s S) (n int, err error) {
   349  	sw := getTextWriter[S](w)
   350  	var last, wn int
   351  	var prevMatchEmpty bool
   352  	for i := 0; i <= len(s); {
   353  		// Fast path: s[i] is not a prefix of any pattern.
   354  		if i != len(s) && r.root.priority == 0 {
   355  			index := int(r.mapping[s[i]])
   356  			if index == r.tableSize || r.root.table[index] == nil {
   357  				i++
   358  				continue
   359  			}
   360  		}
   361  
   362  		// Ignore the empty match iff the previous loop found the empty match.
   363  		val, keylen, match := r.lookup(s[i:], prevMatchEmpty)
   364  		prevMatchEmpty = match && keylen == 0
   365  		if match {
   366  			wn, err = sw.WriteText(s[last:i])
   367  			n += wn
   368  			if err != nil {
   369  				return
   370  			}
   371  			wn, err = sw.WriteText(val)
   372  			n += wn
   373  			if err != nil {
   374  				return
   375  			}
   376  			i += keylen
   377  			last = i
   378  			continue
   379  		}
   380  		i++
   381  	}
   382  	if last != len(s) {
   383  		wn, err = sw.WriteText(s[last:])
   384  		n += wn
   385  	}
   386  	return
   387  }
   388  
   389  // singleStringReplacer is the implementation that's used when there is only
   390  // one string to replace (and that string has more than one byte).
   391  type singleStringReplacer[S String] struct {
   392  	finder *stringFinder[S]
   393  	// value is the new string that replaces that pattern when it's found.
   394  	value S
   395  }
   396  
   397  func makeSingleStringReplacer[S String](pattern, value S) *singleStringReplacer[S] {
   398  	return &singleStringReplacer[S]{finder: makeStringFinder(pattern), value: value}
   399  }
   400  
   401  func (r *singleStringReplacer[S]) Replace(s S) S {
   402  	var buf Builder[S]
   403  	i, matched := 0, false
   404  	for {
   405  		match := r.finder.next(s[i:])
   406  		if match == -1 {
   407  			break
   408  		}
   409  		matched = true
   410  		buf.Grow(match + len(r.value))
   411  		buf.WriteText(s[i : i+match])
   412  		buf.WriteText(r.value)
   413  		i += match + len(r.finder.pattern)
   414  	}
   415  	if !matched {
   416  		return s
   417  	}
   418  	buf.WriteText(s[i:])
   419  	return buf.Text()
   420  }
   421  
   422  func (r *singleStringReplacer[S]) WriteString(w io.Writer, s S) (n int, err error) {
   423  	sw := getTextWriter[S](w)
   424  	var i, wn int
   425  	for {
   426  		match := r.finder.next(s[i:])
   427  		if match == -1 {
   428  			break
   429  		}
   430  		wn, err = sw.WriteText(s[i : i+match])
   431  		n += wn
   432  		if err != nil {
   433  			return
   434  		}
   435  		wn, err = sw.WriteText(r.value)
   436  		n += wn
   437  		if err != nil {
   438  			return
   439  		}
   440  		i += match + len(r.finder.pattern)
   441  	}
   442  	wn, err = sw.WriteText(s[i:])
   443  	n += wn
   444  	return
   445  }
   446  
   447  // byteReplacer is the implementation that's used when all the "old"
   448  // and "new" values are single ASCII bytes.
   449  // The array contains replacement bytes indexed by old byte.
   450  type byteReplacer[S String] struct {
   451  	m [256]byte
   452  }
   453  
   454  func (r *byteReplacer[S]) Replace(s S) S {
   455  	var buf []byte // lazily allocated
   456  	for i := 0; i < len(s); i++ {
   457  		b := s[i]
   458  		if r.m[b] != b {
   459  			if buf == nil {
   460  				buf = []byte(s)
   461  			}
   462  			buf[i] = r.m[b]
   463  		}
   464  	}
   465  	if buf == nil {
   466  		return s
   467  	}
   468  	return S(buf)
   469  }
   470  
   471  func (r *byteReplacer[S]) WriteString(w io.Writer, s S) (n int, err error) {
   472  	sw := getTextWriter[S](w)
   473  	last := 0
   474  	for i := 0; i < len(s); i++ {
   475  		b := s[i]
   476  		if r.m[b] == b {
   477  			continue
   478  		}
   479  		if last != i {
   480  			wn, err := sw.WriteText(s[last:i])
   481  			n += wn
   482  			if err != nil {
   483  				return n, err
   484  			}
   485  		}
   486  		last = i + 1
   487  		nw, err := w.Write(r.m[b : int(b)+1])
   488  		n += nw
   489  		if err != nil {
   490  			return n, err
   491  		}
   492  	}
   493  	if last != len(s) {
   494  		nw, err := sw.WriteText(s[last:])
   495  		n += nw
   496  		if err != nil {
   497  			return n, err
   498  		}
   499  	}
   500  	return n, nil
   501  }
   502  
   503  // byteStringReplacer is the implementation that's used when all the
   504  // "old" values are single ASCII bytes but the "new" values vary in size.
   505  type byteStringReplacer[S String] struct {
   506  	// replacements contains replacement byte slices indexed by old byte.
   507  	// A nil []byte means that the old byte should not be replaced.
   508  	replacements [256][]byte
   509  	// toReplace keeps a list of bytes to replace. Depending on length of toReplace
   510  	// and length of target string it may be faster to use Count, or a plain loop.
   511  	// We store single byte as a string, because Count takes a string.
   512  	toReplace []S
   513  }
   514  
   515  // countCutOff controls the ratio of a string length to a number of replacements
   516  // at which (*byteStringReplacer[S]).Replace switches algorithms.
   517  // For strings with higher ration of length to replacements than that value,
   518  // we call Count, for each replacement from toReplace.
   519  // For strings, with a lower ratio we use simple loop, because of Count overhead.
   520  // countCutOff is an empirically determined overhead multiplier.
   521  // TODO(tocarip) revisit once we have register-based abi/mid-stack inlining.
   522  const countCutOff = 8
   523  
   524  func (r *byteStringReplacer[S]) Replace(s S) S {
   525  	newSize := len(s)
   526  	anyChanges := false
   527  	// Is it faster to use Count?
   528  	if len(r.toReplace)*countCutOff <= len(s) {
   529  		for _, x := range r.toReplace {
   530  			if c := Count(s, x); c != 0 {
   531  				// The -1 is because we are replacing 1 byte with len(replacements[b]) bytes.
   532  				newSize += c * (len(r.replacements[x[0]]) - 1)
   533  				anyChanges = true
   534  			}
   535  
   536  		}
   537  	} else {
   538  		for i := 0; i < len(s); i++ {
   539  			b := s[i]
   540  			if r.replacements[b] != nil {
   541  				// See above for explanation of -1
   542  				newSize += len(r.replacements[b]) - 1
   543  				anyChanges = true
   544  			}
   545  		}
   546  	}
   547  	if !anyChanges {
   548  		return s
   549  	}
   550  	buf := make([]byte, newSize)
   551  	j := 0
   552  	for i := 0; i < len(s); i++ {
   553  		b := s[i]
   554  		if r.replacements[b] != nil {
   555  			j += copy(buf[j:], r.replacements[b])
   556  		} else {
   557  			buf[j] = b
   558  			j++
   559  		}
   560  	}
   561  	return S(buf)
   562  }
   563  
   564  func (r *byteStringReplacer[S]) WriteString(w io.Writer, s S) (n int, err error) {
   565  	sw := getTextWriter[S](w)
   566  	last := 0
   567  	for i := 0; i < len(s); i++ {
   568  		b := s[i]
   569  		if r.replacements[b] == nil {
   570  			continue
   571  		}
   572  		if last != i {
   573  			nw, err := sw.WriteText(s[last:i])
   574  			n += nw
   575  			if err != nil {
   576  				return n, err
   577  			}
   578  		}
   579  		last = i + 1
   580  		nw, err := w.Write(r.replacements[b])
   581  		n += nw
   582  		if err != nil {
   583  			return n, err
   584  		}
   585  	}
   586  	if last != len(s) {
   587  		var nw int
   588  		nw, err = sw.WriteText(s[last:])
   589  		n += nw
   590  	}
   591  	return
   592  }