github.com/kovansky/hugo@v0.92.3-0.20220224232819-63076e4ff19f/tpl/collections/where.go (about)

     1  // Copyright 2017 The Hugo Authors. All rights reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  // http://www.apache.org/licenses/LICENSE-2.0
     7  //
     8  // Unless required by applicable law or agreed to in writing, software
     9  // distributed under the License is distributed on an "AS IS" BASIS,
    10  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package collections
    15  
    16  import (
    17  	"errors"
    18  	"fmt"
    19  	"reflect"
    20  	"strings"
    21  
    22  	"github.com/gohugoio/hugo/common/maps"
    23  )
    24  
    25  // Where returns a filtered subset of a given data type.
    26  func (ns *Namespace) Where(seq, key interface{}, args ...interface{}) (interface{}, error) {
    27  	seqv, isNil := indirect(reflect.ValueOf(seq))
    28  	if isNil {
    29  		return nil, errors.New("can't iterate over a nil value of type " + reflect.ValueOf(seq).Type().String())
    30  	}
    31  
    32  	mv, op, err := parseWhereArgs(args...)
    33  	if err != nil {
    34  		return nil, err
    35  	}
    36  
    37  	var path []string
    38  	kv := reflect.ValueOf(key)
    39  	if kv.Kind() == reflect.String {
    40  		path = strings.Split(strings.Trim(kv.String(), "."), ".")
    41  	}
    42  
    43  	switch seqv.Kind() {
    44  	case reflect.Array, reflect.Slice:
    45  		return ns.checkWhereArray(seqv, kv, mv, path, op)
    46  	case reflect.Map:
    47  		return ns.checkWhereMap(seqv, kv, mv, path, op)
    48  	default:
    49  		return nil, fmt.Errorf("can't iterate over %v", seq)
    50  	}
    51  }
    52  
    53  func (ns *Namespace) checkCondition(v, mv reflect.Value, op string) (bool, error) {
    54  	v, vIsNil := indirect(v)
    55  	if !v.IsValid() {
    56  		vIsNil = true
    57  	}
    58  
    59  	mv, mvIsNil := indirect(mv)
    60  	if !mv.IsValid() {
    61  		mvIsNil = true
    62  	}
    63  	if vIsNil || mvIsNil {
    64  		switch op {
    65  		case "", "=", "==", "eq":
    66  			return vIsNil == mvIsNil, nil
    67  		case "!=", "<>", "ne":
    68  			return vIsNil != mvIsNil, nil
    69  		}
    70  		return false, nil
    71  	}
    72  
    73  	if v.Kind() == reflect.Bool && mv.Kind() == reflect.Bool {
    74  		switch op {
    75  		case "", "=", "==", "eq":
    76  			return v.Bool() == mv.Bool(), nil
    77  		case "!=", "<>", "ne":
    78  			return v.Bool() != mv.Bool(), nil
    79  		}
    80  		return false, nil
    81  	}
    82  
    83  	var ivp, imvp *int64
    84  	var fvp, fmvp *float64
    85  	var svp, smvp *string
    86  	var slv, slmv interface{}
    87  	var ima []int64
    88  	var fma []float64
    89  	var sma []string
    90  
    91  	if mv.Kind() == v.Kind() {
    92  		switch v.Kind() {
    93  		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
    94  			iv := v.Int()
    95  			ivp = &iv
    96  			imv := mv.Int()
    97  			imvp = &imv
    98  		case reflect.String:
    99  			sv := v.String()
   100  			svp = &sv
   101  			smv := mv.String()
   102  			smvp = &smv
   103  		case reflect.Float64:
   104  			fv := v.Float()
   105  			fvp = &fv
   106  			fmv := mv.Float()
   107  			fmvp = &fmv
   108  		case reflect.Struct:
   109  			switch v.Type() {
   110  			case timeType:
   111  				iv := toTimeUnix(v)
   112  				ivp = &iv
   113  				imv := toTimeUnix(mv)
   114  				imvp = &imv
   115  			}
   116  		case reflect.Array, reflect.Slice:
   117  			slv = v.Interface()
   118  			slmv = mv.Interface()
   119  		}
   120  	} else if isNumber(v.Kind()) && isNumber(mv.Kind()) {
   121  		fv, err := toFloat(v)
   122  		if err != nil {
   123  			return false, err
   124  		}
   125  		fvp = &fv
   126  		fmv, err := toFloat(mv)
   127  		if err != nil {
   128  			return false, err
   129  		}
   130  		fmvp = &fmv
   131  	} else {
   132  		if mv.Kind() != reflect.Array && mv.Kind() != reflect.Slice {
   133  			return false, nil
   134  		}
   135  
   136  		if mv.Len() == 0 {
   137  			return false, nil
   138  		}
   139  
   140  		if v.Kind() != reflect.Interface && mv.Type().Elem().Kind() != reflect.Interface && mv.Type().Elem() != v.Type() && v.Kind() != reflect.Array && v.Kind() != reflect.Slice {
   141  			return false, nil
   142  		}
   143  		switch v.Kind() {
   144  		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   145  			iv := v.Int()
   146  			ivp = &iv
   147  			for i := 0; i < mv.Len(); i++ {
   148  				if anInt, err := toInt(mv.Index(i)); err == nil {
   149  					ima = append(ima, anInt)
   150  				}
   151  			}
   152  		case reflect.String:
   153  			sv := v.String()
   154  			svp = &sv
   155  			for i := 0; i < mv.Len(); i++ {
   156  				if aString, err := toString(mv.Index(i)); err == nil {
   157  					sma = append(sma, aString)
   158  				}
   159  			}
   160  		case reflect.Float64:
   161  			fv := v.Float()
   162  			fvp = &fv
   163  			for i := 0; i < mv.Len(); i++ {
   164  				if aFloat, err := toFloat(mv.Index(i)); err == nil {
   165  					fma = append(fma, aFloat)
   166  				}
   167  			}
   168  		case reflect.Struct:
   169  			switch v.Type() {
   170  			case timeType:
   171  				iv := toTimeUnix(v)
   172  				ivp = &iv
   173  				for i := 0; i < mv.Len(); i++ {
   174  					ima = append(ima, toTimeUnix(mv.Index(i)))
   175  				}
   176  			}
   177  		case reflect.Array, reflect.Slice:
   178  			slv = v.Interface()
   179  			slmv = mv.Interface()
   180  		}
   181  	}
   182  
   183  	switch op {
   184  	case "", "=", "==", "eq":
   185  		switch {
   186  		case ivp != nil && imvp != nil:
   187  			return *ivp == *imvp, nil
   188  		case svp != nil && smvp != nil:
   189  			return *svp == *smvp, nil
   190  		case fvp != nil && fmvp != nil:
   191  			return *fvp == *fmvp, nil
   192  		}
   193  	case "!=", "<>", "ne":
   194  		switch {
   195  		case ivp != nil && imvp != nil:
   196  			return *ivp != *imvp, nil
   197  		case svp != nil && smvp != nil:
   198  			return *svp != *smvp, nil
   199  		case fvp != nil && fmvp != nil:
   200  			return *fvp != *fmvp, nil
   201  		}
   202  	case ">=", "ge":
   203  		switch {
   204  		case ivp != nil && imvp != nil:
   205  			return *ivp >= *imvp, nil
   206  		case svp != nil && smvp != nil:
   207  			return *svp >= *smvp, nil
   208  		case fvp != nil && fmvp != nil:
   209  			return *fvp >= *fmvp, nil
   210  		}
   211  	case ">", "gt":
   212  		switch {
   213  		case ivp != nil && imvp != nil:
   214  			return *ivp > *imvp, nil
   215  		case svp != nil && smvp != nil:
   216  			return *svp > *smvp, nil
   217  		case fvp != nil && fmvp != nil:
   218  			return *fvp > *fmvp, nil
   219  		}
   220  	case "<=", "le":
   221  		switch {
   222  		case ivp != nil && imvp != nil:
   223  			return *ivp <= *imvp, nil
   224  		case svp != nil && smvp != nil:
   225  			return *svp <= *smvp, nil
   226  		case fvp != nil && fmvp != nil:
   227  			return *fvp <= *fmvp, nil
   228  		}
   229  	case "<", "lt":
   230  		switch {
   231  		case ivp != nil && imvp != nil:
   232  			return *ivp < *imvp, nil
   233  		case svp != nil && smvp != nil:
   234  			return *svp < *smvp, nil
   235  		case fvp != nil && fmvp != nil:
   236  			return *fvp < *fmvp, nil
   237  		}
   238  	case "in", "not in":
   239  		var r bool
   240  		switch {
   241  		case ivp != nil && len(ima) > 0:
   242  			r, _ = ns.In(ima, *ivp)
   243  		case fvp != nil && len(fma) > 0:
   244  			r, _ = ns.In(fma, *fvp)
   245  		case svp != nil:
   246  			if len(sma) > 0 {
   247  				r, _ = ns.In(sma, *svp)
   248  			} else if smvp != nil {
   249  				r, _ = ns.In(*smvp, *svp)
   250  			}
   251  		default:
   252  			return false, nil
   253  		}
   254  		if op == "not in" {
   255  			return !r, nil
   256  		}
   257  		return r, nil
   258  	case "intersect":
   259  		r, err := ns.Intersect(slv, slmv)
   260  		if err != nil {
   261  			return false, err
   262  		}
   263  
   264  		if reflect.TypeOf(r).Kind() == reflect.Slice {
   265  			s := reflect.ValueOf(r)
   266  
   267  			if s.Len() > 0 {
   268  				return true, nil
   269  			}
   270  			return false, nil
   271  		}
   272  		return false, errors.New("invalid intersect values")
   273  	default:
   274  		return false, errors.New("no such operator")
   275  	}
   276  	return false, nil
   277  }
   278  
   279  func evaluateSubElem(obj reflect.Value, elemName string) (reflect.Value, error) {
   280  	if !obj.IsValid() {
   281  		return zero, errors.New("can't evaluate an invalid value")
   282  	}
   283  
   284  	typ := obj.Type()
   285  	obj, isNil := indirect(obj)
   286  
   287  	if obj.Kind() == reflect.Interface {
   288  		// If obj is an interface, we need to inspect the value it contains
   289  		// to see the full set of methods and fields.
   290  		// Indirect returns the value that it points to, which is what's needed
   291  		// below to be able to reflect on its fields.
   292  		obj = reflect.Indirect(obj.Elem())
   293  	}
   294  
   295  	// first, check whether obj has a method. In this case, obj is
   296  	// a struct or its pointer. If obj is a struct,
   297  	// to check all T and *T method, use obj pointer type Value
   298  	objPtr := obj
   299  	if objPtr.Kind() != reflect.Interface && objPtr.CanAddr() {
   300  		objPtr = objPtr.Addr()
   301  	}
   302  
   303  	mt, ok := objPtr.Type().MethodByName(elemName)
   304  	if ok {
   305  		switch {
   306  		case mt.PkgPath != "":
   307  			return zero, fmt.Errorf("%s is an unexported method of type %s", elemName, typ)
   308  		case mt.Type.NumIn() > 1:
   309  			return zero, fmt.Errorf("%s is a method of type %s but requires more than 1 parameter", elemName, typ)
   310  		case mt.Type.NumOut() == 0:
   311  			return zero, fmt.Errorf("%s is a method of type %s but returns no output", elemName, typ)
   312  		case mt.Type.NumOut() > 2:
   313  			return zero, fmt.Errorf("%s is a method of type %s but returns more than 2 outputs", elemName, typ)
   314  		case mt.Type.NumOut() == 1 && mt.Type.Out(0).Implements(errorType):
   315  			return zero, fmt.Errorf("%s is a method of type %s but only returns an error type", elemName, typ)
   316  		case mt.Type.NumOut() == 2 && !mt.Type.Out(1).Implements(errorType):
   317  			return zero, fmt.Errorf("%s is a method of type %s returning two values but the second value is not an error type", elemName, typ)
   318  		}
   319  		res := objPtr.Method(mt.Index).Call([]reflect.Value{})
   320  		if len(res) == 2 && !res[1].IsNil() {
   321  			return zero, fmt.Errorf("error at calling a method %s of type %s: %s", elemName, typ, res[1].Interface().(error))
   322  		}
   323  		return res[0], nil
   324  	}
   325  
   326  	// elemName isn't a method so next start to check whether it is
   327  	// a struct field or a map value. In both cases, it mustn't be
   328  	// a nil value
   329  	if isNil {
   330  		return zero, fmt.Errorf("can't evaluate a nil pointer of type %s by a struct field or map key name %s", typ, elemName)
   331  	}
   332  	switch obj.Kind() {
   333  	case reflect.Struct:
   334  		ft, ok := obj.Type().FieldByName(elemName)
   335  		if ok {
   336  			if ft.PkgPath != "" && !ft.Anonymous {
   337  				return zero, fmt.Errorf("%s is an unexported field of struct type %s", elemName, typ)
   338  			}
   339  			return obj.FieldByIndex(ft.Index), nil
   340  		}
   341  		return zero, fmt.Errorf("%s isn't a field of struct type %s", elemName, typ)
   342  	case reflect.Map:
   343  		kv := reflect.ValueOf(elemName)
   344  		if kv.Type().AssignableTo(obj.Type().Key()) {
   345  			return obj.MapIndex(kv), nil
   346  		}
   347  		return zero, fmt.Errorf("%s isn't a key of map type %s", elemName, typ)
   348  	}
   349  	return zero, fmt.Errorf("%s is neither a struct field, a method nor a map element of type %s", elemName, typ)
   350  }
   351  
   352  // parseWhereArgs parses the end arguments to the where function.  Return a
   353  // match value and an operator, if one is defined.
   354  func parseWhereArgs(args ...interface{}) (mv reflect.Value, op string, err error) {
   355  	switch len(args) {
   356  	case 1:
   357  		mv = reflect.ValueOf(args[0])
   358  	case 2:
   359  		var ok bool
   360  		if op, ok = args[0].(string); !ok {
   361  			err = errors.New("operator argument must be string type")
   362  			return
   363  		}
   364  		op = strings.TrimSpace(strings.ToLower(op))
   365  		mv = reflect.ValueOf(args[1])
   366  	default:
   367  		err = errors.New("can't evaluate the array by no match argument or more than or equal to two arguments")
   368  	}
   369  	return
   370  }
   371  
   372  // checkWhereArray handles the where-matching logic when the seqv value is an
   373  // Array or Slice.
   374  func (ns *Namespace) checkWhereArray(seqv, kv, mv reflect.Value, path []string, op string) (interface{}, error) {
   375  	rv := reflect.MakeSlice(seqv.Type(), 0, 0)
   376  
   377  	for i := 0; i < seqv.Len(); i++ {
   378  		var vvv reflect.Value
   379  		rvv := seqv.Index(i)
   380  
   381  		if kv.Kind() == reflect.String {
   382  			if params, ok := rvv.Interface().(maps.Params); ok {
   383  				vvv = reflect.ValueOf(params.Get(path...))
   384  			} else {
   385  				vvv = rvv
   386  				for i, elemName := range path {
   387  					var err error
   388  					vvv, err = evaluateSubElem(vvv, elemName)
   389  
   390  					if err != nil {
   391  						continue
   392  					}
   393  
   394  					if i < len(path)-1 && vvv.IsValid() {
   395  						if params, ok := vvv.Interface().(maps.Params); ok {
   396  							// The current path element is the map itself, .Params.
   397  							vvv = reflect.ValueOf(params.Get(path[i+1:]...))
   398  							break
   399  						}
   400  					}
   401  				}
   402  			}
   403  		} else {
   404  			vv, _ := indirect(rvv)
   405  			if vv.Kind() == reflect.Map && kv.Type().AssignableTo(vv.Type().Key()) {
   406  				vvv = vv.MapIndex(kv)
   407  			}
   408  		}
   409  
   410  		if ok, err := ns.checkCondition(vvv, mv, op); ok {
   411  			rv = reflect.Append(rv, rvv)
   412  		} else if err != nil {
   413  			return nil, err
   414  		}
   415  	}
   416  	return rv.Interface(), nil
   417  }
   418  
   419  // checkWhereMap handles the where-matching logic when the seqv value is a Map.
   420  func (ns *Namespace) checkWhereMap(seqv, kv, mv reflect.Value, path []string, op string) (interface{}, error) {
   421  	rv := reflect.MakeMap(seqv.Type())
   422  	keys := seqv.MapKeys()
   423  	for _, k := range keys {
   424  		elemv := seqv.MapIndex(k)
   425  		switch elemv.Kind() {
   426  		case reflect.Array, reflect.Slice:
   427  			r, err := ns.checkWhereArray(elemv, kv, mv, path, op)
   428  			if err != nil {
   429  				return nil, err
   430  			}
   431  
   432  			switch rr := reflect.ValueOf(r); rr.Kind() {
   433  			case reflect.Slice:
   434  				if rr.Len() > 0 {
   435  					rv.SetMapIndex(k, elemv)
   436  				}
   437  			}
   438  		case reflect.Interface:
   439  			elemvv, isNil := indirect(elemv)
   440  			if isNil {
   441  				continue
   442  			}
   443  
   444  			switch elemvv.Kind() {
   445  			case reflect.Array, reflect.Slice:
   446  				r, err := ns.checkWhereArray(elemvv, kv, mv, path, op)
   447  				if err != nil {
   448  					return nil, err
   449  				}
   450  
   451  				switch rr := reflect.ValueOf(r); rr.Kind() {
   452  				case reflect.Slice:
   453  					if rr.Len() > 0 {
   454  						rv.SetMapIndex(k, elemv)
   455  					}
   456  				}
   457  			}
   458  		}
   459  	}
   460  	return rv.Interface(), nil
   461  }
   462  
   463  // toFloat returns the float value if possible.
   464  func toFloat(v reflect.Value) (float64, error) {
   465  	switch v.Kind() {
   466  	case reflect.Float32, reflect.Float64:
   467  		return v.Float(), nil
   468  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   469  		return v.Convert(reflect.TypeOf(float64(0))).Float(), nil
   470  	case reflect.Interface:
   471  		return toFloat(v.Elem())
   472  	}
   473  	return -1, errors.New("unable to convert value to float")
   474  }
   475  
   476  // toInt returns the int value if possible, -1 if not.
   477  // TODO(bep) consolidate all these reflect funcs.
   478  func toInt(v reflect.Value) (int64, error) {
   479  	switch v.Kind() {
   480  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   481  		return v.Int(), nil
   482  	case reflect.Interface:
   483  		return toInt(v.Elem())
   484  	}
   485  	return -1, errors.New("unable to convert value to int")
   486  }
   487  
   488  func toUint(v reflect.Value) (uint64, error) {
   489  	switch v.Kind() {
   490  	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
   491  		return v.Uint(), nil
   492  	case reflect.Interface:
   493  		return toUint(v.Elem())
   494  	}
   495  	return 0, errors.New("unable to convert value to uint")
   496  }
   497  
   498  // toString returns the string value if possible, "" if not.
   499  func toString(v reflect.Value) (string, error) {
   500  	switch v.Kind() {
   501  	case reflect.String:
   502  		return v.String(), nil
   503  	case reflect.Interface:
   504  		return toString(v.Elem())
   505  	}
   506  	return "", errors.New("unable to convert value to string")
   507  }
   508  
   509  func toTimeUnix(v reflect.Value) int64 {
   510  	if v.Kind() == reflect.Interface {
   511  		return toTimeUnix(v.Elem())
   512  	}
   513  	if v.Type() != timeType {
   514  		panic("coding error: argument must be time.Time type reflect Value")
   515  	}
   516  	return v.MethodByName("Unix").Call([]reflect.Value{})[0].Int()
   517  }