github.com/linchen2chris/hugo@v0.0.0-20230307053224-cec209389705/tpl/collections/where.go (about)

     1  // Copyright 2017 The Hugo Authors. All rights reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  // http://www.apache.org/licenses/LICENSE-2.0
     7  //
     8  // Unless required by applicable law or agreed to in writing, software
     9  // distributed under the License is distributed on an "AS IS" BASIS,
    10  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package collections
    15  
    16  import (
    17  	"errors"
    18  	"fmt"
    19  	"reflect"
    20  	"strings"
    21  
    22  	"github.com/gohugoio/hugo/common/hreflect"
    23  	"github.com/gohugoio/hugo/common/maps"
    24  )
    25  
    26  // Where returns a filtered subset of collection c.
    27  func (ns *Namespace) Where(c, key any, args ...any) (any, error) {
    28  	seqv, isNil := indirect(reflect.ValueOf(c))
    29  	if isNil {
    30  		return nil, errors.New("can't iterate over a nil value of type " + reflect.ValueOf(c).Type().String())
    31  	}
    32  
    33  	mv, op, err := parseWhereArgs(args...)
    34  	if err != nil {
    35  		return nil, err
    36  	}
    37  
    38  	var path []string
    39  	kv := reflect.ValueOf(key)
    40  	if kv.Kind() == reflect.String {
    41  		path = strings.Split(strings.Trim(kv.String(), "."), ".")
    42  	}
    43  
    44  	switch seqv.Kind() {
    45  	case reflect.Array, reflect.Slice:
    46  		return ns.checkWhereArray(seqv, kv, mv, path, op)
    47  	case reflect.Map:
    48  		return ns.checkWhereMap(seqv, kv, mv, path, op)
    49  	default:
    50  		return nil, fmt.Errorf("can't iterate over %v", c)
    51  	}
    52  }
    53  
    54  func (ns *Namespace) checkCondition(v, mv reflect.Value, op string) (bool, error) {
    55  	v, vIsNil := indirect(v)
    56  	if !v.IsValid() {
    57  		vIsNil = true
    58  	}
    59  
    60  	mv, mvIsNil := indirect(mv)
    61  	if !mv.IsValid() {
    62  		mvIsNil = true
    63  	}
    64  	if vIsNil || mvIsNil {
    65  		switch op {
    66  		case "", "=", "==", "eq":
    67  			return vIsNil == mvIsNil, nil
    68  		case "!=", "<>", "ne":
    69  			return vIsNil != mvIsNil, nil
    70  		}
    71  		return false, nil
    72  	}
    73  
    74  	if v.Kind() == reflect.Bool && mv.Kind() == reflect.Bool {
    75  		switch op {
    76  		case "", "=", "==", "eq":
    77  			return v.Bool() == mv.Bool(), nil
    78  		case "!=", "<>", "ne":
    79  			return v.Bool() != mv.Bool(), nil
    80  		}
    81  		return false, nil
    82  	}
    83  
    84  	var ivp, imvp *int64
    85  	var fvp, fmvp *float64
    86  	var svp, smvp *string
    87  	var slv, slmv any
    88  	var ima []int64
    89  	var fma []float64
    90  	var sma []string
    91  
    92  	if mv.Kind() == v.Kind() {
    93  		switch v.Kind() {
    94  		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
    95  			iv := v.Int()
    96  			ivp = &iv
    97  			imv := mv.Int()
    98  			imvp = &imv
    99  		case reflect.String:
   100  			sv := v.String()
   101  			svp = &sv
   102  			smv := mv.String()
   103  			smvp = &smv
   104  		case reflect.Float64:
   105  			fv := v.Float()
   106  			fvp = &fv
   107  			fmv := mv.Float()
   108  			fmvp = &fmv
   109  		case reflect.Struct:
   110  			if hreflect.IsTime(v.Type()) {
   111  				iv := ns.toTimeUnix(v)
   112  				ivp = &iv
   113  				imv := ns.toTimeUnix(mv)
   114  				imvp = &imv
   115  			}
   116  		case reflect.Array, reflect.Slice:
   117  			slv = v.Interface()
   118  			slmv = mv.Interface()
   119  		}
   120  	} else if isNumber(v.Kind()) && isNumber(mv.Kind()) {
   121  		fv, err := toFloat(v)
   122  		if err != nil {
   123  			return false, err
   124  		}
   125  		fvp = &fv
   126  		fmv, err := toFloat(mv)
   127  		if err != nil {
   128  			return false, err
   129  		}
   130  		fmvp = &fmv
   131  	} else {
   132  		if mv.Kind() != reflect.Array && mv.Kind() != reflect.Slice {
   133  			return false, nil
   134  		}
   135  
   136  		if mv.Len() == 0 {
   137  			return false, nil
   138  		}
   139  
   140  		if v.Kind() != reflect.Interface && mv.Type().Elem().Kind() != reflect.Interface && mv.Type().Elem() != v.Type() && v.Kind() != reflect.Array && v.Kind() != reflect.Slice {
   141  			return false, nil
   142  		}
   143  		switch v.Kind() {
   144  		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   145  			iv := v.Int()
   146  			ivp = &iv
   147  			for i := 0; i < mv.Len(); i++ {
   148  				if anInt, err := toInt(mv.Index(i)); err == nil {
   149  					ima = append(ima, anInt)
   150  				}
   151  			}
   152  		case reflect.String:
   153  			sv := v.String()
   154  			svp = &sv
   155  			for i := 0; i < mv.Len(); i++ {
   156  				if aString, err := toString(mv.Index(i)); err == nil {
   157  					sma = append(sma, aString)
   158  				}
   159  			}
   160  		case reflect.Float64:
   161  			fv := v.Float()
   162  			fvp = &fv
   163  			for i := 0; i < mv.Len(); i++ {
   164  				if aFloat, err := toFloat(mv.Index(i)); err == nil {
   165  					fma = append(fma, aFloat)
   166  				}
   167  			}
   168  		case reflect.Struct:
   169  			if hreflect.IsTime(v.Type()) {
   170  				iv := ns.toTimeUnix(v)
   171  				ivp = &iv
   172  				for i := 0; i < mv.Len(); i++ {
   173  					ima = append(ima, ns.toTimeUnix(mv.Index(i)))
   174  				}
   175  			}
   176  		case reflect.Array, reflect.Slice:
   177  			slv = v.Interface()
   178  			slmv = mv.Interface()
   179  		}
   180  	}
   181  
   182  	switch op {
   183  	case "", "=", "==", "eq":
   184  		switch {
   185  		case ivp != nil && imvp != nil:
   186  			return *ivp == *imvp, nil
   187  		case svp != nil && smvp != nil:
   188  			return *svp == *smvp, nil
   189  		case fvp != nil && fmvp != nil:
   190  			return *fvp == *fmvp, nil
   191  		}
   192  	case "!=", "<>", "ne":
   193  		switch {
   194  		case ivp != nil && imvp != nil:
   195  			return *ivp != *imvp, nil
   196  		case svp != nil && smvp != nil:
   197  			return *svp != *smvp, nil
   198  		case fvp != nil && fmvp != nil:
   199  			return *fvp != *fmvp, nil
   200  		}
   201  	case ">=", "ge":
   202  		switch {
   203  		case ivp != nil && imvp != nil:
   204  			return *ivp >= *imvp, nil
   205  		case svp != nil && smvp != nil:
   206  			return *svp >= *smvp, nil
   207  		case fvp != nil && fmvp != nil:
   208  			return *fvp >= *fmvp, nil
   209  		}
   210  	case ">", "gt":
   211  		switch {
   212  		case ivp != nil && imvp != nil:
   213  			return *ivp > *imvp, nil
   214  		case svp != nil && smvp != nil:
   215  			return *svp > *smvp, nil
   216  		case fvp != nil && fmvp != nil:
   217  			return *fvp > *fmvp, nil
   218  		}
   219  	case "<=", "le":
   220  		switch {
   221  		case ivp != nil && imvp != nil:
   222  			return *ivp <= *imvp, nil
   223  		case svp != nil && smvp != nil:
   224  			return *svp <= *smvp, nil
   225  		case fvp != nil && fmvp != nil:
   226  			return *fvp <= *fmvp, nil
   227  		}
   228  	case "<", "lt":
   229  		switch {
   230  		case ivp != nil && imvp != nil:
   231  			return *ivp < *imvp, nil
   232  		case svp != nil && smvp != nil:
   233  			return *svp < *smvp, nil
   234  		case fvp != nil && fmvp != nil:
   235  			return *fvp < *fmvp, nil
   236  		}
   237  	case "in", "not in":
   238  		var r bool
   239  		switch {
   240  		case ivp != nil && len(ima) > 0:
   241  			r, _ = ns.In(ima, *ivp)
   242  		case fvp != nil && len(fma) > 0:
   243  			r, _ = ns.In(fma, *fvp)
   244  		case svp != nil:
   245  			if len(sma) > 0 {
   246  				r, _ = ns.In(sma, *svp)
   247  			} else if smvp != nil {
   248  				r, _ = ns.In(*smvp, *svp)
   249  			}
   250  		default:
   251  			return false, nil
   252  		}
   253  		if op == "not in" {
   254  			return !r, nil
   255  		}
   256  		return r, nil
   257  	case "intersect":
   258  		r, err := ns.Intersect(slv, slmv)
   259  		if err != nil {
   260  			return false, err
   261  		}
   262  
   263  		if reflect.TypeOf(r).Kind() == reflect.Slice {
   264  			s := reflect.ValueOf(r)
   265  
   266  			if s.Len() > 0 {
   267  				return true, nil
   268  			}
   269  			return false, nil
   270  		}
   271  		return false, errors.New("invalid intersect values")
   272  	default:
   273  		return false, errors.New("no such operator")
   274  	}
   275  	return false, nil
   276  }
   277  
   278  func evaluateSubElem(obj reflect.Value, elemName string) (reflect.Value, error) {
   279  	if !obj.IsValid() {
   280  		return zero, errors.New("can't evaluate an invalid value")
   281  	}
   282  
   283  	typ := obj.Type()
   284  	obj, isNil := indirect(obj)
   285  
   286  	if obj.Kind() == reflect.Interface {
   287  		// If obj is an interface, we need to inspect the value it contains
   288  		// to see the full set of methods and fields.
   289  		// Indirect returns the value that it points to, which is what's needed
   290  		// below to be able to reflect on its fields.
   291  		obj = reflect.Indirect(obj.Elem())
   292  	}
   293  
   294  	// first, check whether obj has a method. In this case, obj is
   295  	// a struct or its pointer. If obj is a struct,
   296  	// to check all T and *T method, use obj pointer type Value
   297  	objPtr := obj
   298  	if objPtr.Kind() != reflect.Interface && objPtr.CanAddr() {
   299  		objPtr = objPtr.Addr()
   300  	}
   301  
   302  	index := hreflect.GetMethodIndexByName(objPtr.Type(), elemName)
   303  	if index != -1 {
   304  		mt := objPtr.Type().Method(index)
   305  		switch {
   306  		case mt.PkgPath != "":
   307  			return zero, fmt.Errorf("%s is an unexported method of type %s", elemName, typ)
   308  		case mt.Type.NumIn() > 1:
   309  			return zero, fmt.Errorf("%s is a method of type %s but requires more than 1 parameter", elemName, typ)
   310  		case mt.Type.NumOut() == 0:
   311  			return zero, fmt.Errorf("%s is a method of type %s but returns no output", elemName, typ)
   312  		case mt.Type.NumOut() > 2:
   313  			return zero, fmt.Errorf("%s is a method of type %s but returns more than 2 outputs", elemName, typ)
   314  		case mt.Type.NumOut() == 1 && mt.Type.Out(0).Implements(errorType):
   315  			return zero, fmt.Errorf("%s is a method of type %s but only returns an error type", elemName, typ)
   316  		case mt.Type.NumOut() == 2 && !mt.Type.Out(1).Implements(errorType):
   317  			return zero, fmt.Errorf("%s is a method of type %s returning two values but the second value is not an error type", elemName, typ)
   318  		}
   319  		res := objPtr.Method(mt.Index).Call([]reflect.Value{})
   320  		if len(res) == 2 && !res[1].IsNil() {
   321  			return zero, fmt.Errorf("error at calling a method %s of type %s: %s", elemName, typ, res[1].Interface().(error))
   322  		}
   323  		return res[0], nil
   324  	}
   325  
   326  	// elemName isn't a method so next start to check whether it is
   327  	// a struct field or a map value. In both cases, it mustn't be
   328  	// a nil value
   329  	if isNil {
   330  		return zero, fmt.Errorf("can't evaluate a nil pointer of type %s by a struct field or map key name %s", typ, elemName)
   331  	}
   332  	switch obj.Kind() {
   333  	case reflect.Struct:
   334  		ft, ok := obj.Type().FieldByName(elemName)
   335  		if ok {
   336  			if ft.PkgPath != "" && !ft.Anonymous {
   337  				return zero, fmt.Errorf("%s is an unexported field of struct type %s", elemName, typ)
   338  			}
   339  			return obj.FieldByIndex(ft.Index), nil
   340  		}
   341  		return zero, fmt.Errorf("%s isn't a field of struct type %s", elemName, typ)
   342  	case reflect.Map:
   343  		kv := reflect.ValueOf(elemName)
   344  		if kv.Type().AssignableTo(obj.Type().Key()) {
   345  			return obj.MapIndex(kv), nil
   346  		}
   347  		return zero, fmt.Errorf("%s isn't a key of map type %s", elemName, typ)
   348  	}
   349  	return zero, fmt.Errorf("%s is neither a struct field, a method nor a map element of type %s", elemName, typ)
   350  }
   351  
   352  // parseWhereArgs parses the end arguments to the where function.  Return a
   353  // match value and an operator, if one is defined.
   354  func parseWhereArgs(args ...any) (mv reflect.Value, op string, err error) {
   355  	switch len(args) {
   356  	case 1:
   357  		mv = reflect.ValueOf(args[0])
   358  	case 2:
   359  		var ok bool
   360  		if op, ok = args[0].(string); !ok {
   361  			err = errors.New("operator argument must be string type")
   362  			return
   363  		}
   364  		op = strings.TrimSpace(strings.ToLower(op))
   365  		mv = reflect.ValueOf(args[1])
   366  	default:
   367  		err = errors.New("can't evaluate the array by no match argument or more than or equal to two arguments")
   368  	}
   369  	return
   370  }
   371  
   372  // checkWhereArray handles the where-matching logic when the seqv value is an
   373  // Array or Slice.
   374  func (ns *Namespace) checkWhereArray(seqv, kv, mv reflect.Value, path []string, op string) (any, error) {
   375  	rv := reflect.MakeSlice(seqv.Type(), 0, 0)
   376  
   377  	for i := 0; i < seqv.Len(); i++ {
   378  		var vvv reflect.Value
   379  		rvv := seqv.Index(i)
   380  
   381  		if kv.Kind() == reflect.String {
   382  			if params, ok := rvv.Interface().(maps.Params); ok {
   383  				vvv = reflect.ValueOf(params.Get(path...))
   384  			} else {
   385  				vvv = rvv
   386  				for i, elemName := range path {
   387  					var err error
   388  					vvv, err = evaluateSubElem(vvv, elemName)
   389  
   390  					if err != nil {
   391  						continue
   392  					}
   393  
   394  					if i < len(path)-1 && vvv.IsValid() {
   395  						if params, ok := vvv.Interface().(maps.Params); ok {
   396  							// The current path element is the map itself, .Params.
   397  							vvv = reflect.ValueOf(params.Get(path[i+1:]...))
   398  							break
   399  						}
   400  					}
   401  				}
   402  			}
   403  		} else {
   404  			vv, _ := indirect(rvv)
   405  			if vv.Kind() == reflect.Map && kv.Type().AssignableTo(vv.Type().Key()) {
   406  				vvv = vv.MapIndex(kv)
   407  			}
   408  		}
   409  
   410  		if ok, err := ns.checkCondition(vvv, mv, op); ok {
   411  			rv = reflect.Append(rv, rvv)
   412  		} else if err != nil {
   413  			return nil, err
   414  		}
   415  	}
   416  	return rv.Interface(), nil
   417  }
   418  
   419  // checkWhereMap handles the where-matching logic when the seqv value is a Map.
   420  func (ns *Namespace) checkWhereMap(seqv, kv, mv reflect.Value, path []string, op string) (any, error) {
   421  	rv := reflect.MakeMap(seqv.Type())
   422  	keys := seqv.MapKeys()
   423  	for _, k := range keys {
   424  		elemv := seqv.MapIndex(k)
   425  		switch elemv.Kind() {
   426  		case reflect.Array, reflect.Slice:
   427  			r, err := ns.checkWhereArray(elemv, kv, mv, path, op)
   428  			if err != nil {
   429  				return nil, err
   430  			}
   431  
   432  			switch rr := reflect.ValueOf(r); rr.Kind() {
   433  			case reflect.Slice:
   434  				if rr.Len() > 0 {
   435  					rv.SetMapIndex(k, elemv)
   436  				}
   437  			}
   438  		case reflect.Interface:
   439  			elemvv, isNil := indirect(elemv)
   440  			if isNil {
   441  				continue
   442  			}
   443  
   444  			switch elemvv.Kind() {
   445  			case reflect.Array, reflect.Slice:
   446  				r, err := ns.checkWhereArray(elemvv, kv, mv, path, op)
   447  				if err != nil {
   448  					return nil, err
   449  				}
   450  
   451  				switch rr := reflect.ValueOf(r); rr.Kind() {
   452  				case reflect.Slice:
   453  					if rr.Len() > 0 {
   454  						rv.SetMapIndex(k, elemv)
   455  					}
   456  				}
   457  			}
   458  		}
   459  	}
   460  	return rv.Interface(), nil
   461  }
   462  
   463  // toFloat returns the float value if possible.
   464  func toFloat(v reflect.Value) (float64, error) {
   465  	switch v.Kind() {
   466  	case reflect.Float32, reflect.Float64:
   467  		return v.Float(), nil
   468  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   469  		return v.Convert(reflect.TypeOf(float64(0))).Float(), nil
   470  	case reflect.Interface:
   471  		return toFloat(v.Elem())
   472  	}
   473  	return -1, errors.New("unable to convert value to float")
   474  }
   475  
   476  // toInt returns the int value if possible, -1 if not.
   477  // TODO(bep) consolidate all these reflect funcs.
   478  func toInt(v reflect.Value) (int64, error) {
   479  	switch v.Kind() {
   480  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   481  		return v.Int(), nil
   482  	case reflect.Interface:
   483  		return toInt(v.Elem())
   484  	}
   485  	return -1, errors.New("unable to convert value to int")
   486  }
   487  
   488  func toUint(v reflect.Value) (uint64, error) {
   489  	switch v.Kind() {
   490  	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
   491  		return v.Uint(), nil
   492  	case reflect.Interface:
   493  		return toUint(v.Elem())
   494  	}
   495  	return 0, errors.New("unable to convert value to uint")
   496  }
   497  
   498  // toString returns the string value if possible, "" if not.
   499  func toString(v reflect.Value) (string, error) {
   500  	switch v.Kind() {
   501  	case reflect.String:
   502  		return v.String(), nil
   503  	case reflect.Interface:
   504  		return toString(v.Elem())
   505  	}
   506  	return "", errors.New("unable to convert value to string")
   507  }
   508  
   509  func (ns *Namespace) toTimeUnix(v reflect.Value) int64 {
   510  	t, ok := hreflect.AsTime(v, ns.loc)
   511  	if !ok {
   512  		panic("coding error: argument must be time.Time type reflect Value")
   513  	}
   514  	return t.Unix()
   515  }