github.com/parquet-go/parquet-go@v0.21.1-0.20240501160520-b3c3a0c3ed6f/compare.go (about) 1 package parquet 2 3 import ( 4 "encoding/binary" 5 "sync" 6 7 "github.com/parquet-go/parquet-go/deprecated" 8 ) 9 10 // CompareDescending constructs a comparison function which inverses the order 11 // of values. 12 // 13 //go:noinline 14 func CompareDescending(cmp func(Value, Value) int) func(Value, Value) int { 15 return func(a, b Value) int { return -cmp(a, b) } 16 } 17 18 // CompareNullsFirst constructs a comparison function which assumes that null 19 // values are smaller than all other values. 20 // 21 //go:noinline 22 func CompareNullsFirst(cmp func(Value, Value) int) func(Value, Value) int { 23 return func(a, b Value) int { 24 switch { 25 case a.IsNull(): 26 if b.IsNull() { 27 return 0 28 } 29 return -1 30 case b.IsNull(): 31 return +1 32 default: 33 return cmp(a, b) 34 } 35 } 36 } 37 38 // CompareNullsLast constructs a comparison function which assumes that null 39 // values are greater than all other values. 40 // 41 //go:noinline 42 func CompareNullsLast(cmp func(Value, Value) int) func(Value, Value) int { 43 return func(a, b Value) int { 44 switch { 45 case a.IsNull(): 46 if b.IsNull() { 47 return 0 48 } 49 return +1 50 case b.IsNull(): 51 return -1 52 default: 53 return cmp(a, b) 54 } 55 } 56 } 57 58 func compareBool(v1, v2 bool) int { 59 switch { 60 case !v1 && v2: 61 return -1 62 case v1 && !v2: 63 return +1 64 default: 65 return 0 66 } 67 } 68 69 func compareInt32(v1, v2 int32) int { 70 switch { 71 case v1 < v2: 72 return -1 73 case v1 > v2: 74 return +1 75 default: 76 return 0 77 } 78 } 79 80 func compareInt64(v1, v2 int64) int { 81 switch { 82 case v1 < v2: 83 return -1 84 case v1 > v2: 85 return +1 86 default: 87 return 0 88 } 89 } 90 91 func compareInt96(v1, v2 deprecated.Int96) int { 92 switch { 93 case v1.Less(v2): 94 return -1 95 case v2.Less(v1): 96 return +1 97 default: 98 return 0 99 } 100 } 101 102 func compareFloat32(v1, v2 float32) int { 103 switch { 104 case v1 < v2: 105 return -1 106 case v1 > v2: 107 return +1 108 default: 109 return 0 110 } 111 } 112 113 func compareFloat64(v1, v2 float64) int { 114 switch { 115 case v1 < v2: 116 return -1 117 case v1 > v2: 118 return +1 119 default: 120 return 0 121 } 122 } 123 124 func compareUint32(v1, v2 uint32) int { 125 switch { 126 case v1 < v2: 127 return -1 128 case v1 > v2: 129 return +1 130 default: 131 return 0 132 } 133 } 134 135 func compareUint64(v1, v2 uint64) int { 136 switch { 137 case v1 < v2: 138 return -1 139 case v1 > v2: 140 return +1 141 default: 142 return 0 143 } 144 } 145 146 func compareBE128(v1, v2 *[16]byte) int { 147 x := binary.BigEndian.Uint64(v1[:8]) 148 y := binary.BigEndian.Uint64(v2[:8]) 149 switch { 150 case x < y: 151 return -1 152 case x > y: 153 return +1 154 } 155 x = binary.BigEndian.Uint64(v1[8:]) 156 y = binary.BigEndian.Uint64(v2[8:]) 157 switch { 158 case x < y: 159 return -1 160 case x > y: 161 return +1 162 default: 163 return 0 164 } 165 } 166 167 func lessBE128(v1, v2 *[16]byte) bool { 168 x := binary.BigEndian.Uint64(v1[:8]) 169 y := binary.BigEndian.Uint64(v2[:8]) 170 switch { 171 case x < y: 172 return true 173 case x > y: 174 return false 175 } 176 x = binary.BigEndian.Uint64(v1[8:]) 177 y = binary.BigEndian.Uint64(v2[8:]) 178 return x < y 179 } 180 181 func compareRowsFuncOf(schema *Schema, sortingColumns []SortingColumn) func(Row, Row) int { 182 leafColumns := make([]leafColumn, len(sortingColumns)) 183 canCompareRows := true 184 185 forEachLeafColumnOf(schema, func(leaf leafColumn) { 186 if leaf.maxRepetitionLevel > 0 { 187 canCompareRows = false 188 } 189 190 if sortingIndex := searchSortingColumn(sortingColumns, leaf.path); sortingIndex < len(sortingColumns) { 191 leafColumns[sortingIndex] = leaf 192 193 if leaf.maxDefinitionLevel > 0 { 194 canCompareRows = false 195 } 196 } 197 }) 198 199 // This is an optimization for the common case where rows 200 // are sorted by non-optional, non-repeated columns. 201 // 202 // The sort function can make the assumption that it will 203 // find the column value at the current column index, and 204 // does not need to scan the rows looking for values with 205 // a matching column index. 206 if canCompareRows { 207 return compareRowsFuncOfColumnIndexes(leafColumns, sortingColumns) 208 } 209 210 return compareRowsFuncOfColumnValues(leafColumns, sortingColumns) 211 } 212 213 func compareRowsUnordered(Row, Row) int { return 0 } 214 215 //go:noinline 216 func compareRowsFuncOfIndexColumns(compareFuncs []func(Row, Row) int) func(Row, Row) int { 217 return func(row1, row2 Row) int { 218 for _, compare := range compareFuncs { 219 if cmp := compare(row1, row2); cmp != 0 { 220 return cmp 221 } 222 } 223 return 0 224 } 225 } 226 227 //go:noinline 228 func compareRowsFuncOfIndexAscending(columnIndex int16, typ Type) func(Row, Row) int { 229 return func(row1, row2 Row) int { return typ.Compare(row1[columnIndex], row2[columnIndex]) } 230 } 231 232 //go:noinline 233 func compareRowsFuncOfIndexDescending(columnIndex int16, typ Type) func(Row, Row) int { 234 return func(row1, row2 Row) int { return -typ.Compare(row1[columnIndex], row2[columnIndex]) } 235 } 236 237 //go:noinline 238 func compareRowsFuncOfColumnIndexes(leafColumns []leafColumn, sortingColumns []SortingColumn) func(Row, Row) int { 239 compareFuncs := make([]func(Row, Row) int, len(sortingColumns)) 240 241 for sortingIndex, sortingColumn := range sortingColumns { 242 leaf := leafColumns[sortingIndex] 243 typ := leaf.node.Type() 244 245 if sortingColumn.Descending() { 246 compareFuncs[sortingIndex] = compareRowsFuncOfIndexDescending(leaf.columnIndex, typ) 247 } else { 248 compareFuncs[sortingIndex] = compareRowsFuncOfIndexAscending(leaf.columnIndex, typ) 249 } 250 } 251 252 switch len(compareFuncs) { 253 case 0: 254 return compareRowsUnordered 255 case 1: 256 return compareFuncs[0] 257 default: 258 return compareRowsFuncOfIndexColumns(compareFuncs) 259 } 260 } 261 262 var columnPool = &sync.Pool{New: func() any { return make([][2]int32, 0, 128) }} 263 264 //go:noinline 265 func compareRowsFuncOfColumnValues(leafColumns []leafColumn, sortingColumns []SortingColumn) func(Row, Row) int { 266 highestColumnIndex := int16(0) 267 columnIndexes := make([]int16, len(sortingColumns)) 268 compareFuncs := make([]func(Value, Value) int, len(sortingColumns)) 269 270 for sortingIndex, sortingColumn := range sortingColumns { 271 leaf := leafColumns[sortingIndex] 272 compare := leaf.node.Type().Compare 273 274 if sortingColumn.Descending() { 275 compare = CompareDescending(compare) 276 } 277 278 if leaf.maxDefinitionLevel > 0 { 279 if sortingColumn.NullsFirst() { 280 compare = CompareNullsFirst(compare) 281 } else { 282 compare = CompareNullsLast(compare) 283 } 284 } 285 286 columnIndexes[sortingIndex] = leaf.columnIndex 287 compareFuncs[sortingIndex] = compare 288 289 if leaf.columnIndex > highestColumnIndex { 290 highestColumnIndex = leaf.columnIndex 291 } 292 } 293 294 return func(row1, row2 Row) int { 295 columns1 := columnPool.Get().([][2]int32) 296 columns2 := columnPool.Get().([][2]int32) 297 defer func() { 298 columns1 = columns1[:0] 299 columns2 = columns2[:0] 300 columnPool.Put(columns1) 301 columnPool.Put(columns2) 302 }() 303 304 i1 := 0 305 i2 := 0 306 307 for columnIndex := int16(0); columnIndex <= highestColumnIndex; columnIndex++ { 308 j1 := i1 + 1 309 j2 := i2 + 1 310 311 for j1 < len(row1) && row1[j1].columnIndex == ^columnIndex { 312 j1++ 313 } 314 315 for j2 < len(row2) && row2[j2].columnIndex == ^columnIndex { 316 j2++ 317 } 318 319 columns1 = append(columns1, [2]int32{int32(i1), int32(j1)}) 320 columns2 = append(columns2, [2]int32{int32(i2), int32(j2)}) 321 i1 = j1 322 i2 = j2 323 } 324 325 for i, compare := range compareFuncs { 326 columnIndex := columnIndexes[i] 327 offsets1 := columns1[columnIndex] 328 offsets2 := columns2[columnIndex] 329 values1 := row1[offsets1[0]:offsets1[1]:offsets1[1]] 330 values2 := row2[offsets2[0]:offsets2[1]:offsets2[1]] 331 i1 := 0 332 i2 := 0 333 334 for i1 < len(values1) && i2 < len(values2) { 335 if cmp := compare(values1[i1], values2[i2]); cmp != 0 { 336 return cmp 337 } 338 i1++ 339 i2++ 340 } 341 342 if i1 < len(values1) { 343 return +1 344 } 345 if i2 < len(values2) { 346 return -1 347 } 348 } 349 return 0 350 } 351 }