github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/libraries/doltcore/schema/collation_comparator.go (about)

     1  // Copyright 2022 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package schema
    16  
    17  import (
    18  	"bytes"
    19  	"unicode/utf8"
    20  
    21  	"github.com/dolthub/dolt/go/store/val"
    22  
    23  	"github.com/dolthub/go-mysql-server/sql"
    24  )
    25  
    26  type CollationTupleComparator struct {
    27  	Collations []sql.CollationID // CollationIDs are implemented as uint16
    28  }
    29  
    30  var _ val.TupleComparator = CollationTupleComparator{}
    31  
    32  // Compare implements TupleComparator
    33  func (c CollationTupleComparator) Compare(left, right val.Tuple, desc val.TupleDesc) (cmp int) {
    34  	fast := desc.GetFixedAccess()
    35  	for i := range fast {
    36  		start, stop := fast[i][0], fast[i][1]
    37  		cmp = collationCompare(desc.Types[i], c.Collations[i], left[start:stop], right[start:stop])
    38  		if cmp != 0 {
    39  			return cmp
    40  		}
    41  	}
    42  
    43  	off := len(fast)
    44  	for i, typ := range desc.Types[off:] {
    45  		j := i + off
    46  		cmp = collationCompare(typ, c.Collations[j], left.GetField(j), right.GetField(j))
    47  		if cmp != 0 {
    48  			return cmp
    49  		}
    50  	}
    51  	return
    52  }
    53  
    54  // CompareValues implements TupleComparator
    55  func (c CollationTupleComparator) CompareValues(index int, left, right []byte, typ val.Type) int {
    56  	return collationCompare(typ, c.Collations[index], left, right)
    57  }
    58  
    59  // Prefix implements TupleComparator
    60  func (c CollationTupleComparator) Prefix(n int) val.TupleComparator {
    61  	newCollations := make([]sql.CollationID, n)
    62  	copy(newCollations, c.Collations)
    63  	return CollationTupleComparator{newCollations}
    64  }
    65  
    66  // Suffix implements TupleComparator
    67  func (c CollationTupleComparator) Suffix(n int) val.TupleComparator {
    68  	newCollations := make([]sql.CollationID, n)
    69  	copy(newCollations, c.Collations[len(c.Collations)-n:])
    70  	return CollationTupleComparator{newCollations}
    71  }
    72  
    73  // Validated implements TupleComparator
    74  func (c CollationTupleComparator) Validated(types []val.Type) val.TupleComparator {
    75  	if len(c.Collations) > len(types) {
    76  		panic("too many collations compared to type encoding")
    77  	}
    78  	i := 0
    79  	for ; i < len(c.Collations); i++ {
    80  		if types[i].Enc == val.StringEnc && c.Collations[i] == sql.Collation_Unspecified {
    81  			c.Collations[i] = sql.Collation_Default
    82  		}
    83  	}
    84  	if len(c.Collations) == len(types) {
    85  		return c
    86  	}
    87  	newCollations := make([]sql.CollationID, len(types))
    88  	copy(newCollations, c.Collations)
    89  	for ; i < len(newCollations); i++ {
    90  		if types[i].Enc == val.StringEnc {
    91  			panic("string type encoding is missing its collation")
    92  		}
    93  		newCollations[i] = sql.Collation_Unspecified
    94  	}
    95  	return CollationTupleComparator{Collations: newCollations}
    96  }
    97  
    98  func collationCompare(typ val.Type, collation sql.CollationID, left, right []byte) int {
    99  	// order NULLs first
   100  	if left == nil || right == nil {
   101  		if bytes.Equal(left, right) {
   102  			return 0
   103  		} else if left == nil {
   104  			return -1
   105  		} else {
   106  			return 1
   107  		}
   108  	}
   109  
   110  	if typ.Enc == val.StringEnc {
   111  		return compareCollatedStrings(collation, left[:len(left)-1], right[:len(right)-1])
   112  	} else {
   113  		return val.DefaultTupleComparator{}.CompareValues(0, left, right, typ)
   114  	}
   115  }
   116  
   117  func compareCollatedStrings(collation sql.CollationID, left, right []byte) int {
   118  	i := 0
   119  	for i < len(left) && i < len(right) {
   120  		if left[i] != right[i] {
   121  			break
   122  		}
   123  		i++
   124  	}
   125  	if i >= len(left) || i >= len(right) {
   126  		if len(left) < len(right) {
   127  			return -1
   128  		} else if len(left) > len(right) {
   129  			return 1
   130  		} else {
   131  			return 0
   132  		}
   133  	}
   134  
   135  	li := i
   136  	for ; li > 0 && !utf8.RuneStart(left[li]); li-- {
   137  	}
   138  	left = left[li:]
   139  
   140  	ri := i
   141  	for ; ri > 0 && !utf8.RuneStart(right[ri]); ri-- {
   142  	}
   143  	right = right[ri:]
   144  
   145  	getRuneWeight := collation.Sorter()
   146  	for len(left) > 0 && len(right) > 0 {
   147  		// Binary strings aren't handled through this function, so it is safe to use the utf8 functions
   148  		leftRune, leftRead := utf8.DecodeRune(left)
   149  		rightRune, rightRead := utf8.DecodeRune(right)
   150  		if leftRead == utf8.RuneError || rightRead == utf8.RuneError {
   151  			// Malformed strings sort after well-formed strings, and we consider two malformed strings to be equal
   152  			if leftRead == utf8.RuneError && rightRead != utf8.RuneError {
   153  				return 1
   154  			} else if leftRead != utf8.RuneError && rightRead == utf8.RuneError {
   155  				return -1
   156  			} else {
   157  				return 0
   158  			}
   159  		}
   160  		if leftRune != rightRune {
   161  			leftWeight := getRuneWeight(leftRune)
   162  			rightWeight := getRuneWeight(rightRune)
   163  			if leftWeight < rightWeight {
   164  				return -1
   165  			} else if leftWeight > rightWeight {
   166  				return 1
   167  			}
   168  		}
   169  		left = left[leftRead:]
   170  		right = right[rightRead:]
   171  	}
   172  
   173  	// Strings are equal up to the compared length, so shorter strings sort before longer strings
   174  	if len(left) < len(right) {
   175  		return -1
   176  	} else if len(left) > len(right) {
   177  		return 1
   178  	} else {
   179  		return 0
   180  	}
   181  }