github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/testutils/sort.go (about)

     1  // Copyright 2016 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package testutils
    12  
    13  import (
    14  	"fmt"
    15  	"reflect"
    16  	"sort"
    17  )
    18  
    19  var _ sort.Interface = structSorter{}
    20  
    21  // structSorter implements sort.Interface for a slice of structs, making heavy use of
    22  // reflection.
    23  type structSorter struct {
    24  	v          reflect.Value
    25  	fieldNames []string
    26  }
    27  
    28  // Len returns the length of the underlying slice.
    29  func (ss structSorter) Len() int {
    30  	return ss.v.Len()
    31  }
    32  
    33  // Less returns true iff if the sort fields at index i are less than the sort
    34  // fields at index j.
    35  func (ss structSorter) Less(i, j int) bool {
    36  	v1 := reflect.Indirect(ss.v.Index(i))
    37  	v2 := reflect.Indirect(ss.v.Index(j))
    38  	return ss.fieldIsLess(v1, v2, 0)
    39  }
    40  
    41  func (ss structSorter) fieldIsLess(v1, v2 reflect.Value, fieldNum int) bool {
    42  	fieldName := ss.fieldNames[fieldNum]
    43  	lastField := len(ss.fieldNames) == fieldNum+1
    44  
    45  	// Grab the appropriate field from both structs.
    46  	//
    47  	// TODO(cdo): This can be optimized by moving this next block of tests into
    48  	// SortStructs, caching the index of the field, and using the more efficient
    49  	// reflect.Value.FieldByIndex().
    50  	f1 := v1.FieldByName(fieldName)
    51  	if !f1.IsValid() {
    52  		panic(fmt.Sprintf("couldn't get field %s", fieldName))
    53  	}
    54  	f2 := v2.FieldByName(fieldName)
    55  	if !f2.IsValid() {
    56  		panic(fmt.Sprintf("couldn't get field %s", fieldName))
    57  	}
    58  
    59  	// Do the appropriate < comparison based on the type of the fields.
    60  	switch f1.Kind() {
    61  	case reflect.String:
    62  		if !lastField && f1.String() == f2.String() {
    63  			return ss.fieldIsLess(v1, v2, fieldNum+1)
    64  		}
    65  		return f1.String() < f2.String()
    66  
    67  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
    68  		if !lastField && f1.Int() == f2.Int() {
    69  			return ss.fieldIsLess(v1, v2, fieldNum+1)
    70  		}
    71  		return f1.Int() < f2.Int()
    72  
    73  	case reflect.Uint, reflect.Uintptr, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
    74  		if !lastField && f1.Uint() == f2.Uint() {
    75  			return ss.fieldIsLess(v1, v2, fieldNum+1)
    76  		}
    77  		return f1.Uint() < f2.Uint()
    78  
    79  	case reflect.Float32, reflect.Float64:
    80  		if !lastField && f1.Float() == f2.Float() {
    81  			return ss.fieldIsLess(v1, v2, fieldNum+1)
    82  		}
    83  		return f1.Float() < f2.Float()
    84  
    85  	case reflect.Bool:
    86  		if !lastField && f1.Bool() == f2.Bool() {
    87  			return ss.fieldIsLess(v1, v2, fieldNum+1)
    88  		}
    89  		return !f1.Bool() && f2.Bool()
    90  	}
    91  
    92  	panic(fmt.Sprintf("can't handle sort key type %d", uint(f1.Kind())))
    93  }
    94  
    95  // Swap swaps the elements at the provided indices.
    96  func (ss structSorter) Swap(i, j int) {
    97  	// Store the temp value in a new reflect.Value. Then, do a standard swap of the two slice
    98  	// elements.
    99  	t := reflect.ValueOf(ss.v.Index(i).Interface())
   100  	ss.v.Index(i).Set(ss.v.Index(j))
   101  	ss.v.Index(j).Set(t)
   102  }
   103  
   104  // SortStructs sorts the given slice of structs using the given fields as the ordered sort keys.
   105  func SortStructs(s interface{}, fieldNames ...string) {
   106  	// Verify that we've gotten a slice of structs or pointers to structs.
   107  	structs := reflect.ValueOf(s)
   108  	if structs.Kind() != reflect.Slice {
   109  		panic(fmt.Sprintf("expected slice, got %T", s))
   110  	}
   111  	elemType := structs.Type().Elem()
   112  	if elemType.Kind() == reflect.Ptr {
   113  		elemType = elemType.Elem()
   114  	}
   115  	if elemType.Kind() != reflect.Struct {
   116  		panic(fmt.Sprintf("%s is not a struct or pointer to struct", structs.Elem()))
   117  	}
   118  
   119  	sort.Sort(structSorter{structs, fieldNames})
   120  }