vitess.io/vitess@v0.16.2/go/mysql/collations/uca_contraction_test.go (about)

     1  /*
     2  Copyright 2021 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package collations
    18  
    19  import (
    20  	"encoding/json"
    21  	"fmt"
    22  	"math/rand"
    23  	"os"
    24  	"reflect"
    25  	"sort"
    26  	"testing"
    27  	"unicode/utf8"
    28  
    29  	"github.com/stretchr/testify/assert"
    30  
    31  	"vitess.io/vitess/go/mysql/collations/internal/charset"
    32  	"vitess.io/vitess/go/mysql/collations/internal/uca"
    33  )
    34  
    35  type CollationWithContractions struct {
    36  	Collation    Collation
    37  	Contractions []uca.Contraction
    38  	ContractFast uca.Contractor
    39  	ContractTrie uca.Contractor
    40  }
    41  
    42  func findContractedCollations(t testing.TB, unique bool) (result []CollationWithContractions) {
    43  	type collationMetadata struct {
    44  		Contractions []uca.Contraction
    45  	}
    46  
    47  	var seen = make(map[string]bool)
    48  
    49  	for _, collation := range testall() {
    50  		var contract uca.Contractor
    51  		if uca, ok := collation.(*Collation_utf8mb4_uca_0900); ok {
    52  			contract = uca.contract
    53  		}
    54  		if uca, ok := collation.(*Collation_uca_legacy); ok {
    55  			contract = uca.contract
    56  		}
    57  		if contract == nil {
    58  			continue
    59  		}
    60  
    61  		rf, err := os.Open(fmt.Sprintf("testdata/mysqldata/%s.json", collation.Name()))
    62  		if err != nil {
    63  			t.Skipf("failed to open JSON metadata (%v). did you run colldump?", err)
    64  		}
    65  
    66  		var meta collationMetadata
    67  		if err := json.NewDecoder(rf).Decode(&meta); err != nil {
    68  			t.Fatal(err)
    69  		}
    70  		rf.Close()
    71  
    72  		if unique {
    73  			raw := fmt.Sprintf("%#v", meta.Contractions)
    74  			if seen[raw] {
    75  				continue
    76  			}
    77  			seen[raw] = true
    78  		}
    79  
    80  		for n := range meta.Contractions {
    81  			ctr := &meta.Contractions[n]
    82  			for i := 0; i < len(ctr.Weights)-3; i += 3 {
    83  				if ctr.Weights[i] == 0x0 && ctr.Weights[i+1] == 0x0 && ctr.Weights[i+2] == 0x0 {
    84  					ctr.Weights = ctr.Weights[:i]
    85  					break
    86  				}
    87  			}
    88  		}
    89  
    90  		result = append(result, CollationWithContractions{
    91  			Collation:    collation,
    92  			Contractions: meta.Contractions,
    93  			ContractFast: contract,
    94  			ContractTrie: uca.NewTrieContractor(meta.Contractions),
    95  		})
    96  	}
    97  	return
    98  }
    99  
   100  func testMatch(t *testing.T, name string, cnt uca.Contraction, result []uint16, remainder []byte, skip int) {
   101  	assert.True(t, reflect.DeepEqual(cnt.Weights, result), "%s didn't match: expected %#v, got %#v", name, cnt.Weights, result)
   102  	assert.Equal(t, 0, len(remainder), "%s bad remainder: %#v", name, remainder)
   103  	assert.Equal(t, len(cnt.Path), skip, "%s bad skipped length %d for %#v", name, skip, cnt.Path)
   104  
   105  }
   106  
   107  func TestUCAContractions(t *testing.T) {
   108  	for _, cwc := range findContractedCollations(t, false) {
   109  		t.Run(cwc.Collation.Name(), func(t *testing.T) {
   110  			for _, cnt := range cwc.Contractions {
   111  				if cnt.Contextual {
   112  					head := cnt.Path[0]
   113  					tail := cnt.Path[1]
   114  
   115  					result := cwc.ContractTrie.FindContextual(head, tail)
   116  					testMatch(t, "ContractTrie", cnt, result, nil, 2)
   117  
   118  					result = cwc.ContractFast.FindContextual(head, tail)
   119  					testMatch(t, "ContractFast", cnt, result, nil, 2)
   120  					continue
   121  				}
   122  
   123  				head := cnt.Path[0]
   124  				tail := string(cnt.Path[1:])
   125  
   126  				result, remainder, skip := cwc.ContractTrie.Find(charset.Charset_utf8mb4{}, head, []byte(tail))
   127  				testMatch(t, "ContractTrie", cnt, result, remainder, skip)
   128  
   129  				result, remainder, skip = cwc.ContractFast.Find(charset.Charset_utf8mb4{}, head, []byte(tail))
   130  				testMatch(t, "ContractFast", cnt, result, remainder, skip)
   131  			}
   132  		})
   133  	}
   134  }
   135  
   136  func benchmarkFind(b *testing.B, input []byte, contract uca.Contractor) {
   137  	b.SetBytes(int64(len(input)))
   138  	b.ReportAllocs()
   139  	b.ResetTimer()
   140  
   141  	for n := 0; n < b.N; n++ {
   142  		in := input
   143  		for len(in) > 0 {
   144  			cp, width := utf8.DecodeRune(in)
   145  			in = in[width:]
   146  			_, _, _ = contract.Find(charset.Charset_utf8mb4{}, cp, in)
   147  		}
   148  	}
   149  }
   150  
   151  func benchmarkFindJA(b *testing.B, input []byte, contract uca.Contractor) {
   152  	b.SetBytes(int64(len(input)))
   153  	b.ReportAllocs()
   154  	b.ResetTimer()
   155  
   156  	for n := 0; n < b.N; n++ {
   157  		prev := rune(0)
   158  		in := input
   159  		for len(in) > 0 {
   160  			cp, width := utf8.DecodeRune(in)
   161  			_ = contract.FindContextual(cp, prev)
   162  			prev = cp
   163  			in = in[width:]
   164  		}
   165  	}
   166  }
   167  
   168  type strgen struct {
   169  	repertoire   map[rune]struct{}
   170  	contractions []string
   171  }
   172  
   173  func newStrgen() *strgen {
   174  	return &strgen{repertoire: make(map[rune]struct{})}
   175  }
   176  
   177  func (s *strgen) withASCII() *strgen {
   178  	for r := rune(0); r < utf8.RuneSelf; r++ {
   179  		s.repertoire[r] = struct{}{}
   180  	}
   181  	return s
   182  }
   183  
   184  func (s *strgen) withContractions(all []uca.Contraction) *strgen {
   185  	for _, cnt := range all {
   186  		for _, r := range cnt.Path {
   187  			s.repertoire[r] = struct{}{}
   188  		}
   189  
   190  		if cnt.Contextual {
   191  			s.contractions = append(s.contractions, string([]rune{cnt.Path[1], cnt.Path[0]}))
   192  		} else {
   193  			s.contractions = append(s.contractions, string(cnt.Path))
   194  		}
   195  	}
   196  	return s
   197  }
   198  
   199  func (s *strgen) withText(in string) *strgen {
   200  	for _, r := range in {
   201  		s.repertoire[r] = struct{}{}
   202  	}
   203  	return s
   204  }
   205  
   206  func (s *strgen) generate(length int, freq float64) (out []byte) {
   207  	var flat []rune
   208  	for r := range s.repertoire {
   209  		flat = append(flat, r)
   210  	}
   211  	sort.Slice(flat, func(i, j int) bool {
   212  		return flat[i] < flat[j]
   213  	})
   214  
   215  	gen := rand.New(rand.NewSource(0xDEADBEEF))
   216  	out = make([]byte, 0, length)
   217  	for len(out) < length {
   218  		if gen.Float64() < freq {
   219  			cnt := s.contractions[rand.Intn(len(s.contractions))]
   220  			out = append(out, cnt...)
   221  		} else {
   222  			cp := flat[rand.Intn(len(flat))]
   223  			out = append(out, string(cp)...)
   224  		}
   225  	}
   226  	return
   227  }
   228  
   229  func BenchmarkUCAContractions(b *testing.B) {
   230  	for _, cwc := range findContractedCollations(b, true) {
   231  		if cwc.Contractions[0].Contextual {
   232  			continue
   233  		}
   234  
   235  		gen := newStrgen().withASCII().withContractions(cwc.Contractions)
   236  		frequency := 0.05
   237  		input := gen.generate(1024*32, 0.05)
   238  
   239  		b.Run(fmt.Sprintf("%s-%.02f-fast", cwc.Collation.Name(), frequency), func(b *testing.B) {
   240  			benchmarkFind(b, input, cwc.ContractFast)
   241  		})
   242  
   243  		b.Run(fmt.Sprintf("%s-%.02f-trie", cwc.Collation.Name(), frequency), func(b *testing.B) {
   244  			benchmarkFind(b, input, cwc.ContractTrie)
   245  		})
   246  	}
   247  }
   248  
   249  func BenchmarkUCAContractionsJA(b *testing.B) {
   250  	for _, cwc := range findContractedCollations(b, true) {
   251  		if !cwc.Contractions[0].Contextual {
   252  			continue
   253  		}
   254  
   255  		gen := newStrgen().withASCII().withText(JapaneseString).withText(JapaneseString2).withContractions(cwc.Contractions)
   256  		frequency := 0.05
   257  		input := gen.generate(1024*32, 0.05)
   258  
   259  		b.Run(fmt.Sprintf("%s-%.02f-fast", cwc.Collation.Name(), frequency), func(b *testing.B) {
   260  			benchmarkFindJA(b, input, cwc.ContractFast)
   261  		})
   262  
   263  		b.Run(fmt.Sprintf("%s-%.02f-trie", cwc.Collation.Name(), frequency), func(b *testing.B) {
   264  			benchmarkFindJA(b, input, cwc.ContractTrie)
   265  		})
   266  	}
   267  }