github.com/segmentio/parquet-go@v0.0.0-20230712180008-5d42db8f0d47/compare.go (about)

     1  package parquet
     2  
     3  import (
     4  	"encoding/binary"
     5  
     6  	"github.com/segmentio/parquet-go/deprecated"
     7  )
     8  
     9  // CompareDescending constructs a comparison function which inverses the order
    10  // of values.
    11  //
    12  //go:noinline
    13  func CompareDescending(cmp func(Value, Value) int) func(Value, Value) int {
    14  	return func(a, b Value) int { return -cmp(a, b) }
    15  }
    16  
    17  // CompareNullsFirst constructs a comparison function which assumes that null
    18  // values are smaller than all other values.
    19  //
    20  //go:noinline
    21  func CompareNullsFirst(cmp func(Value, Value) int) func(Value, Value) int {
    22  	return func(a, b Value) int {
    23  		switch {
    24  		case a.IsNull():
    25  			if b.IsNull() {
    26  				return 0
    27  			}
    28  			return -1
    29  		case b.IsNull():
    30  			return +1
    31  		default:
    32  			return cmp(a, b)
    33  		}
    34  	}
    35  }
    36  
    37  // CompareNullsLast constructs a comparison function which assumes that null
    38  // values are greater than all other values.
    39  //
    40  //go:noinline
    41  func CompareNullsLast(cmp func(Value, Value) int) func(Value, Value) int {
    42  	return func(a, b Value) int {
    43  		switch {
    44  		case a.IsNull():
    45  			if b.IsNull() {
    46  				return 0
    47  			}
    48  			return +1
    49  		case b.IsNull():
    50  			return -1
    51  		default:
    52  			return cmp(a, b)
    53  		}
    54  	}
    55  }
    56  
    57  func compareBool(v1, v2 bool) int {
    58  	switch {
    59  	case !v1 && v2:
    60  		return -1
    61  	case v1 && !v2:
    62  		return +1
    63  	default:
    64  		return 0
    65  	}
    66  }
    67  
    68  func compareInt32(v1, v2 int32) int {
    69  	switch {
    70  	case v1 < v2:
    71  		return -1
    72  	case v1 > v2:
    73  		return +1
    74  	default:
    75  		return 0
    76  	}
    77  }
    78  
    79  func compareInt64(v1, v2 int64) int {
    80  	switch {
    81  	case v1 < v2:
    82  		return -1
    83  	case v1 > v2:
    84  		return +1
    85  	default:
    86  		return 0
    87  	}
    88  }
    89  
    90  func compareInt96(v1, v2 deprecated.Int96) int {
    91  	switch {
    92  	case v1.Less(v2):
    93  		return -1
    94  	case v2.Less(v1):
    95  		return +1
    96  	default:
    97  		return 0
    98  	}
    99  }
   100  
   101  func compareFloat32(v1, v2 float32) int {
   102  	switch {
   103  	case v1 < v2:
   104  		return -1
   105  	case v1 > v2:
   106  		return +1
   107  	default:
   108  		return 0
   109  	}
   110  }
   111  
   112  func compareFloat64(v1, v2 float64) int {
   113  	switch {
   114  	case v1 < v2:
   115  		return -1
   116  	case v1 > v2:
   117  		return +1
   118  	default:
   119  		return 0
   120  	}
   121  }
   122  
   123  func compareUint32(v1, v2 uint32) int {
   124  	switch {
   125  	case v1 < v2:
   126  		return -1
   127  	case v1 > v2:
   128  		return +1
   129  	default:
   130  		return 0
   131  	}
   132  }
   133  
   134  func compareUint64(v1, v2 uint64) int {
   135  	switch {
   136  	case v1 < v2:
   137  		return -1
   138  	case v1 > v2:
   139  		return +1
   140  	default:
   141  		return 0
   142  	}
   143  }
   144  
   145  func compareBE128(v1, v2 *[16]byte) int {
   146  	x := binary.BigEndian.Uint64(v1[:8])
   147  	y := binary.BigEndian.Uint64(v2[:8])
   148  	switch {
   149  	case x < y:
   150  		return -1
   151  	case x > y:
   152  		return +1
   153  	}
   154  	x = binary.BigEndian.Uint64(v1[8:])
   155  	y = binary.BigEndian.Uint64(v2[8:])
   156  	switch {
   157  	case x < y:
   158  		return -1
   159  	case x > y:
   160  		return +1
   161  	default:
   162  		return 0
   163  	}
   164  }
   165  
   166  func lessBE128(v1, v2 *[16]byte) bool {
   167  	x := binary.BigEndian.Uint64(v1[:8])
   168  	y := binary.BigEndian.Uint64(v2[:8])
   169  	switch {
   170  	case x < y:
   171  		return true
   172  	case x > y:
   173  		return false
   174  	}
   175  	x = binary.BigEndian.Uint64(v1[8:])
   176  	y = binary.BigEndian.Uint64(v2[8:])
   177  	return x < y
   178  }
   179  
   180  func compareRowsFuncOf(schema *Schema, sortingColumns []SortingColumn) func(Row, Row) int {
   181  	leafColumns := make([]leafColumn, len(sortingColumns))
   182  	canCompareRows := true
   183  
   184  	forEachLeafColumnOf(schema, func(leaf leafColumn) {
   185  		if leaf.maxRepetitionLevel > 0 {
   186  			canCompareRows = false
   187  		}
   188  
   189  		if sortingIndex := searchSortingColumn(sortingColumns, leaf.path); sortingIndex < len(sortingColumns) {
   190  			leafColumns[sortingIndex] = leaf
   191  
   192  			if leaf.maxDefinitionLevel > 0 {
   193  				canCompareRows = false
   194  			}
   195  		}
   196  	})
   197  
   198  	// This is an optimization for the common case where rows
   199  	// are sorted by non-optional, non-repeated columns.
   200  	//
   201  	// The sort function can make the assumption that it will
   202  	// find the column value at the current column index, and
   203  	// does not need to scan the rows looking for values with
   204  	// a matching column index.
   205  	if canCompareRows {
   206  		return compareRowsFuncOfColumnIndexes(leafColumns, sortingColumns)
   207  	}
   208  
   209  	return compareRowsFuncOfColumnValues(leafColumns, sortingColumns)
   210  }
   211  
   212  func compareRowsUnordered(Row, Row) int { return 0 }
   213  
   214  //go:noinline
   215  func compareRowsFuncOfIndexColumns(compareFuncs []func(Row, Row) int) func(Row, Row) int {
   216  	return func(row1, row2 Row) int {
   217  		for _, compare := range compareFuncs {
   218  			if cmp := compare(row1, row2); cmp != 0 {
   219  				return cmp
   220  			}
   221  		}
   222  		return 0
   223  	}
   224  }
   225  
   226  //go:noinline
   227  func compareRowsFuncOfIndexAscending(columnIndex int16, typ Type) func(Row, Row) int {
   228  	return func(row1, row2 Row) int { return typ.Compare(row1[columnIndex], row2[columnIndex]) }
   229  }
   230  
   231  //go:noinline
   232  func compareRowsFuncOfIndexDescending(columnIndex int16, typ Type) func(Row, Row) int {
   233  	return func(row1, row2 Row) int { return -typ.Compare(row1[columnIndex], row2[columnIndex]) }
   234  }
   235  
   236  //go:noinline
   237  func compareRowsFuncOfColumnIndexes(leafColumns []leafColumn, sortingColumns []SortingColumn) func(Row, Row) int {
   238  	compareFuncs := make([]func(Row, Row) int, len(sortingColumns))
   239  
   240  	for sortingIndex, sortingColumn := range sortingColumns {
   241  		leaf := leafColumns[sortingIndex]
   242  		typ := leaf.node.Type()
   243  
   244  		if sortingColumn.Descending() {
   245  			compareFuncs[sortingIndex] = compareRowsFuncOfIndexDescending(leaf.columnIndex, typ)
   246  		} else {
   247  			compareFuncs[sortingIndex] = compareRowsFuncOfIndexAscending(leaf.columnIndex, typ)
   248  		}
   249  	}
   250  
   251  	switch len(compareFuncs) {
   252  	case 0:
   253  		return compareRowsUnordered
   254  	case 1:
   255  		return compareFuncs[0]
   256  	default:
   257  		return compareRowsFuncOfIndexColumns(compareFuncs)
   258  	}
   259  }
   260  
   261  //go:noinline
   262  func compareRowsFuncOfColumnValues(leafColumns []leafColumn, sortingColumns []SortingColumn) func(Row, Row) int {
   263  	highestColumnIndex := int16(0)
   264  	columnIndexes := make([]int16, len(sortingColumns))
   265  	compareFuncs := make([]func(Value, Value) int, len(sortingColumns))
   266  
   267  	for sortingIndex, sortingColumn := range sortingColumns {
   268  		leaf := leafColumns[sortingIndex]
   269  		compare := leaf.node.Type().Compare
   270  
   271  		if sortingColumn.Descending() {
   272  			compare = CompareDescending(compare)
   273  		}
   274  
   275  		if leaf.maxDefinitionLevel > 0 {
   276  			if sortingColumn.NullsFirst() {
   277  				compare = CompareNullsFirst(compare)
   278  			} else {
   279  				compare = CompareNullsLast(compare)
   280  			}
   281  		}
   282  
   283  		columnIndexes[sortingIndex] = leaf.columnIndex
   284  		compareFuncs[sortingIndex] = compare
   285  
   286  		if leaf.columnIndex > highestColumnIndex {
   287  			highestColumnIndex = leaf.columnIndex
   288  		}
   289  	}
   290  
   291  	return func(row1, row2 Row) int {
   292  		columns1 := make([][2]int32, 0, 64)
   293  		columns2 := make([][2]int32, 0, 64)
   294  
   295  		i1 := 0
   296  		i2 := 0
   297  
   298  		for columnIndex := int16(0); columnIndex <= highestColumnIndex; columnIndex++ {
   299  			j1 := i1 + 1
   300  			j2 := i2 + 1
   301  
   302  			for j1 < len(row1) && row1[j1].columnIndex == ^columnIndex {
   303  				j1++
   304  			}
   305  
   306  			for j2 < len(row2) && row2[j2].columnIndex == ^columnIndex {
   307  				j2++
   308  			}
   309  
   310  			columns1 = append(columns1, [2]int32{int32(i1), int32(j1)})
   311  			columns2 = append(columns2, [2]int32{int32(i2), int32(j2)})
   312  			i1 = j1
   313  			i2 = j2
   314  		}
   315  
   316  		for i, compare := range compareFuncs {
   317  			columnIndex := columnIndexes[i]
   318  			offsets1 := columns1[columnIndex]
   319  			offsets2 := columns2[columnIndex]
   320  			values1 := row1[offsets1[0]:offsets1[1]:offsets1[1]]
   321  			values2 := row2[offsets2[0]:offsets2[1]:offsets2[1]]
   322  			i1 := 0
   323  			i2 := 0
   324  
   325  			for i1 < len(values1) && i2 < len(values2) {
   326  				if cmp := compare(values1[i1], values2[i2]); cmp != 0 {
   327  					return cmp
   328  				}
   329  				i1++
   330  				i2++
   331  			}
   332  
   333  			if i1 < len(values1) {
   334  				return +1
   335  			}
   336  			if i2 < len(values2) {
   337  				return -1
   338  			}
   339  		}
   340  		return 0
   341  	}
   342  }