github.com/parquet-go/parquet-go@v0.21.1-0.20240501160520-b3c3a0c3ed6f/compare.go (about)

     1  package parquet
     2  
     3  import (
     4  	"encoding/binary"
     5  	"sync"
     6  
     7  	"github.com/parquet-go/parquet-go/deprecated"
     8  )
     9  
    10  // CompareDescending constructs a comparison function which inverses the order
    11  // of values.
    12  //
    13  //go:noinline
    14  func CompareDescending(cmp func(Value, Value) int) func(Value, Value) int {
    15  	return func(a, b Value) int { return -cmp(a, b) }
    16  }
    17  
    18  // CompareNullsFirst constructs a comparison function which assumes that null
    19  // values are smaller than all other values.
    20  //
    21  //go:noinline
    22  func CompareNullsFirst(cmp func(Value, Value) int) func(Value, Value) int {
    23  	return func(a, b Value) int {
    24  		switch {
    25  		case a.IsNull():
    26  			if b.IsNull() {
    27  				return 0
    28  			}
    29  			return -1
    30  		case b.IsNull():
    31  			return +1
    32  		default:
    33  			return cmp(a, b)
    34  		}
    35  	}
    36  }
    37  
    38  // CompareNullsLast constructs a comparison function which assumes that null
    39  // values are greater than all other values.
    40  //
    41  //go:noinline
    42  func CompareNullsLast(cmp func(Value, Value) int) func(Value, Value) int {
    43  	return func(a, b Value) int {
    44  		switch {
    45  		case a.IsNull():
    46  			if b.IsNull() {
    47  				return 0
    48  			}
    49  			return +1
    50  		case b.IsNull():
    51  			return -1
    52  		default:
    53  			return cmp(a, b)
    54  		}
    55  	}
    56  }
    57  
    58  func compareBool(v1, v2 bool) int {
    59  	switch {
    60  	case !v1 && v2:
    61  		return -1
    62  	case v1 && !v2:
    63  		return +1
    64  	default:
    65  		return 0
    66  	}
    67  }
    68  
    69  func compareInt32(v1, v2 int32) int {
    70  	switch {
    71  	case v1 < v2:
    72  		return -1
    73  	case v1 > v2:
    74  		return +1
    75  	default:
    76  		return 0
    77  	}
    78  }
    79  
    80  func compareInt64(v1, v2 int64) int {
    81  	switch {
    82  	case v1 < v2:
    83  		return -1
    84  	case v1 > v2:
    85  		return +1
    86  	default:
    87  		return 0
    88  	}
    89  }
    90  
    91  func compareInt96(v1, v2 deprecated.Int96) int {
    92  	switch {
    93  	case v1.Less(v2):
    94  		return -1
    95  	case v2.Less(v1):
    96  		return +1
    97  	default:
    98  		return 0
    99  	}
   100  }
   101  
   102  func compareFloat32(v1, v2 float32) int {
   103  	switch {
   104  	case v1 < v2:
   105  		return -1
   106  	case v1 > v2:
   107  		return +1
   108  	default:
   109  		return 0
   110  	}
   111  }
   112  
   113  func compareFloat64(v1, v2 float64) int {
   114  	switch {
   115  	case v1 < v2:
   116  		return -1
   117  	case v1 > v2:
   118  		return +1
   119  	default:
   120  		return 0
   121  	}
   122  }
   123  
   124  func compareUint32(v1, v2 uint32) int {
   125  	switch {
   126  	case v1 < v2:
   127  		return -1
   128  	case v1 > v2:
   129  		return +1
   130  	default:
   131  		return 0
   132  	}
   133  }
   134  
   135  func compareUint64(v1, v2 uint64) int {
   136  	switch {
   137  	case v1 < v2:
   138  		return -1
   139  	case v1 > v2:
   140  		return +1
   141  	default:
   142  		return 0
   143  	}
   144  }
   145  
   146  func compareBE128(v1, v2 *[16]byte) int {
   147  	x := binary.BigEndian.Uint64(v1[:8])
   148  	y := binary.BigEndian.Uint64(v2[:8])
   149  	switch {
   150  	case x < y:
   151  		return -1
   152  	case x > y:
   153  		return +1
   154  	}
   155  	x = binary.BigEndian.Uint64(v1[8:])
   156  	y = binary.BigEndian.Uint64(v2[8:])
   157  	switch {
   158  	case x < y:
   159  		return -1
   160  	case x > y:
   161  		return +1
   162  	default:
   163  		return 0
   164  	}
   165  }
   166  
   167  func lessBE128(v1, v2 *[16]byte) bool {
   168  	x := binary.BigEndian.Uint64(v1[:8])
   169  	y := binary.BigEndian.Uint64(v2[:8])
   170  	switch {
   171  	case x < y:
   172  		return true
   173  	case x > y:
   174  		return false
   175  	}
   176  	x = binary.BigEndian.Uint64(v1[8:])
   177  	y = binary.BigEndian.Uint64(v2[8:])
   178  	return x < y
   179  }
   180  
   181  func compareRowsFuncOf(schema *Schema, sortingColumns []SortingColumn) func(Row, Row) int {
   182  	leafColumns := make([]leafColumn, len(sortingColumns))
   183  	canCompareRows := true
   184  
   185  	forEachLeafColumnOf(schema, func(leaf leafColumn) {
   186  		if leaf.maxRepetitionLevel > 0 {
   187  			canCompareRows = false
   188  		}
   189  
   190  		if sortingIndex := searchSortingColumn(sortingColumns, leaf.path); sortingIndex < len(sortingColumns) {
   191  			leafColumns[sortingIndex] = leaf
   192  
   193  			if leaf.maxDefinitionLevel > 0 {
   194  				canCompareRows = false
   195  			}
   196  		}
   197  	})
   198  
   199  	// This is an optimization for the common case where rows
   200  	// are sorted by non-optional, non-repeated columns.
   201  	//
   202  	// The sort function can make the assumption that it will
   203  	// find the column value at the current column index, and
   204  	// does not need to scan the rows looking for values with
   205  	// a matching column index.
   206  	if canCompareRows {
   207  		return compareRowsFuncOfColumnIndexes(leafColumns, sortingColumns)
   208  	}
   209  
   210  	return compareRowsFuncOfColumnValues(leafColumns, sortingColumns)
   211  }
   212  
   213  func compareRowsUnordered(Row, Row) int { return 0 }
   214  
   215  //go:noinline
   216  func compareRowsFuncOfIndexColumns(compareFuncs []func(Row, Row) int) func(Row, Row) int {
   217  	return func(row1, row2 Row) int {
   218  		for _, compare := range compareFuncs {
   219  			if cmp := compare(row1, row2); cmp != 0 {
   220  				return cmp
   221  			}
   222  		}
   223  		return 0
   224  	}
   225  }
   226  
   227  //go:noinline
   228  func compareRowsFuncOfIndexAscending(columnIndex int16, typ Type) func(Row, Row) int {
   229  	return func(row1, row2 Row) int { return typ.Compare(row1[columnIndex], row2[columnIndex]) }
   230  }
   231  
   232  //go:noinline
   233  func compareRowsFuncOfIndexDescending(columnIndex int16, typ Type) func(Row, Row) int {
   234  	return func(row1, row2 Row) int { return -typ.Compare(row1[columnIndex], row2[columnIndex]) }
   235  }
   236  
   237  //go:noinline
   238  func compareRowsFuncOfColumnIndexes(leafColumns []leafColumn, sortingColumns []SortingColumn) func(Row, Row) int {
   239  	compareFuncs := make([]func(Row, Row) int, len(sortingColumns))
   240  
   241  	for sortingIndex, sortingColumn := range sortingColumns {
   242  		leaf := leafColumns[sortingIndex]
   243  		typ := leaf.node.Type()
   244  
   245  		if sortingColumn.Descending() {
   246  			compareFuncs[sortingIndex] = compareRowsFuncOfIndexDescending(leaf.columnIndex, typ)
   247  		} else {
   248  			compareFuncs[sortingIndex] = compareRowsFuncOfIndexAscending(leaf.columnIndex, typ)
   249  		}
   250  	}
   251  
   252  	switch len(compareFuncs) {
   253  	case 0:
   254  		return compareRowsUnordered
   255  	case 1:
   256  		return compareFuncs[0]
   257  	default:
   258  		return compareRowsFuncOfIndexColumns(compareFuncs)
   259  	}
   260  }
   261  
   262  var columnPool = &sync.Pool{New: func() any { return make([][2]int32, 0, 128) }}
   263  
   264  //go:noinline
   265  func compareRowsFuncOfColumnValues(leafColumns []leafColumn, sortingColumns []SortingColumn) func(Row, Row) int {
   266  	highestColumnIndex := int16(0)
   267  	columnIndexes := make([]int16, len(sortingColumns))
   268  	compareFuncs := make([]func(Value, Value) int, len(sortingColumns))
   269  
   270  	for sortingIndex, sortingColumn := range sortingColumns {
   271  		leaf := leafColumns[sortingIndex]
   272  		compare := leaf.node.Type().Compare
   273  
   274  		if sortingColumn.Descending() {
   275  			compare = CompareDescending(compare)
   276  		}
   277  
   278  		if leaf.maxDefinitionLevel > 0 {
   279  			if sortingColumn.NullsFirst() {
   280  				compare = CompareNullsFirst(compare)
   281  			} else {
   282  				compare = CompareNullsLast(compare)
   283  			}
   284  		}
   285  
   286  		columnIndexes[sortingIndex] = leaf.columnIndex
   287  		compareFuncs[sortingIndex] = compare
   288  
   289  		if leaf.columnIndex > highestColumnIndex {
   290  			highestColumnIndex = leaf.columnIndex
   291  		}
   292  	}
   293  
   294  	return func(row1, row2 Row) int {
   295  		columns1 := columnPool.Get().([][2]int32)
   296  		columns2 := columnPool.Get().([][2]int32)
   297  		defer func() {
   298  			columns1 = columns1[:0]
   299  			columns2 = columns2[:0]
   300  			columnPool.Put(columns1)
   301  			columnPool.Put(columns2)
   302  		}()
   303  
   304  		i1 := 0
   305  		i2 := 0
   306  
   307  		for columnIndex := int16(0); columnIndex <= highestColumnIndex; columnIndex++ {
   308  			j1 := i1 + 1
   309  			j2 := i2 + 1
   310  
   311  			for j1 < len(row1) && row1[j1].columnIndex == ^columnIndex {
   312  				j1++
   313  			}
   314  
   315  			for j2 < len(row2) && row2[j2].columnIndex == ^columnIndex {
   316  				j2++
   317  			}
   318  
   319  			columns1 = append(columns1, [2]int32{int32(i1), int32(j1)})
   320  			columns2 = append(columns2, [2]int32{int32(i2), int32(j2)})
   321  			i1 = j1
   322  			i2 = j2
   323  		}
   324  
   325  		for i, compare := range compareFuncs {
   326  			columnIndex := columnIndexes[i]
   327  			offsets1 := columns1[columnIndex]
   328  			offsets2 := columns2[columnIndex]
   329  			values1 := row1[offsets1[0]:offsets1[1]:offsets1[1]]
   330  			values2 := row2[offsets2[0]:offsets2[1]:offsets2[1]]
   331  			i1 := 0
   332  			i2 := 0
   333  
   334  			for i1 < len(values1) && i2 < len(values2) {
   335  				if cmp := compare(values1[i1], values2[i2]); cmp != 0 {
   336  					return cmp
   337  				}
   338  				i1++
   339  				i2++
   340  			}
   341  
   342  			if i1 < len(values1) {
   343  				return +1
   344  			}
   345  			if i2 < len(values2) {
   346  				return -1
   347  			}
   348  		}
   349  		return 0
   350  	}
   351  }