github.com/linchen2chris/hugo@v0.0.0-20230307053224-cec209389705/tpl/compare/compare.go (about)

     1  // Copyright 2017 The Hugo Authors. All rights reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  // http://www.apache.org/licenses/LICENSE-2.0
     7  //
     8  // Unless required by applicable law or agreed to in writing, software
     9  // distributed under the License is distributed on an "AS IS" BASIS,
    10  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  // Package compare provides template functions for comparing values.
    15  package compare
    16  
    17  import (
    18  	"fmt"
    19  	"math"
    20  	"reflect"
    21  	"strconv"
    22  	"time"
    23  
    24  	"github.com/gohugoio/hugo/compare"
    25  	"github.com/gohugoio/hugo/langs"
    26  
    27  	"github.com/gohugoio/hugo/common/hreflect"
    28  	"github.com/gohugoio/hugo/common/htime"
    29  	"github.com/gohugoio/hugo/common/types"
    30  )
    31  
    32  // New returns a new instance of the compare-namespaced template functions.
    33  func New(loc *time.Location, caseInsensitive bool) *Namespace {
    34  	return &Namespace{loc: loc, caseInsensitive: caseInsensitive}
    35  }
    36  
    37  // Namespace provides template functions for the "compare" namespace.
    38  type Namespace struct {
    39  	loc *time.Location
    40  	// Enable to do case insensitive string compares.
    41  	caseInsensitive bool
    42  }
    43  
    44  // Default checks whether a givenv is set and returns the default value defaultv if it
    45  // is not.  "Set" in this context means non-zero for numeric types and times;
    46  // non-zero length for strings, arrays, slices, and maps;
    47  // any boolean or struct value; or non-nil for any other types.
    48  func (*Namespace) Default(defaultv any, givenv ...any) (any, error) {
    49  	// given is variadic because the following construct will not pass a piped
    50  	// argument when the key is missing:  {{ index . "key" | default "foo" }}
    51  	// The Go template will complain that we got 1 argument when we expected 2.
    52  
    53  	if len(givenv) == 0 {
    54  		return defaultv, nil
    55  	}
    56  	if len(givenv) != 1 {
    57  		return nil, fmt.Errorf("wrong number of args for default: want 2 got %d", len(givenv)+1)
    58  	}
    59  
    60  	g := reflect.ValueOf(givenv[0])
    61  	if !g.IsValid() {
    62  		return defaultv, nil
    63  	}
    64  
    65  	set := false
    66  
    67  	switch g.Kind() {
    68  	case reflect.Bool:
    69  		set = true
    70  	case reflect.String, reflect.Array, reflect.Slice, reflect.Map:
    71  		set = g.Len() != 0
    72  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
    73  		set = g.Int() != 0
    74  	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
    75  		set = g.Uint() != 0
    76  	case reflect.Float32, reflect.Float64:
    77  		set = g.Float() != 0
    78  	case reflect.Complex64, reflect.Complex128:
    79  		set = g.Complex() != 0
    80  	case reflect.Struct:
    81  		switch actual := givenv[0].(type) {
    82  		case time.Time:
    83  			set = !actual.IsZero()
    84  		default:
    85  			set = true
    86  		}
    87  	default:
    88  		set = !g.IsNil()
    89  	}
    90  
    91  	if set {
    92  		return givenv[0], nil
    93  	}
    94  
    95  	return defaultv, nil
    96  }
    97  
    98  // Eq returns the boolean truth of arg1 == arg2 || arg1 == arg3 || arg1 == arg4.
    99  func (n *Namespace) Eq(first any, others ...any) bool {
   100  	if n.caseInsensitive {
   101  		panic("caseInsensitive not implemented for Eq")
   102  	}
   103  	n.checkComparisonArgCount(1, others...)
   104  	normalize := func(v any) any {
   105  		if types.IsNil(v) {
   106  			return nil
   107  		}
   108  
   109  		if at, ok := v.(htime.AsTimeProvider); ok {
   110  			return at.AsTime(n.loc)
   111  		}
   112  
   113  		vv := reflect.ValueOf(v)
   114  		switch vv.Kind() {
   115  		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   116  			return vv.Int()
   117  		case reflect.Float32, reflect.Float64:
   118  			return vv.Float()
   119  		case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
   120  			return vv.Uint()
   121  		case reflect.String:
   122  			return vv.String()
   123  		default:
   124  			return v
   125  		}
   126  	}
   127  
   128  	normFirst := normalize(first)
   129  	for _, other := range others {
   130  		if e, ok := first.(compare.Eqer); ok {
   131  			if e.Eq(other) {
   132  				return true
   133  			}
   134  			continue
   135  		}
   136  
   137  		if e, ok := other.(compare.Eqer); ok {
   138  			if e.Eq(first) {
   139  				return true
   140  			}
   141  			continue
   142  		}
   143  
   144  		other = normalize(other)
   145  		if reflect.DeepEqual(normFirst, other) {
   146  			return true
   147  		}
   148  	}
   149  
   150  	return false
   151  }
   152  
   153  // Ne returns the boolean truth of arg1 != arg2 && arg1 != arg3 && arg1 != arg4.
   154  func (n *Namespace) Ne(first any, others ...any) bool {
   155  	n.checkComparisonArgCount(1, others...)
   156  	for _, other := range others {
   157  		if n.Eq(first, other) {
   158  			return false
   159  		}
   160  	}
   161  	return true
   162  }
   163  
   164  // Ge returns the boolean truth of arg1 >= arg2 && arg1 >= arg3 && arg1 >= arg4.
   165  func (n *Namespace) Ge(first any, others ...any) bool {
   166  	n.checkComparisonArgCount(1, others...)
   167  	for _, other := range others {
   168  		left, right := n.compareGet(first, other)
   169  		if !(left >= right) {
   170  			return false
   171  		}
   172  	}
   173  	return true
   174  }
   175  
   176  // Gt returns the boolean truth of arg1 > arg2 && arg1 > arg3 && arg1 > arg4.
   177  func (n *Namespace) Gt(first any, others ...any) bool {
   178  	n.checkComparisonArgCount(1, others...)
   179  	for _, other := range others {
   180  		left, right := n.compareGet(first, other)
   181  		if !(left > right) {
   182  			return false
   183  		}
   184  	}
   185  	return true
   186  }
   187  
   188  // Le returns the boolean truth of arg1 <= arg2 && arg1 <= arg3 && arg1 <= arg4.
   189  func (n *Namespace) Le(first any, others ...any) bool {
   190  	n.checkComparisonArgCount(1, others...)
   191  	for _, other := range others {
   192  		left, right := n.compareGet(first, other)
   193  		if !(left <= right) {
   194  			return false
   195  		}
   196  	}
   197  	return true
   198  }
   199  
   200  // Lt returns the boolean truth of arg1 < arg2 && arg1 < arg3 && arg1 < arg4.
   201  // The provided collator will be used for string comparisons.
   202  // This is for internal use.
   203  func (n *Namespace) LtCollate(collator *langs.Collator, first any, others ...any) bool {
   204  	n.checkComparisonArgCount(1, others...)
   205  	for _, other := range others {
   206  		left, right := n.compareGetWithCollator(collator, first, other)
   207  		if !(left < right) {
   208  			return false
   209  		}
   210  	}
   211  	return true
   212  }
   213  
   214  // Lt returns the boolean truth of arg1 < arg2 && arg1 < arg3 && arg1 < arg4.
   215  func (n *Namespace) Lt(first any, others ...any) bool {
   216  	return n.LtCollate(nil, first, others...)
   217  }
   218  
   219  func (n *Namespace) checkComparisonArgCount(min int, others ...any) bool {
   220  	if len(others) < min {
   221  		panic("missing arguments for comparison")
   222  	}
   223  	return true
   224  }
   225  
   226  // Conditional can be used as a ternary operator.
   227  //
   228  // It returns v1 if cond is true, else v2.
   229  func (n *Namespace) Conditional(cond bool, v1, v2 any) any {
   230  	if cond {
   231  		return v1
   232  	}
   233  	return v2
   234  }
   235  
   236  func (ns *Namespace) compareGet(a any, b any) (float64, float64) {
   237  	return ns.compareGetWithCollator(nil, a, b)
   238  }
   239  
   240  func (ns *Namespace) compareGetWithCollator(collator *langs.Collator, a any, b any) (float64, float64) {
   241  	if ac, ok := a.(compare.Comparer); ok {
   242  		c := ac.Compare(b)
   243  		if c < 0 {
   244  			return 1, 0
   245  		} else if c == 0 {
   246  			return 0, 0
   247  		} else {
   248  			return 0, 1
   249  		}
   250  	}
   251  
   252  	if bc, ok := b.(compare.Comparer); ok {
   253  		c := bc.Compare(a)
   254  		if c < 0 {
   255  			return 0, 1
   256  		} else if c == 0 {
   257  			return 0, 0
   258  		} else {
   259  			return 1, 0
   260  		}
   261  	}
   262  
   263  	var left, right float64
   264  	var leftStr, rightStr *string
   265  	av := reflect.ValueOf(a)
   266  
   267  	switch av.Kind() {
   268  	case reflect.Array, reflect.Chan, reflect.Map, reflect.Slice:
   269  		left = float64(av.Len())
   270  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   271  		left = float64(av.Int())
   272  	case reflect.Float32, reflect.Float64:
   273  		left = av.Float()
   274  	case reflect.String:
   275  		var err error
   276  		left, err = strconv.ParseFloat(av.String(), 64)
   277  		// Check if float is a special floating value and cast value as string.
   278  		if math.IsInf(left, 0) || math.IsNaN(left) || err != nil {
   279  			str := av.String()
   280  			leftStr = &str
   281  		}
   282  	case reflect.Struct:
   283  		if hreflect.IsTime(av.Type()) {
   284  			left = float64(ns.toTimeUnix(av))
   285  		}
   286  	case reflect.Bool:
   287  		left = 0
   288  		if av.Bool() {
   289  			left = 1
   290  		}
   291  	}
   292  
   293  	bv := reflect.ValueOf(b)
   294  
   295  	switch bv.Kind() {
   296  	case reflect.Array, reflect.Chan, reflect.Map, reflect.Slice:
   297  		right = float64(bv.Len())
   298  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   299  		right = float64(bv.Int())
   300  	case reflect.Float32, reflect.Float64:
   301  		right = bv.Float()
   302  	case reflect.String:
   303  		var err error
   304  		right, err = strconv.ParseFloat(bv.String(), 64)
   305  		// Check if float is a special floating value and cast value as string.
   306  		if math.IsInf(right, 0) || math.IsNaN(right) || err != nil {
   307  			str := bv.String()
   308  			rightStr = &str
   309  		}
   310  	case reflect.Struct:
   311  		if hreflect.IsTime(bv.Type()) {
   312  			right = float64(ns.toTimeUnix(bv))
   313  		}
   314  	case reflect.Bool:
   315  		right = 0
   316  		if bv.Bool() {
   317  			right = 1
   318  		}
   319  	}
   320  
   321  	if (ns.caseInsensitive || collator != nil) && leftStr != nil && rightStr != nil {
   322  		var c int
   323  		if collator != nil {
   324  			c = collator.CompareStrings(*leftStr, *rightStr)
   325  		} else {
   326  			c = compare.Strings(*leftStr, *rightStr)
   327  		}
   328  		if c < 0 {
   329  			return 0, 1
   330  		} else if c > 0 {
   331  			return 1, 0
   332  		} else {
   333  			return 0, 0
   334  		}
   335  	}
   336  
   337  	switch {
   338  	case leftStr == nil || rightStr == nil:
   339  	case *leftStr < *rightStr:
   340  		return 0, 1
   341  	case *leftStr > *rightStr:
   342  		return 1, 0
   343  	default:
   344  		return 0, 0
   345  	}
   346  
   347  	return left, right
   348  }
   349  
   350  func (ns *Namespace) toTimeUnix(v reflect.Value) int64 {
   351  	t, ok := hreflect.AsTime(v, ns.loc)
   352  	if !ok {
   353  		panic("coding error: argument must be time.Time type reflect Value")
   354  	}
   355  	return t.Unix()
   356  }