github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/libraries/doltcore/schema/collation_comparator.go (about) 1 // Copyright 2022 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package schema 16 17 import ( 18 "bytes" 19 "unicode/utf8" 20 21 "github.com/dolthub/dolt/go/store/val" 22 23 "github.com/dolthub/go-mysql-server/sql" 24 ) 25 26 type CollationTupleComparator struct { 27 Collations []sql.CollationID // CollationIDs are implemented as uint16 28 } 29 30 var _ val.TupleComparator = CollationTupleComparator{} 31 32 // Compare implements TupleComparator 33 func (c CollationTupleComparator) Compare(left, right val.Tuple, desc val.TupleDesc) (cmp int) { 34 fast := desc.GetFixedAccess() 35 for i := range fast { 36 start, stop := fast[i][0], fast[i][1] 37 cmp = collationCompare(desc.Types[i], c.Collations[i], left[start:stop], right[start:stop]) 38 if cmp != 0 { 39 return cmp 40 } 41 } 42 43 off := len(fast) 44 for i, typ := range desc.Types[off:] { 45 j := i + off 46 cmp = collationCompare(typ, c.Collations[j], left.GetField(j), right.GetField(j)) 47 if cmp != 0 { 48 return cmp 49 } 50 } 51 return 52 } 53 54 // CompareValues implements TupleComparator 55 func (c CollationTupleComparator) CompareValues(index int, left, right []byte, typ val.Type) int { 56 return collationCompare(typ, c.Collations[index], left, right) 57 } 58 59 // Prefix implements TupleComparator 60 func (c CollationTupleComparator) Prefix(n int) val.TupleComparator { 61 newCollations := make([]sql.CollationID, n) 62 copy(newCollations, c.Collations) 63 return CollationTupleComparator{newCollations} 64 } 65 66 // Suffix implements TupleComparator 67 func (c CollationTupleComparator) Suffix(n int) val.TupleComparator { 68 newCollations := make([]sql.CollationID, n) 69 copy(newCollations, c.Collations[len(c.Collations)-n:]) 70 return CollationTupleComparator{newCollations} 71 } 72 73 // Validated implements TupleComparator 74 func (c CollationTupleComparator) Validated(types []val.Type) val.TupleComparator { 75 if len(c.Collations) > len(types) { 76 panic("too many collations compared to type encoding") 77 } 78 i := 0 79 for ; i < len(c.Collations); i++ { 80 if types[i].Enc == val.StringEnc && c.Collations[i] == sql.Collation_Unspecified { 81 c.Collations[i] = sql.Collation_Default 82 } 83 } 84 if len(c.Collations) == len(types) { 85 return c 86 } 87 newCollations := make([]sql.CollationID, len(types)) 88 copy(newCollations, c.Collations) 89 for ; i < len(newCollations); i++ { 90 if types[i].Enc == val.StringEnc { 91 panic("string type encoding is missing its collation") 92 } 93 newCollations[i] = sql.Collation_Unspecified 94 } 95 return CollationTupleComparator{Collations: newCollations} 96 } 97 98 func collationCompare(typ val.Type, collation sql.CollationID, left, right []byte) int { 99 // order NULLs first 100 if left == nil || right == nil { 101 if bytes.Equal(left, right) { 102 return 0 103 } else if left == nil { 104 return -1 105 } else { 106 return 1 107 } 108 } 109 110 if typ.Enc == val.StringEnc { 111 return compareCollatedStrings(collation, left[:len(left)-1], right[:len(right)-1]) 112 } else { 113 return val.DefaultTupleComparator{}.CompareValues(0, left, right, typ) 114 } 115 } 116 117 func compareCollatedStrings(collation sql.CollationID, left, right []byte) int { 118 i := 0 119 for i < len(left) && i < len(right) { 120 if left[i] != right[i] { 121 break 122 } 123 i++ 124 } 125 if i >= len(left) || i >= len(right) { 126 if len(left) < len(right) { 127 return -1 128 } else if len(left) > len(right) { 129 return 1 130 } else { 131 return 0 132 } 133 } 134 135 li := i 136 for ; li > 0 && !utf8.RuneStart(left[li]); li-- { 137 } 138 left = left[li:] 139 140 ri := i 141 for ; ri > 0 && !utf8.RuneStart(right[ri]); ri-- { 142 } 143 right = right[ri:] 144 145 getRuneWeight := collation.Sorter() 146 for len(left) > 0 && len(right) > 0 { 147 // Binary strings aren't handled through this function, so it is safe to use the utf8 functions 148 leftRune, leftRead := utf8.DecodeRune(left) 149 rightRune, rightRead := utf8.DecodeRune(right) 150 if leftRead == utf8.RuneError || rightRead == utf8.RuneError { 151 // Malformed strings sort after well-formed strings, and we consider two malformed strings to be equal 152 if leftRead == utf8.RuneError && rightRead != utf8.RuneError { 153 return 1 154 } else if leftRead != utf8.RuneError && rightRead == utf8.RuneError { 155 return -1 156 } else { 157 return 0 158 } 159 } 160 if leftRune != rightRune { 161 leftWeight := getRuneWeight(leftRune) 162 rightWeight := getRuneWeight(rightRune) 163 if leftWeight < rightWeight { 164 return -1 165 } else if leftWeight > rightWeight { 166 return 1 167 } 168 } 169 left = left[leftRead:] 170 right = right[rightRead:] 171 } 172 173 // Strings are equal up to the compared length, so shorter strings sort before longer strings 174 if len(left) < len(right) { 175 return -1 176 } else if len(left) > len(right) { 177 return 1 178 } else { 179 return 0 180 } 181 }