github.com/shohhei1126/hugo@v0.42.2-0.20180623210752-3d5928889ad7/tpl/collections/where.go (about)

     1  // Copyright 2017 The Hugo Authors. All rights reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  // http://www.apache.org/licenses/LICENSE-2.0
     7  //
     8  // Unless required by applicable law or agreed to in writing, software
     9  // distributed under the License is distributed on an "AS IS" BASIS,
    10  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package collections
    15  
    16  import (
    17  	"errors"
    18  	"fmt"
    19  	"reflect"
    20  	"strings"
    21  )
    22  
    23  // Where returns a filtered subset of a given data type.
    24  func (ns *Namespace) Where(seq, key interface{}, args ...interface{}) (interface{}, error) {
    25  	seqv, isNil := indirect(reflect.ValueOf(seq))
    26  	if isNil {
    27  		return nil, errors.New("can't iterate over a nil value of type " + reflect.ValueOf(seq).Type().String())
    28  	}
    29  
    30  	mv, op, err := parseWhereArgs(args...)
    31  	if err != nil {
    32  		return nil, err
    33  	}
    34  
    35  	var path []string
    36  	kv := reflect.ValueOf(key)
    37  	if kv.Kind() == reflect.String {
    38  		path = strings.Split(strings.Trim(kv.String(), "."), ".")
    39  	}
    40  
    41  	switch seqv.Kind() {
    42  	case reflect.Array, reflect.Slice:
    43  		return ns.checkWhereArray(seqv, kv, mv, path, op)
    44  	case reflect.Map:
    45  		return ns.checkWhereMap(seqv, kv, mv, path, op)
    46  	default:
    47  		return nil, fmt.Errorf("can't iterate over %v", seq)
    48  	}
    49  }
    50  
    51  func (ns *Namespace) checkCondition(v, mv reflect.Value, op string) (bool, error) {
    52  	v, vIsNil := indirect(v)
    53  	if !v.IsValid() {
    54  		vIsNil = true
    55  	}
    56  
    57  	mv, mvIsNil := indirect(mv)
    58  	if !mv.IsValid() {
    59  		mvIsNil = true
    60  	}
    61  	if vIsNil || mvIsNil {
    62  		switch op {
    63  		case "", "=", "==", "eq":
    64  			return vIsNil == mvIsNil, nil
    65  		case "!=", "<>", "ne":
    66  			return vIsNil != mvIsNil, nil
    67  		}
    68  		return false, nil
    69  	}
    70  
    71  	if v.Kind() == reflect.Bool && mv.Kind() == reflect.Bool {
    72  		switch op {
    73  		case "", "=", "==", "eq":
    74  			return v.Bool() == mv.Bool(), nil
    75  		case "!=", "<>", "ne":
    76  			return v.Bool() != mv.Bool(), nil
    77  		}
    78  		return false, nil
    79  	}
    80  
    81  	var ivp, imvp *int64
    82  	var svp, smvp *string
    83  	var slv, slmv interface{}
    84  	var ima []int64
    85  	var sma []string
    86  	if mv.Type() == v.Type() {
    87  		switch v.Kind() {
    88  		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
    89  			iv := v.Int()
    90  			ivp = &iv
    91  			imv := mv.Int()
    92  			imvp = &imv
    93  		case reflect.String:
    94  			sv := v.String()
    95  			svp = &sv
    96  			smv := mv.String()
    97  			smvp = &smv
    98  		case reflect.Struct:
    99  			switch v.Type() {
   100  			case timeType:
   101  				iv := toTimeUnix(v)
   102  				ivp = &iv
   103  				imv := toTimeUnix(mv)
   104  				imvp = &imv
   105  			}
   106  		case reflect.Array, reflect.Slice:
   107  			slv = v.Interface()
   108  			slmv = mv.Interface()
   109  		}
   110  	} else {
   111  		if mv.Kind() != reflect.Array && mv.Kind() != reflect.Slice {
   112  			return false, nil
   113  		}
   114  
   115  		if mv.Len() == 0 {
   116  			return false, nil
   117  		}
   118  
   119  		if v.Kind() != reflect.Interface && mv.Type().Elem().Kind() != reflect.Interface && mv.Type().Elem() != v.Type() && v.Kind() != reflect.Array && v.Kind() != reflect.Slice {
   120  			return false, nil
   121  		}
   122  		switch v.Kind() {
   123  		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   124  			iv := v.Int()
   125  			ivp = &iv
   126  			for i := 0; i < mv.Len(); i++ {
   127  				if anInt, err := toInt(mv.Index(i)); err == nil {
   128  					ima = append(ima, anInt)
   129  				}
   130  			}
   131  		case reflect.String:
   132  			sv := v.String()
   133  			svp = &sv
   134  			for i := 0; i < mv.Len(); i++ {
   135  				if aString, err := toString(mv.Index(i)); err == nil {
   136  					sma = append(sma, aString)
   137  				}
   138  			}
   139  		case reflect.Struct:
   140  			switch v.Type() {
   141  			case timeType:
   142  				iv := toTimeUnix(v)
   143  				ivp = &iv
   144  				for i := 0; i < mv.Len(); i++ {
   145  					ima = append(ima, toTimeUnix(mv.Index(i)))
   146  				}
   147  			}
   148  		case reflect.Array, reflect.Slice:
   149  			slv = v.Interface()
   150  			slmv = mv.Interface()
   151  		}
   152  	}
   153  
   154  	switch op {
   155  	case "", "=", "==", "eq":
   156  		if ivp != nil && imvp != nil {
   157  			return *ivp == *imvp, nil
   158  		} else if svp != nil && smvp != nil {
   159  			return *svp == *smvp, nil
   160  		}
   161  	case "!=", "<>", "ne":
   162  		if ivp != nil && imvp != nil {
   163  			return *ivp != *imvp, nil
   164  		} else if svp != nil && smvp != nil {
   165  			return *svp != *smvp, nil
   166  		}
   167  	case ">=", "ge":
   168  		if ivp != nil && imvp != nil {
   169  			return *ivp >= *imvp, nil
   170  		} else if svp != nil && smvp != nil {
   171  			return *svp >= *smvp, nil
   172  		}
   173  	case ">", "gt":
   174  		if ivp != nil && imvp != nil {
   175  			return *ivp > *imvp, nil
   176  		} else if svp != nil && smvp != nil {
   177  			return *svp > *smvp, nil
   178  		}
   179  	case "<=", "le":
   180  		if ivp != nil && imvp != nil {
   181  			return *ivp <= *imvp, nil
   182  		} else if svp != nil && smvp != nil {
   183  			return *svp <= *smvp, nil
   184  		}
   185  	case "<", "lt":
   186  		if ivp != nil && imvp != nil {
   187  			return *ivp < *imvp, nil
   188  		} else if svp != nil && smvp != nil {
   189  			return *svp < *smvp, nil
   190  		}
   191  	case "in", "not in":
   192  		var r bool
   193  		if ivp != nil && len(ima) > 0 {
   194  			r = ns.In(ima, *ivp)
   195  		} else if svp != nil {
   196  			if len(sma) > 0 {
   197  				r = ns.In(sma, *svp)
   198  			} else if smvp != nil {
   199  				r = ns.In(*smvp, *svp)
   200  			}
   201  		} else {
   202  			return false, nil
   203  		}
   204  		if op == "not in" {
   205  			return !r, nil
   206  		}
   207  		return r, nil
   208  	case "intersect":
   209  		r, err := ns.Intersect(slv, slmv)
   210  		if err != nil {
   211  			return false, err
   212  		}
   213  
   214  		if reflect.TypeOf(r).Kind() == reflect.Slice {
   215  			s := reflect.ValueOf(r)
   216  
   217  			if s.Len() > 0 {
   218  				return true, nil
   219  			}
   220  			return false, nil
   221  		}
   222  		return false, errors.New("invalid intersect values")
   223  	default:
   224  		return false, errors.New("no such operator")
   225  	}
   226  	return false, nil
   227  }
   228  
   229  func evaluateSubElem(obj reflect.Value, elemName string) (reflect.Value, error) {
   230  	if !obj.IsValid() {
   231  		return zero, errors.New("can't evaluate an invalid value")
   232  	}
   233  	typ := obj.Type()
   234  	obj, isNil := indirect(obj)
   235  
   236  	// first, check whether obj has a method. In this case, obj is
   237  	// an interface, a struct or its pointer. If obj is a struct,
   238  	// to check all T and *T method, use obj pointer type Value
   239  	objPtr := obj
   240  	if objPtr.Kind() != reflect.Interface && objPtr.CanAddr() {
   241  		objPtr = objPtr.Addr()
   242  	}
   243  	mt, ok := objPtr.Type().MethodByName(elemName)
   244  	if ok {
   245  		if mt.PkgPath != "" {
   246  			return zero, fmt.Errorf("%s is an unexported method of type %s", elemName, typ)
   247  		}
   248  		// struct pointer has one receiver argument and interface doesn't have an argument
   249  		if mt.Type.NumIn() > 1 || mt.Type.NumOut() == 0 || mt.Type.NumOut() > 2 {
   250  			return zero, fmt.Errorf("%s is a method of type %s but doesn't satisfy requirements", elemName, typ)
   251  		}
   252  		if mt.Type.NumOut() == 1 && mt.Type.Out(0).Implements(errorType) {
   253  			return zero, fmt.Errorf("%s is a method of type %s but doesn't satisfy requirements", elemName, typ)
   254  		}
   255  		if mt.Type.NumOut() == 2 && !mt.Type.Out(1).Implements(errorType) {
   256  			return zero, fmt.Errorf("%s is a method of type %s but doesn't satisfy requirements", elemName, typ)
   257  		}
   258  		res := objPtr.Method(mt.Index).Call([]reflect.Value{})
   259  		if len(res) == 2 && !res[1].IsNil() {
   260  			return zero, fmt.Errorf("error at calling a method %s of type %s: %s", elemName, typ, res[1].Interface().(error))
   261  		}
   262  		return res[0], nil
   263  	}
   264  
   265  	// elemName isn't a method so next start to check whether it is
   266  	// a struct field or a map value. In both cases, it mustn't be
   267  	// a nil value
   268  	if isNil {
   269  		return zero, fmt.Errorf("can't evaluate a nil pointer of type %s by a struct field or map key name %s", typ, elemName)
   270  	}
   271  	switch obj.Kind() {
   272  	case reflect.Struct:
   273  		ft, ok := obj.Type().FieldByName(elemName)
   274  		if ok {
   275  			if ft.PkgPath != "" && !ft.Anonymous {
   276  				return zero, fmt.Errorf("%s is an unexported field of struct type %s", elemName, typ)
   277  			}
   278  			return obj.FieldByIndex(ft.Index), nil
   279  		}
   280  		return zero, fmt.Errorf("%s isn't a field of struct type %s", elemName, typ)
   281  	case reflect.Map:
   282  		kv := reflect.ValueOf(elemName)
   283  		if kv.Type().AssignableTo(obj.Type().Key()) {
   284  			return obj.MapIndex(kv), nil
   285  		}
   286  		return zero, fmt.Errorf("%s isn't a key of map type %s", elemName, typ)
   287  	}
   288  	return zero, fmt.Errorf("%s is neither a struct field, a method nor a map element of type %s", elemName, typ)
   289  }
   290  
   291  // parseWhereArgs parses the end arguments to the where function.  Return a
   292  // match value and an operator, if one is defined.
   293  func parseWhereArgs(args ...interface{}) (mv reflect.Value, op string, err error) {
   294  	switch len(args) {
   295  	case 1:
   296  		mv = reflect.ValueOf(args[0])
   297  	case 2:
   298  		var ok bool
   299  		if op, ok = args[0].(string); !ok {
   300  			err = errors.New("operator argument must be string type")
   301  			return
   302  		}
   303  		op = strings.TrimSpace(strings.ToLower(op))
   304  		mv = reflect.ValueOf(args[1])
   305  	default:
   306  		err = errors.New("can't evaluate the array by no match argument or more than or equal to two arguments")
   307  	}
   308  	return
   309  }
   310  
   311  // checkWhereArray handles the where-matching logic when the seqv value is an
   312  // Array or Slice.
   313  func (ns *Namespace) checkWhereArray(seqv, kv, mv reflect.Value, path []string, op string) (interface{}, error) {
   314  	rv := reflect.MakeSlice(seqv.Type(), 0, 0)
   315  	for i := 0; i < seqv.Len(); i++ {
   316  		var vvv reflect.Value
   317  		rvv := seqv.Index(i)
   318  		if kv.Kind() == reflect.String {
   319  			vvv = rvv
   320  			for _, elemName := range path {
   321  				var err error
   322  				vvv, err = evaluateSubElem(vvv, elemName)
   323  				if err != nil {
   324  					return nil, err
   325  				}
   326  			}
   327  		} else {
   328  			vv, _ := indirect(rvv)
   329  			if vv.Kind() == reflect.Map && kv.Type().AssignableTo(vv.Type().Key()) {
   330  				vvv = vv.MapIndex(kv)
   331  			}
   332  		}
   333  
   334  		if ok, err := ns.checkCondition(vvv, mv, op); ok {
   335  			rv = reflect.Append(rv, rvv)
   336  		} else if err != nil {
   337  			return nil, err
   338  		}
   339  	}
   340  	return rv.Interface(), nil
   341  }
   342  
   343  // checkWhereMap handles the where-matching logic when the seqv value is a Map.
   344  func (ns *Namespace) checkWhereMap(seqv, kv, mv reflect.Value, path []string, op string) (interface{}, error) {
   345  	rv := reflect.MakeMap(seqv.Type())
   346  	keys := seqv.MapKeys()
   347  	for _, k := range keys {
   348  		elemv := seqv.MapIndex(k)
   349  		switch elemv.Kind() {
   350  		case reflect.Array, reflect.Slice:
   351  			r, err := ns.checkWhereArray(elemv, kv, mv, path, op)
   352  			if err != nil {
   353  				return nil, err
   354  			}
   355  
   356  			switch rr := reflect.ValueOf(r); rr.Kind() {
   357  			case reflect.Slice:
   358  				if rr.Len() > 0 {
   359  					rv.SetMapIndex(k, elemv)
   360  				}
   361  			}
   362  		case reflect.Interface:
   363  			elemvv, isNil := indirect(elemv)
   364  			if isNil {
   365  				continue
   366  			}
   367  
   368  			switch elemvv.Kind() {
   369  			case reflect.Array, reflect.Slice:
   370  				r, err := ns.checkWhereArray(elemvv, kv, mv, path, op)
   371  				if err != nil {
   372  					return nil, err
   373  				}
   374  
   375  				switch rr := reflect.ValueOf(r); rr.Kind() {
   376  				case reflect.Slice:
   377  					if rr.Len() > 0 {
   378  						rv.SetMapIndex(k, elemv)
   379  					}
   380  				}
   381  			}
   382  		}
   383  	}
   384  	return rv.Interface(), nil
   385  }
   386  
   387  // toFloat returns the float value if possible.
   388  func toFloat(v reflect.Value) (float64, error) {
   389  	switch v.Kind() {
   390  	case reflect.Float32, reflect.Float64:
   391  		return v.Float(), nil
   392  	case reflect.Interface:
   393  		return toFloat(v.Elem())
   394  	}
   395  	return -1, errors.New("unable to convert value to float")
   396  }
   397  
   398  // toInt returns the int value if possible, -1 if not.
   399  // TODO(bep) consolidate all these reflect funcs.
   400  func toInt(v reflect.Value) (int64, error) {
   401  	switch v.Kind() {
   402  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   403  		return v.Int(), nil
   404  	case reflect.Interface:
   405  		return toInt(v.Elem())
   406  	}
   407  	return -1, errors.New("unable to convert value to int")
   408  }
   409  
   410  func toUint(v reflect.Value) (uint64, error) {
   411  	switch v.Kind() {
   412  	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
   413  		return v.Uint(), nil
   414  	case reflect.Interface:
   415  		return toUint(v.Elem())
   416  	}
   417  	return 0, errors.New("unable to convert value to uint")
   418  }
   419  
   420  // toString returns the string value if possible, "" if not.
   421  func toString(v reflect.Value) (string, error) {
   422  	switch v.Kind() {
   423  	case reflect.String:
   424  		return v.String(), nil
   425  	case reflect.Interface:
   426  		return toString(v.Elem())
   427  	}
   428  	return "", errors.New("unable to convert value to string")
   429  }
   430  
   431  func toTimeUnix(v reflect.Value) int64 {
   432  	if v.Kind() == reflect.Interface {
   433  		return toTimeUnix(v.Elem())
   434  	}
   435  	if v.Type() != timeType {
   436  		panic("coding error: argument must be time.Time type reflect Value")
   437  	}
   438  	return v.MethodByName("Unix").Call([]reflect.Value{})[0].Int()
   439  }