github.com/anakojm/hugo-katex@v0.0.0-20231023141351-42d6f5de9c0b/tpl/collections/where.go (about)

     1  // Copyright 2017 The Hugo Authors. All rights reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  // http://www.apache.org/licenses/LICENSE-2.0
     7  //
     8  // Unless required by applicable law or agreed to in writing, software
     9  // distributed under the License is distributed on an "AS IS" BASIS,
    10  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package collections
    15  
    16  import (
    17  	"context"
    18  	"errors"
    19  	"fmt"
    20  	"reflect"
    21  	"strings"
    22  
    23  	"github.com/gohugoio/hugo/common/hreflect"
    24  	"github.com/gohugoio/hugo/common/hstrings"
    25  	"github.com/gohugoio/hugo/common/maps"
    26  )
    27  
    28  // Where returns a filtered subset of collection c.
    29  func (ns *Namespace) Where(ctx context.Context, c, key any, args ...any) (any, error) {
    30  	seqv, isNil := indirect(reflect.ValueOf(c))
    31  	if isNil {
    32  		return nil, errors.New("can't iterate over a nil value of type " + reflect.ValueOf(c).Type().String())
    33  	}
    34  
    35  	mv, op, err := parseWhereArgs(args...)
    36  	if err != nil {
    37  		return nil, err
    38  	}
    39  
    40  	ctxv := reflect.ValueOf(ctx)
    41  
    42  	var path []string
    43  	kv := reflect.ValueOf(key)
    44  	if kv.Kind() == reflect.String {
    45  		path = strings.Split(strings.Trim(kv.String(), "."), ".")
    46  	}
    47  
    48  	switch seqv.Kind() {
    49  	case reflect.Array, reflect.Slice:
    50  		return ns.checkWhereArray(ctxv, seqv, kv, mv, path, op)
    51  	case reflect.Map:
    52  		return ns.checkWhereMap(ctxv, seqv, kv, mv, path, op)
    53  	default:
    54  		return nil, fmt.Errorf("can't iterate over %v", c)
    55  	}
    56  }
    57  
    58  func (ns *Namespace) checkCondition(v, mv reflect.Value, op string) (bool, error) {
    59  	v, vIsNil := indirect(v)
    60  	if !v.IsValid() {
    61  		vIsNil = true
    62  	}
    63  
    64  	mv, mvIsNil := indirect(mv)
    65  	if !mv.IsValid() {
    66  		mvIsNil = true
    67  	}
    68  	if vIsNil || mvIsNil {
    69  		switch op {
    70  		case "", "=", "==", "eq":
    71  			return vIsNil == mvIsNil, nil
    72  		case "!=", "<>", "ne":
    73  			return vIsNil != mvIsNil, nil
    74  		}
    75  		return false, nil
    76  	}
    77  
    78  	if v.Kind() == reflect.Bool && mv.Kind() == reflect.Bool {
    79  		switch op {
    80  		case "", "=", "==", "eq":
    81  			return v.Bool() == mv.Bool(), nil
    82  		case "!=", "<>", "ne":
    83  			return v.Bool() != mv.Bool(), nil
    84  		}
    85  		return false, nil
    86  	}
    87  
    88  	var ivp, imvp *int64
    89  	var fvp, fmvp *float64
    90  	var svp, smvp *string
    91  	var slv, slmv any
    92  	var ima []int64
    93  	var fma []float64
    94  	var sma []string
    95  
    96  	if mv.Kind() == v.Kind() {
    97  		switch v.Kind() {
    98  		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
    99  			iv := v.Int()
   100  			ivp = &iv
   101  			imv := mv.Int()
   102  			imvp = &imv
   103  		case reflect.String:
   104  			sv := v.String()
   105  			svp = &sv
   106  			smv := mv.String()
   107  			smvp = &smv
   108  		case reflect.Float64:
   109  			fv := v.Float()
   110  			fvp = &fv
   111  			fmv := mv.Float()
   112  			fmvp = &fmv
   113  		case reflect.Struct:
   114  			if hreflect.IsTime(v.Type()) {
   115  				iv := ns.toTimeUnix(v)
   116  				ivp = &iv
   117  				imv := ns.toTimeUnix(mv)
   118  				imvp = &imv
   119  			}
   120  		case reflect.Array, reflect.Slice:
   121  			slv = v.Interface()
   122  			slmv = mv.Interface()
   123  		}
   124  	} else if isNumber(v.Kind()) && isNumber(mv.Kind()) {
   125  		fv, err := toFloat(v)
   126  		if err != nil {
   127  			return false, err
   128  		}
   129  		fvp = &fv
   130  		fmv, err := toFloat(mv)
   131  		if err != nil {
   132  			return false, err
   133  		}
   134  		fmvp = &fmv
   135  	} else {
   136  		if mv.Kind() != reflect.Array && mv.Kind() != reflect.Slice {
   137  			return false, nil
   138  		}
   139  
   140  		if mv.Len() == 0 {
   141  			return false, nil
   142  		}
   143  
   144  		if v.Kind() != reflect.Interface && mv.Type().Elem().Kind() != reflect.Interface && mv.Type().Elem() != v.Type() && v.Kind() != reflect.Array && v.Kind() != reflect.Slice {
   145  			return false, nil
   146  		}
   147  		switch v.Kind() {
   148  		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   149  			iv := v.Int()
   150  			ivp = &iv
   151  			for i := 0; i < mv.Len(); i++ {
   152  				if anInt, err := toInt(mv.Index(i)); err == nil {
   153  					ima = append(ima, anInt)
   154  				}
   155  			}
   156  		case reflect.String:
   157  			sv := v.String()
   158  			svp = &sv
   159  			for i := 0; i < mv.Len(); i++ {
   160  				if aString, err := toString(mv.Index(i)); err == nil {
   161  					sma = append(sma, aString)
   162  				}
   163  			}
   164  		case reflect.Float64:
   165  			fv := v.Float()
   166  			fvp = &fv
   167  			for i := 0; i < mv.Len(); i++ {
   168  				if aFloat, err := toFloat(mv.Index(i)); err == nil {
   169  					fma = append(fma, aFloat)
   170  				}
   171  			}
   172  		case reflect.Struct:
   173  			if hreflect.IsTime(v.Type()) {
   174  				iv := ns.toTimeUnix(v)
   175  				ivp = &iv
   176  				for i := 0; i < mv.Len(); i++ {
   177  					ima = append(ima, ns.toTimeUnix(mv.Index(i)))
   178  				}
   179  			}
   180  		case reflect.Array, reflect.Slice:
   181  			slv = v.Interface()
   182  			slmv = mv.Interface()
   183  		}
   184  	}
   185  
   186  	switch op {
   187  	case "", "=", "==", "eq":
   188  		switch {
   189  		case ivp != nil && imvp != nil:
   190  			return *ivp == *imvp, nil
   191  		case svp != nil && smvp != nil:
   192  			return *svp == *smvp, nil
   193  		case fvp != nil && fmvp != nil:
   194  			return *fvp == *fmvp, nil
   195  		}
   196  	case "!=", "<>", "ne":
   197  		switch {
   198  		case ivp != nil && imvp != nil:
   199  			return *ivp != *imvp, nil
   200  		case svp != nil && smvp != nil:
   201  			return *svp != *smvp, nil
   202  		case fvp != nil && fmvp != nil:
   203  			return *fvp != *fmvp, nil
   204  		}
   205  	case ">=", "ge":
   206  		switch {
   207  		case ivp != nil && imvp != nil:
   208  			return *ivp >= *imvp, nil
   209  		case svp != nil && smvp != nil:
   210  			return *svp >= *smvp, nil
   211  		case fvp != nil && fmvp != nil:
   212  			return *fvp >= *fmvp, nil
   213  		}
   214  	case ">", "gt":
   215  		switch {
   216  		case ivp != nil && imvp != nil:
   217  			return *ivp > *imvp, nil
   218  		case svp != nil && smvp != nil:
   219  			return *svp > *smvp, nil
   220  		case fvp != nil && fmvp != nil:
   221  			return *fvp > *fmvp, nil
   222  		}
   223  	case "<=", "le":
   224  		switch {
   225  		case ivp != nil && imvp != nil:
   226  			return *ivp <= *imvp, nil
   227  		case svp != nil && smvp != nil:
   228  			return *svp <= *smvp, nil
   229  		case fvp != nil && fmvp != nil:
   230  			return *fvp <= *fmvp, nil
   231  		}
   232  	case "<", "lt":
   233  		switch {
   234  		case ivp != nil && imvp != nil:
   235  			return *ivp < *imvp, nil
   236  		case svp != nil && smvp != nil:
   237  			return *svp < *smvp, nil
   238  		case fvp != nil && fmvp != nil:
   239  			return *fvp < *fmvp, nil
   240  		}
   241  	case "in", "not in":
   242  		var r bool
   243  		switch {
   244  		case ivp != nil && len(ima) > 0:
   245  			r, _ = ns.In(ima, *ivp)
   246  		case fvp != nil && len(fma) > 0:
   247  			r, _ = ns.In(fma, *fvp)
   248  		case svp != nil:
   249  			if len(sma) > 0 {
   250  				r, _ = ns.In(sma, *svp)
   251  			} else if smvp != nil {
   252  				r, _ = ns.In(*smvp, *svp)
   253  			}
   254  		default:
   255  			return false, nil
   256  		}
   257  		if op == "not in" {
   258  			return !r, nil
   259  		}
   260  		return r, nil
   261  	case "intersect":
   262  		r, err := ns.Intersect(slv, slmv)
   263  		if err != nil {
   264  			return false, err
   265  		}
   266  
   267  		if reflect.TypeOf(r).Kind() == reflect.Slice {
   268  			s := reflect.ValueOf(r)
   269  
   270  			if s.Len() > 0 {
   271  				return true, nil
   272  			}
   273  			return false, nil
   274  		}
   275  		return false, errors.New("invalid intersect values")
   276  	case "like":
   277  		if svp != nil && smvp != nil {
   278  			re, err := hstrings.GetOrCompileRegexp(*smvp)
   279  			if err != nil {
   280  				return false, err
   281  			}
   282  			if re.MatchString(*svp) {
   283  				return true, nil
   284  			}
   285  			return false, nil
   286  		}
   287  	default:
   288  		return false, errors.New("no such operator")
   289  	}
   290  	return false, nil
   291  }
   292  
   293  func evaluateSubElem(ctx, obj reflect.Value, elemName string) (reflect.Value, error) {
   294  	if !obj.IsValid() {
   295  		return zero, errors.New("can't evaluate an invalid value")
   296  	}
   297  
   298  	typ := obj.Type()
   299  	obj, isNil := indirect(obj)
   300  
   301  	if obj.Kind() == reflect.Interface {
   302  		// If obj is an interface, we need to inspect the value it contains
   303  		// to see the full set of methods and fields.
   304  		// Indirect returns the value that it points to, which is what's needed
   305  		// below to be able to reflect on its fields.
   306  		obj = reflect.Indirect(obj.Elem())
   307  	}
   308  
   309  	// first, check whether obj has a method. In this case, obj is
   310  	// a struct or its pointer. If obj is a struct,
   311  	// to check all T and *T method, use obj pointer type Value
   312  	objPtr := obj
   313  	if objPtr.Kind() != reflect.Interface && objPtr.CanAddr() {
   314  		objPtr = objPtr.Addr()
   315  	}
   316  
   317  	index := hreflect.GetMethodIndexByName(objPtr.Type(), elemName)
   318  	if index != -1 {
   319  		var args []reflect.Value
   320  		mt := objPtr.Type().Method(index)
   321  		num := mt.Type.NumIn()
   322  		maxNumIn := 1
   323  		if num > 1 && mt.Type.In(1).Implements(hreflect.ContextInterface) {
   324  			args = []reflect.Value{ctx}
   325  			maxNumIn = 2
   326  		}
   327  
   328  		switch {
   329  		case mt.PkgPath != "":
   330  			return zero, fmt.Errorf("%s is an unexported method of type %s", elemName, typ)
   331  		case mt.Type.NumIn() > maxNumIn:
   332  			return zero, fmt.Errorf("%s is a method of type %s but requires more than %d parameter", elemName, typ, maxNumIn)
   333  		case mt.Type.NumOut() == 0:
   334  			return zero, fmt.Errorf("%s is a method of type %s but returns no output", elemName, typ)
   335  		case mt.Type.NumOut() > 2:
   336  			return zero, fmt.Errorf("%s is a method of type %s but returns more than 2 outputs", elemName, typ)
   337  		case mt.Type.NumOut() == 1 && mt.Type.Out(0).Implements(errorType):
   338  			return zero, fmt.Errorf("%s is a method of type %s but only returns an error type", elemName, typ)
   339  		case mt.Type.NumOut() == 2 && !mt.Type.Out(1).Implements(errorType):
   340  			return zero, fmt.Errorf("%s is a method of type %s returning two values but the second value is not an error type", elemName, typ)
   341  		}
   342  		res := objPtr.Method(mt.Index).Call(args)
   343  		if len(res) == 2 && !res[1].IsNil() {
   344  			return zero, fmt.Errorf("error at calling a method %s of type %s: %s", elemName, typ, res[1].Interface().(error))
   345  		}
   346  		return res[0], nil
   347  	}
   348  
   349  	// elemName isn't a method so next start to check whether it is
   350  	// a struct field or a map value. In both cases, it mustn't be
   351  	// a nil value
   352  	if isNil {
   353  		return zero, fmt.Errorf("can't evaluate a nil pointer of type %s by a struct field or map key name %s", typ, elemName)
   354  	}
   355  	switch obj.Kind() {
   356  	case reflect.Struct:
   357  		ft, ok := obj.Type().FieldByName(elemName)
   358  		if ok {
   359  			if ft.PkgPath != "" && !ft.Anonymous {
   360  				return zero, fmt.Errorf("%s is an unexported field of struct type %s", elemName, typ)
   361  			}
   362  			return obj.FieldByIndex(ft.Index), nil
   363  		}
   364  		return zero, fmt.Errorf("%s isn't a field of struct type %s", elemName, typ)
   365  	case reflect.Map:
   366  		kv := reflect.ValueOf(elemName)
   367  		if kv.Type().AssignableTo(obj.Type().Key()) {
   368  			return obj.MapIndex(kv), nil
   369  		}
   370  		return zero, fmt.Errorf("%s isn't a key of map type %s", elemName, typ)
   371  	}
   372  	return zero, fmt.Errorf("%s is neither a struct field, a method nor a map element of type %s", elemName, typ)
   373  }
   374  
   375  // parseWhereArgs parses the end arguments to the where function.  Return a
   376  // match value and an operator, if one is defined.
   377  func parseWhereArgs(args ...any) (mv reflect.Value, op string, err error) {
   378  	switch len(args) {
   379  	case 1:
   380  		mv = reflect.ValueOf(args[0])
   381  	case 2:
   382  		var ok bool
   383  		if op, ok = args[0].(string); !ok {
   384  			err = errors.New("operator argument must be string type")
   385  			return
   386  		}
   387  		op = strings.TrimSpace(strings.ToLower(op))
   388  		mv = reflect.ValueOf(args[1])
   389  	default:
   390  		err = errors.New("can't evaluate the array by no match argument or more than or equal to two arguments")
   391  	}
   392  	return
   393  }
   394  
   395  // checkWhereArray handles the where-matching logic when the seqv value is an
   396  // Array or Slice.
   397  func (ns *Namespace) checkWhereArray(ctxv, seqv, kv, mv reflect.Value, path []string, op string) (any, error) {
   398  	rv := reflect.MakeSlice(seqv.Type(), 0, 0)
   399  
   400  	for i := 0; i < seqv.Len(); i++ {
   401  		var vvv reflect.Value
   402  		rvv := seqv.Index(i)
   403  
   404  		if kv.Kind() == reflect.String {
   405  			if params, ok := rvv.Interface().(maps.Params); ok {
   406  				vvv = reflect.ValueOf(params.GetNested(path...))
   407  			} else {
   408  				vvv = rvv
   409  				for i, elemName := range path {
   410  					var err error
   411  					vvv, err = evaluateSubElem(ctxv, vvv, elemName)
   412  
   413  					if err != nil {
   414  						continue
   415  					}
   416  
   417  					if i < len(path)-1 && vvv.IsValid() {
   418  						if params, ok := vvv.Interface().(maps.Params); ok {
   419  							// The current path element is the map itself, .Params.
   420  							vvv = reflect.ValueOf(params.GetNested(path[i+1:]...))
   421  							break
   422  						}
   423  					}
   424  				}
   425  			}
   426  		} else {
   427  			vv, _ := indirect(rvv)
   428  			if vv.Kind() == reflect.Map && kv.Type().AssignableTo(vv.Type().Key()) {
   429  				vvv = vv.MapIndex(kv)
   430  			}
   431  		}
   432  
   433  		if ok, err := ns.checkCondition(vvv, mv, op); ok {
   434  			rv = reflect.Append(rv, rvv)
   435  		} else if err != nil {
   436  			return nil, err
   437  		}
   438  	}
   439  	return rv.Interface(), nil
   440  }
   441  
   442  // checkWhereMap handles the where-matching logic when the seqv value is a Map.
   443  func (ns *Namespace) checkWhereMap(ctxv, seqv, kv, mv reflect.Value, path []string, op string) (any, error) {
   444  	rv := reflect.MakeMap(seqv.Type())
   445  	keys := seqv.MapKeys()
   446  	for _, k := range keys {
   447  		elemv := seqv.MapIndex(k)
   448  		switch elemv.Kind() {
   449  		case reflect.Array, reflect.Slice:
   450  			r, err := ns.checkWhereArray(ctxv, elemv, kv, mv, path, op)
   451  			if err != nil {
   452  				return nil, err
   453  			}
   454  
   455  			switch rr := reflect.ValueOf(r); rr.Kind() {
   456  			case reflect.Slice:
   457  				if rr.Len() > 0 {
   458  					rv.SetMapIndex(k, elemv)
   459  				}
   460  			}
   461  		case reflect.Interface:
   462  			elemvv, isNil := indirect(elemv)
   463  			if isNil {
   464  				continue
   465  			}
   466  
   467  			switch elemvv.Kind() {
   468  			case reflect.Array, reflect.Slice:
   469  				r, err := ns.checkWhereArray(ctxv, elemvv, kv, mv, path, op)
   470  				if err != nil {
   471  					return nil, err
   472  				}
   473  
   474  				switch rr := reflect.ValueOf(r); rr.Kind() {
   475  				case reflect.Slice:
   476  					if rr.Len() > 0 {
   477  						rv.SetMapIndex(k, elemv)
   478  					}
   479  				}
   480  			}
   481  		}
   482  	}
   483  	return rv.Interface(), nil
   484  }
   485  
   486  // toFloat returns the float value if possible.
   487  func toFloat(v reflect.Value) (float64, error) {
   488  	switch v.Kind() {
   489  	case reflect.Float32, reflect.Float64:
   490  		return v.Float(), nil
   491  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   492  		return v.Convert(reflect.TypeOf(float64(0))).Float(), nil
   493  	case reflect.Interface:
   494  		return toFloat(v.Elem())
   495  	}
   496  	return -1, errors.New("unable to convert value to float")
   497  }
   498  
   499  // toInt returns the int value if possible, -1 if not.
   500  // TODO(bep) consolidate all these reflect funcs.
   501  func toInt(v reflect.Value) (int64, error) {
   502  	switch v.Kind() {
   503  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   504  		return v.Int(), nil
   505  	case reflect.Interface:
   506  		return toInt(v.Elem())
   507  	}
   508  	return -1, errors.New("unable to convert value to int")
   509  }
   510  
   511  func toUint(v reflect.Value) (uint64, error) {
   512  	switch v.Kind() {
   513  	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
   514  		return v.Uint(), nil
   515  	case reflect.Interface:
   516  		return toUint(v.Elem())
   517  	}
   518  	return 0, errors.New("unable to convert value to uint")
   519  }
   520  
   521  // toString returns the string value if possible, "" if not.
   522  func toString(v reflect.Value) (string, error) {
   523  	switch v.Kind() {
   524  	case reflect.String:
   525  		return v.String(), nil
   526  	case reflect.Interface:
   527  		return toString(v.Elem())
   528  	}
   529  	return "", errors.New("unable to convert value to string")
   530  }
   531  
   532  func (ns *Namespace) toTimeUnix(v reflect.Value) int64 {
   533  	t, ok := hreflect.AsTime(v, ns.loc)
   534  	if !ok {
   535  		panic("coding error: argument must be time.Time type reflect Value")
   536  	}
   537  	return t.Unix()
   538  }