github.com/corona10/go@v0.0.0-20180224231303-7a218942be57/src/text/template/funcs.go (about)

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package template
     6  
     7  import (
     8  	"bytes"
     9  	"errors"
    10  	"fmt"
    11  	"io"
    12  	"net/url"
    13  	"reflect"
    14  	"strings"
    15  	"unicode"
    16  	"unicode/utf8"
    17  )
    18  
    19  // FuncMap is the type of the map defining the mapping from names to functions.
    20  // Each function must have either a single return value, or two return values of
    21  // which the second has type error. In that case, if the second (error)
    22  // return value evaluates to non-nil during execution, execution terminates and
    23  // Execute returns that error.
    24  //
    25  // When template execution invokes a function with an argument list, that list
    26  // must be assignable to the function's parameter types. Functions meant to
    27  // apply to arguments of arbitrary type can use parameters of type interface{} or
    28  // of type reflect.Value. Similarly, functions meant to return a result of arbitrary
    29  // type can return interface{} or reflect.Value.
    30  type FuncMap map[string]interface{}
    31  
    32  var builtins = FuncMap{
    33  	"and":      and,
    34  	"call":     call,
    35  	"html":     HTMLEscaper,
    36  	"index":    index,
    37  	"js":       JSEscaper,
    38  	"len":      length,
    39  	"not":      not,
    40  	"or":       or,
    41  	"print":    fmt.Sprint,
    42  	"printf":   fmt.Sprintf,
    43  	"println":  fmt.Sprintln,
    44  	"urlquery": URLQueryEscaper,
    45  
    46  	// Comparisons
    47  	"eq": eq, // ==
    48  	"ge": ge, // >=
    49  	"gt": gt, // >
    50  	"le": le, // <=
    51  	"lt": lt, // <
    52  	"ne": ne, // !=
    53  }
    54  
    55  var builtinFuncs = createValueFuncs(builtins)
    56  
    57  // createValueFuncs turns a FuncMap into a map[string]reflect.Value
    58  func createValueFuncs(funcMap FuncMap) map[string]reflect.Value {
    59  	m := make(map[string]reflect.Value)
    60  	addValueFuncs(m, funcMap)
    61  	return m
    62  }
    63  
    64  // addValueFuncs adds to values the functions in funcs, converting them to reflect.Values.
    65  func addValueFuncs(out map[string]reflect.Value, in FuncMap) {
    66  	for name, fn := range in {
    67  		if !goodName(name) {
    68  			panic(fmt.Errorf("function name %s is not a valid identifier", name))
    69  		}
    70  		v := reflect.ValueOf(fn)
    71  		if v.Kind() != reflect.Func {
    72  			panic("value for " + name + " not a function")
    73  		}
    74  		if !goodFunc(v.Type()) {
    75  			panic(fmt.Errorf("can't install method/function %q with %d results", name, v.Type().NumOut()))
    76  		}
    77  		out[name] = v
    78  	}
    79  }
    80  
    81  // addFuncs adds to values the functions in funcs. It does no checking of the input -
    82  // call addValueFuncs first.
    83  func addFuncs(out, in FuncMap) {
    84  	for name, fn := range in {
    85  		out[name] = fn
    86  	}
    87  }
    88  
    89  // goodFunc reports whether the function or method has the right result signature.
    90  func goodFunc(typ reflect.Type) bool {
    91  	// We allow functions with 1 result or 2 results where the second is an error.
    92  	switch {
    93  	case typ.NumOut() == 1:
    94  		return true
    95  	case typ.NumOut() == 2 && typ.Out(1) == errorType:
    96  		return true
    97  	}
    98  	return false
    99  }
   100  
   101  // goodName reports whether the function name is a valid identifier.
   102  func goodName(name string) bool {
   103  	if name == "" {
   104  		return false
   105  	}
   106  	for i, r := range name {
   107  		switch {
   108  		case r == '_':
   109  		case i == 0 && !unicode.IsLetter(r):
   110  			return false
   111  		case !unicode.IsLetter(r) && !unicode.IsDigit(r):
   112  			return false
   113  		}
   114  	}
   115  	return true
   116  }
   117  
   118  // findFunction looks for a function in the template, and global map.
   119  func findFunction(name string, tmpl *Template) (reflect.Value, bool) {
   120  	if tmpl != nil && tmpl.common != nil {
   121  		tmpl.muFuncs.RLock()
   122  		defer tmpl.muFuncs.RUnlock()
   123  		if fn := tmpl.execFuncs[name]; fn.IsValid() {
   124  			return fn, true
   125  		}
   126  	}
   127  	if fn := builtinFuncs[name]; fn.IsValid() {
   128  		return fn, true
   129  	}
   130  	return reflect.Value{}, false
   131  }
   132  
   133  // prepareArg checks if value can be used as an argument of type argType, and
   134  // converts an invalid value to appropriate zero if possible.
   135  func prepareArg(value reflect.Value, argType reflect.Type) (reflect.Value, error) {
   136  	if !value.IsValid() {
   137  		if !canBeNil(argType) {
   138  			return reflect.Value{}, fmt.Errorf("value is nil; should be of type %s", argType)
   139  		}
   140  		value = reflect.Zero(argType)
   141  	}
   142  	if value.Type().AssignableTo(argType) {
   143  		return value, nil
   144  	}
   145  	if intLike(value.Kind()) && intLike(argType.Kind()) && value.Type().ConvertibleTo(argType) {
   146  		value = value.Convert(argType)
   147  		return value, nil
   148  	}
   149  	return reflect.Value{}, fmt.Errorf("value has type %s; should be %s", value.Type(), argType)
   150  }
   151  
   152  func intLike(typ reflect.Kind) bool {
   153  	switch typ {
   154  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   155  		return true
   156  	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
   157  		return true
   158  	}
   159  	return false
   160  }
   161  
   162  // Indexing.
   163  
   164  // index returns the result of indexing its first argument by the following
   165  // arguments. Thus "index x 1 2 3" is, in Go syntax, x[1][2][3]. Each
   166  // indexed item must be a map, slice, or array.
   167  func index(item reflect.Value, indices ...reflect.Value) (reflect.Value, error) {
   168  	v := indirectInterface(item)
   169  	if !v.IsValid() {
   170  		return reflect.Value{}, fmt.Errorf("index of untyped nil")
   171  	}
   172  	for _, i := range indices {
   173  		index := indirectInterface(i)
   174  		var isNil bool
   175  		if v, isNil = indirect(v); isNil {
   176  			return reflect.Value{}, fmt.Errorf("index of nil pointer")
   177  		}
   178  		switch v.Kind() {
   179  		case reflect.Array, reflect.Slice, reflect.String:
   180  			var x int64
   181  			switch index.Kind() {
   182  			case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   183  				x = index.Int()
   184  			case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
   185  				x = int64(index.Uint())
   186  			case reflect.Invalid:
   187  				return reflect.Value{}, fmt.Errorf("cannot index slice/array with nil")
   188  			default:
   189  				return reflect.Value{}, fmt.Errorf("cannot index slice/array with type %s", index.Type())
   190  			}
   191  			if x < 0 || x >= int64(v.Len()) {
   192  				return reflect.Value{}, fmt.Errorf("index out of range: %d", x)
   193  			}
   194  			v = v.Index(int(x))
   195  		case reflect.Map:
   196  			index, err := prepareArg(index, v.Type().Key())
   197  			if err != nil {
   198  				return reflect.Value{}, err
   199  			}
   200  			if x := v.MapIndex(index); x.IsValid() {
   201  				v = x
   202  			} else {
   203  				v = reflect.Zero(v.Type().Elem())
   204  			}
   205  		case reflect.Invalid:
   206  			// the loop holds invariant: v.IsValid()
   207  			panic("unreachable")
   208  		default:
   209  			return reflect.Value{}, fmt.Errorf("can't index item of type %s", v.Type())
   210  		}
   211  	}
   212  	return v, nil
   213  }
   214  
   215  // Length
   216  
   217  // length returns the length of the item, with an error if it has no defined length.
   218  func length(item interface{}) (int, error) {
   219  	v := reflect.ValueOf(item)
   220  	if !v.IsValid() {
   221  		return 0, fmt.Errorf("len of untyped nil")
   222  	}
   223  	v, isNil := indirect(v)
   224  	if isNil {
   225  		return 0, fmt.Errorf("len of nil pointer")
   226  	}
   227  	switch v.Kind() {
   228  	case reflect.Array, reflect.Chan, reflect.Map, reflect.Slice, reflect.String:
   229  		return v.Len(), nil
   230  	}
   231  	return 0, fmt.Errorf("len of type %s", v.Type())
   232  }
   233  
   234  // Function invocation
   235  
   236  // call returns the result of evaluating the first argument as a function.
   237  // The function must return 1 result, or 2 results, the second of which is an error.
   238  func call(fn reflect.Value, args ...reflect.Value) (reflect.Value, error) {
   239  	v := indirectInterface(fn)
   240  	if !v.IsValid() {
   241  		return reflect.Value{}, fmt.Errorf("call of nil")
   242  	}
   243  	typ := v.Type()
   244  	if typ.Kind() != reflect.Func {
   245  		return reflect.Value{}, fmt.Errorf("non-function of type %s", typ)
   246  	}
   247  	if !goodFunc(typ) {
   248  		return reflect.Value{}, fmt.Errorf("function called with %d args; should be 1 or 2", typ.NumOut())
   249  	}
   250  	numIn := typ.NumIn()
   251  	var dddType reflect.Type
   252  	if typ.IsVariadic() {
   253  		if len(args) < numIn-1 {
   254  			return reflect.Value{}, fmt.Errorf("wrong number of args: got %d want at least %d", len(args), numIn-1)
   255  		}
   256  		dddType = typ.In(numIn - 1).Elem()
   257  	} else {
   258  		if len(args) != numIn {
   259  			return reflect.Value{}, fmt.Errorf("wrong number of args: got %d want %d", len(args), numIn)
   260  		}
   261  	}
   262  	argv := make([]reflect.Value, len(args))
   263  	for i, arg := range args {
   264  		value := indirectInterface(arg)
   265  		// Compute the expected type. Clumsy because of variadics.
   266  		var argType reflect.Type
   267  		if !typ.IsVariadic() || i < numIn-1 {
   268  			argType = typ.In(i)
   269  		} else {
   270  			argType = dddType
   271  		}
   272  
   273  		var err error
   274  		if argv[i], err = prepareArg(value, argType); err != nil {
   275  			return reflect.Value{}, fmt.Errorf("arg %d: %s", i, err)
   276  		}
   277  	}
   278  	result := v.Call(argv)
   279  	if len(result) == 2 && !result[1].IsNil() {
   280  		return result[0], result[1].Interface().(error)
   281  	}
   282  	return result[0], nil
   283  }
   284  
   285  // Boolean logic.
   286  
   287  func truth(arg reflect.Value) bool {
   288  	t, _ := isTrue(indirectInterface(arg))
   289  	return t
   290  }
   291  
   292  // and computes the Boolean AND of its arguments, returning
   293  // the first false argument it encounters, or the last argument.
   294  func and(arg0 reflect.Value, args ...reflect.Value) reflect.Value {
   295  	if !truth(arg0) {
   296  		return arg0
   297  	}
   298  	for i := range args {
   299  		arg0 = args[i]
   300  		if !truth(arg0) {
   301  			break
   302  		}
   303  	}
   304  	return arg0
   305  }
   306  
   307  // or computes the Boolean OR of its arguments, returning
   308  // the first true argument it encounters, or the last argument.
   309  func or(arg0 reflect.Value, args ...reflect.Value) reflect.Value {
   310  	if truth(arg0) {
   311  		return arg0
   312  	}
   313  	for i := range args {
   314  		arg0 = args[i]
   315  		if truth(arg0) {
   316  			break
   317  		}
   318  	}
   319  	return arg0
   320  }
   321  
   322  // not returns the Boolean negation of its argument.
   323  func not(arg reflect.Value) bool {
   324  	return !truth(arg)
   325  }
   326  
   327  // Comparison.
   328  
   329  // TODO: Perhaps allow comparison between signed and unsigned integers.
   330  
   331  var (
   332  	errBadComparisonType = errors.New("invalid type for comparison")
   333  	errBadComparison     = errors.New("incompatible types for comparison")
   334  	errNoComparison      = errors.New("missing argument for comparison")
   335  )
   336  
   337  type kind int
   338  
   339  const (
   340  	invalidKind kind = iota
   341  	boolKind
   342  	complexKind
   343  	intKind
   344  	floatKind
   345  	stringKind
   346  	uintKind
   347  )
   348  
   349  func basicKind(v reflect.Value) (kind, error) {
   350  	switch v.Kind() {
   351  	case reflect.Bool:
   352  		return boolKind, nil
   353  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   354  		return intKind, nil
   355  	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
   356  		return uintKind, nil
   357  	case reflect.Float32, reflect.Float64:
   358  		return floatKind, nil
   359  	case reflect.Complex64, reflect.Complex128:
   360  		return complexKind, nil
   361  	case reflect.String:
   362  		return stringKind, nil
   363  	}
   364  	return invalidKind, errBadComparisonType
   365  }
   366  
   367  // eq evaluates the comparison a == b || a == c || ...
   368  func eq(arg1 reflect.Value, arg2 ...reflect.Value) (bool, error) {
   369  	v1 := indirectInterface(arg1)
   370  	k1, err := basicKind(v1)
   371  	if err != nil {
   372  		return false, err
   373  	}
   374  	if len(arg2) == 0 {
   375  		return false, errNoComparison
   376  	}
   377  	for _, arg := range arg2 {
   378  		v2 := indirectInterface(arg)
   379  		k2, err := basicKind(v2)
   380  		if err != nil {
   381  			return false, err
   382  		}
   383  		truth := false
   384  		if k1 != k2 {
   385  			// Special case: Can compare integer values regardless of type's sign.
   386  			switch {
   387  			case k1 == intKind && k2 == uintKind:
   388  				truth = v1.Int() >= 0 && uint64(v1.Int()) == v2.Uint()
   389  			case k1 == uintKind && k2 == intKind:
   390  				truth = v2.Int() >= 0 && v1.Uint() == uint64(v2.Int())
   391  			default:
   392  				return false, errBadComparison
   393  			}
   394  		} else {
   395  			switch k1 {
   396  			case boolKind:
   397  				truth = v1.Bool() == v2.Bool()
   398  			case complexKind:
   399  				truth = v1.Complex() == v2.Complex()
   400  			case floatKind:
   401  				truth = v1.Float() == v2.Float()
   402  			case intKind:
   403  				truth = v1.Int() == v2.Int()
   404  			case stringKind:
   405  				truth = v1.String() == v2.String()
   406  			case uintKind:
   407  				truth = v1.Uint() == v2.Uint()
   408  			default:
   409  				panic("invalid kind")
   410  			}
   411  		}
   412  		if truth {
   413  			return true, nil
   414  		}
   415  	}
   416  	return false, nil
   417  }
   418  
   419  // ne evaluates the comparison a != b.
   420  func ne(arg1, arg2 reflect.Value) (bool, error) {
   421  	// != is the inverse of ==.
   422  	equal, err := eq(arg1, arg2)
   423  	return !equal, err
   424  }
   425  
   426  // lt evaluates the comparison a < b.
   427  func lt(arg1, arg2 reflect.Value) (bool, error) {
   428  	v1 := indirectInterface(arg1)
   429  	k1, err := basicKind(v1)
   430  	if err != nil {
   431  		return false, err
   432  	}
   433  	v2 := indirectInterface(arg2)
   434  	k2, err := basicKind(v2)
   435  	if err != nil {
   436  		return false, err
   437  	}
   438  	truth := false
   439  	if k1 != k2 {
   440  		// Special case: Can compare integer values regardless of type's sign.
   441  		switch {
   442  		case k1 == intKind && k2 == uintKind:
   443  			truth = v1.Int() < 0 || uint64(v1.Int()) < v2.Uint()
   444  		case k1 == uintKind && k2 == intKind:
   445  			truth = v2.Int() >= 0 && v1.Uint() < uint64(v2.Int())
   446  		default:
   447  			return false, errBadComparison
   448  		}
   449  	} else {
   450  		switch k1 {
   451  		case boolKind, complexKind:
   452  			return false, errBadComparisonType
   453  		case floatKind:
   454  			truth = v1.Float() < v2.Float()
   455  		case intKind:
   456  			truth = v1.Int() < v2.Int()
   457  		case stringKind:
   458  			truth = v1.String() < v2.String()
   459  		case uintKind:
   460  			truth = v1.Uint() < v2.Uint()
   461  		default:
   462  			panic("invalid kind")
   463  		}
   464  	}
   465  	return truth, nil
   466  }
   467  
   468  // le evaluates the comparison <= b.
   469  func le(arg1, arg2 reflect.Value) (bool, error) {
   470  	// <= is < or ==.
   471  	lessThan, err := lt(arg1, arg2)
   472  	if lessThan || err != nil {
   473  		return lessThan, err
   474  	}
   475  	return eq(arg1, arg2)
   476  }
   477  
   478  // gt evaluates the comparison a > b.
   479  func gt(arg1, arg2 reflect.Value) (bool, error) {
   480  	// > is the inverse of <=.
   481  	lessOrEqual, err := le(arg1, arg2)
   482  	if err != nil {
   483  		return false, err
   484  	}
   485  	return !lessOrEqual, nil
   486  }
   487  
   488  // ge evaluates the comparison a >= b.
   489  func ge(arg1, arg2 reflect.Value) (bool, error) {
   490  	// >= is the inverse of <.
   491  	lessThan, err := lt(arg1, arg2)
   492  	if err != nil {
   493  		return false, err
   494  	}
   495  	return !lessThan, nil
   496  }
   497  
   498  // HTML escaping.
   499  
   500  var (
   501  	htmlQuot = []byte("&#34;") // shorter than "&quot;"
   502  	htmlApos = []byte("&#39;") // shorter than "&apos;" and apos was not in HTML until HTML5
   503  	htmlAmp  = []byte("&amp;")
   504  	htmlLt   = []byte("&lt;")
   505  	htmlGt   = []byte("&gt;")
   506  	htmlNull = []byte("\uFFFD")
   507  )
   508  
   509  // HTMLEscape writes to w the escaped HTML equivalent of the plain text data b.
   510  func HTMLEscape(w io.Writer, b []byte) {
   511  	last := 0
   512  	for i, c := range b {
   513  		var html []byte
   514  		switch c {
   515  		case '\000':
   516  			html = htmlNull
   517  		case '"':
   518  			html = htmlQuot
   519  		case '\'':
   520  			html = htmlApos
   521  		case '&':
   522  			html = htmlAmp
   523  		case '<':
   524  			html = htmlLt
   525  		case '>':
   526  			html = htmlGt
   527  		default:
   528  			continue
   529  		}
   530  		w.Write(b[last:i])
   531  		w.Write(html)
   532  		last = i + 1
   533  	}
   534  	w.Write(b[last:])
   535  }
   536  
   537  // HTMLEscapeString returns the escaped HTML equivalent of the plain text data s.
   538  func HTMLEscapeString(s string) string {
   539  	// Avoid allocation if we can.
   540  	if !strings.ContainsAny(s, "'\"&<>\000") {
   541  		return s
   542  	}
   543  	var b bytes.Buffer
   544  	HTMLEscape(&b, []byte(s))
   545  	return b.String()
   546  }
   547  
   548  // HTMLEscaper returns the escaped HTML equivalent of the textual
   549  // representation of its arguments.
   550  func HTMLEscaper(args ...interface{}) string {
   551  	return HTMLEscapeString(evalArgs(args))
   552  }
   553  
   554  // JavaScript escaping.
   555  
   556  var (
   557  	jsLowUni = []byte(`\u00`)
   558  	hex      = []byte("0123456789ABCDEF")
   559  
   560  	jsBackslash = []byte(`\\`)
   561  	jsApos      = []byte(`\'`)
   562  	jsQuot      = []byte(`\"`)
   563  	jsLt        = []byte(`\x3C`)
   564  	jsGt        = []byte(`\x3E`)
   565  )
   566  
   567  // JSEscape writes to w the escaped JavaScript equivalent of the plain text data b.
   568  func JSEscape(w io.Writer, b []byte) {
   569  	last := 0
   570  	for i := 0; i < len(b); i++ {
   571  		c := b[i]
   572  
   573  		if !jsIsSpecial(rune(c)) {
   574  			// fast path: nothing to do
   575  			continue
   576  		}
   577  		w.Write(b[last:i])
   578  
   579  		if c < utf8.RuneSelf {
   580  			// Quotes, slashes and angle brackets get quoted.
   581  			// Control characters get written as \u00XX.
   582  			switch c {
   583  			case '\\':
   584  				w.Write(jsBackslash)
   585  			case '\'':
   586  				w.Write(jsApos)
   587  			case '"':
   588  				w.Write(jsQuot)
   589  			case '<':
   590  				w.Write(jsLt)
   591  			case '>':
   592  				w.Write(jsGt)
   593  			default:
   594  				w.Write(jsLowUni)
   595  				t, b := c>>4, c&0x0f
   596  				w.Write(hex[t : t+1])
   597  				w.Write(hex[b : b+1])
   598  			}
   599  		} else {
   600  			// Unicode rune.
   601  			r, size := utf8.DecodeRune(b[i:])
   602  			if unicode.IsPrint(r) {
   603  				w.Write(b[i : i+size])
   604  			} else {
   605  				fmt.Fprintf(w, "\\u%04X", r)
   606  			}
   607  			i += size - 1
   608  		}
   609  		last = i + 1
   610  	}
   611  	w.Write(b[last:])
   612  }
   613  
   614  // JSEscapeString returns the escaped JavaScript equivalent of the plain text data s.
   615  func JSEscapeString(s string) string {
   616  	// Avoid allocation if we can.
   617  	if strings.IndexFunc(s, jsIsSpecial) < 0 {
   618  		return s
   619  	}
   620  	var b bytes.Buffer
   621  	JSEscape(&b, []byte(s))
   622  	return b.String()
   623  }
   624  
   625  func jsIsSpecial(r rune) bool {
   626  	switch r {
   627  	case '\\', '\'', '"', '<', '>':
   628  		return true
   629  	}
   630  	return r < ' ' || utf8.RuneSelf <= r
   631  }
   632  
   633  // JSEscaper returns the escaped JavaScript equivalent of the textual
   634  // representation of its arguments.
   635  func JSEscaper(args ...interface{}) string {
   636  	return JSEscapeString(evalArgs(args))
   637  }
   638  
   639  // URLQueryEscaper returns the escaped value of the textual representation of
   640  // its arguments in a form suitable for embedding in a URL query.
   641  func URLQueryEscaper(args ...interface{}) string {
   642  	return url.QueryEscape(evalArgs(args))
   643  }
   644  
   645  // evalArgs formats the list of arguments into a string. It is therefore equivalent to
   646  //	fmt.Sprint(args...)
   647  // except that each argument is indirected (if a pointer), as required,
   648  // using the same rules as the default string evaluation during template
   649  // execution.
   650  func evalArgs(args []interface{}) string {
   651  	ok := false
   652  	var s string
   653  	// Fast path for simple common case.
   654  	if len(args) == 1 {
   655  		s, ok = args[0].(string)
   656  	}
   657  	if !ok {
   658  		for i, arg := range args {
   659  			a, ok := printableValue(reflect.ValueOf(arg))
   660  			if ok {
   661  				args[i] = a
   662  			} // else let fmt do its thing
   663  		}
   664  		s = fmt.Sprint(args...)
   665  	}
   666  	return s
   667  }