github.com/huandu/go@v0.0.0-20151114150818-04e615e41150/src/text/template/funcs.go (about)

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package template
     6  
     7  import (
     8  	"bytes"
     9  	"errors"
    10  	"fmt"
    11  	"io"
    12  	"net/url"
    13  	"reflect"
    14  	"strings"
    15  	"unicode"
    16  	"unicode/utf8"
    17  )
    18  
    19  // FuncMap is the type of the map defining the mapping from names to functions.
    20  // Each function must have either a single return value, or two return values of
    21  // which the second has type error. In that case, if the second (error)
    22  // return value evaluates to non-nil during execution, execution terminates and
    23  // Execute returns that error.
    24  type FuncMap map[string]interface{}
    25  
    26  var builtins = FuncMap{
    27  	"and":      and,
    28  	"call":     call,
    29  	"html":     HTMLEscaper,
    30  	"index":    index,
    31  	"js":       JSEscaper,
    32  	"len":      length,
    33  	"not":      not,
    34  	"or":       or,
    35  	"print":    fmt.Sprint,
    36  	"printf":   fmt.Sprintf,
    37  	"println":  fmt.Sprintln,
    38  	"urlquery": URLQueryEscaper,
    39  
    40  	// Comparisons
    41  	"eq": eq, // ==
    42  	"ge": ge, // >=
    43  	"gt": gt, // >
    44  	"le": le, // <=
    45  	"lt": lt, // <
    46  	"ne": ne, // !=
    47  }
    48  
    49  var builtinFuncs = createValueFuncs(builtins)
    50  
    51  // createValueFuncs turns a FuncMap into a map[string]reflect.Value
    52  func createValueFuncs(funcMap FuncMap) map[string]reflect.Value {
    53  	m := make(map[string]reflect.Value)
    54  	addValueFuncs(m, funcMap)
    55  	return m
    56  }
    57  
    58  // addValueFuncs adds to values the functions in funcs, converting them to reflect.Values.
    59  func addValueFuncs(out map[string]reflect.Value, in FuncMap) {
    60  	for name, fn := range in {
    61  		v := reflect.ValueOf(fn)
    62  		if v.Kind() != reflect.Func {
    63  			panic("value for " + name + " not a function")
    64  		}
    65  		if !goodFunc(v.Type()) {
    66  			panic(fmt.Errorf("can't install method/function %q with %d results", name, v.Type().NumOut()))
    67  		}
    68  		out[name] = v
    69  	}
    70  }
    71  
    72  // addFuncs adds to values the functions in funcs. It does no checking of the input -
    73  // call addValueFuncs first.
    74  func addFuncs(out, in FuncMap) {
    75  	for name, fn := range in {
    76  		out[name] = fn
    77  	}
    78  }
    79  
    80  // goodFunc checks that the function or method has the right result signature.
    81  func goodFunc(typ reflect.Type) bool {
    82  	// We allow functions with 1 result or 2 results where the second is an error.
    83  	switch {
    84  	case typ.NumOut() == 1:
    85  		return true
    86  	case typ.NumOut() == 2 && typ.Out(1) == errorType:
    87  		return true
    88  	}
    89  	return false
    90  }
    91  
    92  // findFunction looks for a function in the template, and global map.
    93  func findFunction(name string, tmpl *Template) (reflect.Value, bool) {
    94  	if tmpl != nil && tmpl.common != nil {
    95  		tmpl.muFuncs.RLock()
    96  		defer tmpl.muFuncs.RUnlock()
    97  		if fn := tmpl.execFuncs[name]; fn.IsValid() {
    98  			return fn, true
    99  		}
   100  	}
   101  	if fn := builtinFuncs[name]; fn.IsValid() {
   102  		return fn, true
   103  	}
   104  	return reflect.Value{}, false
   105  }
   106  
   107  // Indexing.
   108  
   109  // index returns the result of indexing its first argument by the following
   110  // arguments.  Thus "index x 1 2 3" is, in Go syntax, x[1][2][3]. Each
   111  // indexed item must be a map, slice, or array.
   112  func index(item interface{}, indices ...interface{}) (interface{}, error) {
   113  	v := reflect.ValueOf(item)
   114  	for _, i := range indices {
   115  		index := reflect.ValueOf(i)
   116  		var isNil bool
   117  		if v, isNil = indirect(v); isNil {
   118  			return nil, fmt.Errorf("index of nil pointer")
   119  		}
   120  		switch v.Kind() {
   121  		case reflect.Array, reflect.Slice, reflect.String:
   122  			var x int64
   123  			switch index.Kind() {
   124  			case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   125  				x = index.Int()
   126  			case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
   127  				x = int64(index.Uint())
   128  			default:
   129  				return nil, fmt.Errorf("cannot index slice/array with type %s", index.Type())
   130  			}
   131  			if x < 0 || x >= int64(v.Len()) {
   132  				return nil, fmt.Errorf("index out of range: %d", x)
   133  			}
   134  			v = v.Index(int(x))
   135  		case reflect.Map:
   136  			if !index.IsValid() {
   137  				index = reflect.Zero(v.Type().Key())
   138  			}
   139  			if !index.Type().AssignableTo(v.Type().Key()) {
   140  				return nil, fmt.Errorf("%s is not index type for %s", index.Type(), v.Type())
   141  			}
   142  			if x := v.MapIndex(index); x.IsValid() {
   143  				v = x
   144  			} else {
   145  				v = reflect.Zero(v.Type().Elem())
   146  			}
   147  		default:
   148  			return nil, fmt.Errorf("can't index item of type %s", v.Type())
   149  		}
   150  	}
   151  	return v.Interface(), nil
   152  }
   153  
   154  // Length
   155  
   156  // length returns the length of the item, with an error if it has no defined length.
   157  func length(item interface{}) (int, error) {
   158  	v, isNil := indirect(reflect.ValueOf(item))
   159  	if isNil {
   160  		return 0, fmt.Errorf("len of nil pointer")
   161  	}
   162  	switch v.Kind() {
   163  	case reflect.Array, reflect.Chan, reflect.Map, reflect.Slice, reflect.String:
   164  		return v.Len(), nil
   165  	}
   166  	return 0, fmt.Errorf("len of type %s", v.Type())
   167  }
   168  
   169  // Function invocation
   170  
   171  // call returns the result of evaluating the first argument as a function.
   172  // The function must return 1 result, or 2 results, the second of which is an error.
   173  func call(fn interface{}, args ...interface{}) (interface{}, error) {
   174  	v := reflect.ValueOf(fn)
   175  	typ := v.Type()
   176  	if typ.Kind() != reflect.Func {
   177  		return nil, fmt.Errorf("non-function of type %s", typ)
   178  	}
   179  	if !goodFunc(typ) {
   180  		return nil, fmt.Errorf("function called with %d args; should be 1 or 2", typ.NumOut())
   181  	}
   182  	numIn := typ.NumIn()
   183  	var dddType reflect.Type
   184  	if typ.IsVariadic() {
   185  		if len(args) < numIn-1 {
   186  			return nil, fmt.Errorf("wrong number of args: got %d want at least %d", len(args), numIn-1)
   187  		}
   188  		dddType = typ.In(numIn - 1).Elem()
   189  	} else {
   190  		if len(args) != numIn {
   191  			return nil, fmt.Errorf("wrong number of args: got %d want %d", len(args), numIn)
   192  		}
   193  	}
   194  	argv := make([]reflect.Value, len(args))
   195  	for i, arg := range args {
   196  		value := reflect.ValueOf(arg)
   197  		// Compute the expected type. Clumsy because of variadics.
   198  		var argType reflect.Type
   199  		if !typ.IsVariadic() || i < numIn-1 {
   200  			argType = typ.In(i)
   201  		} else {
   202  			argType = dddType
   203  		}
   204  		if !value.IsValid() && canBeNil(argType) {
   205  			value = reflect.Zero(argType)
   206  		}
   207  		if !value.Type().AssignableTo(argType) {
   208  			return nil, fmt.Errorf("arg %d has type %s; should be %s", i, value.Type(), argType)
   209  		}
   210  		argv[i] = value
   211  	}
   212  	result := v.Call(argv)
   213  	if len(result) == 2 && !result[1].IsNil() {
   214  		return result[0].Interface(), result[1].Interface().(error)
   215  	}
   216  	return result[0].Interface(), nil
   217  }
   218  
   219  // Boolean logic.
   220  
   221  func truth(a interface{}) bool {
   222  	t, _ := isTrue(reflect.ValueOf(a))
   223  	return t
   224  }
   225  
   226  // and computes the Boolean AND of its arguments, returning
   227  // the first false argument it encounters, or the last argument.
   228  func and(arg0 interface{}, args ...interface{}) interface{} {
   229  	if !truth(arg0) {
   230  		return arg0
   231  	}
   232  	for i := range args {
   233  		arg0 = args[i]
   234  		if !truth(arg0) {
   235  			break
   236  		}
   237  	}
   238  	return arg0
   239  }
   240  
   241  // or computes the Boolean OR of its arguments, returning
   242  // the first true argument it encounters, or the last argument.
   243  func or(arg0 interface{}, args ...interface{}) interface{} {
   244  	if truth(arg0) {
   245  		return arg0
   246  	}
   247  	for i := range args {
   248  		arg0 = args[i]
   249  		if truth(arg0) {
   250  			break
   251  		}
   252  	}
   253  	return arg0
   254  }
   255  
   256  // not returns the Boolean negation of its argument.
   257  func not(arg interface{}) (truth bool) {
   258  	truth, _ = isTrue(reflect.ValueOf(arg))
   259  	return !truth
   260  }
   261  
   262  // Comparison.
   263  
   264  // TODO: Perhaps allow comparison between signed and unsigned integers.
   265  
   266  var (
   267  	errBadComparisonType = errors.New("invalid type for comparison")
   268  	errBadComparison     = errors.New("incompatible types for comparison")
   269  	errNoComparison      = errors.New("missing argument for comparison")
   270  )
   271  
   272  type kind int
   273  
   274  const (
   275  	invalidKind kind = iota
   276  	boolKind
   277  	complexKind
   278  	intKind
   279  	floatKind
   280  	integerKind
   281  	stringKind
   282  	uintKind
   283  )
   284  
   285  func basicKind(v reflect.Value) (kind, error) {
   286  	switch v.Kind() {
   287  	case reflect.Bool:
   288  		return boolKind, nil
   289  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   290  		return intKind, nil
   291  	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
   292  		return uintKind, nil
   293  	case reflect.Float32, reflect.Float64:
   294  		return floatKind, nil
   295  	case reflect.Complex64, reflect.Complex128:
   296  		return complexKind, nil
   297  	case reflect.String:
   298  		return stringKind, nil
   299  	}
   300  	return invalidKind, errBadComparisonType
   301  }
   302  
   303  // eq evaluates the comparison a == b || a == c || ...
   304  func eq(arg1 interface{}, arg2 ...interface{}) (bool, error) {
   305  	v1 := reflect.ValueOf(arg1)
   306  	k1, err := basicKind(v1)
   307  	if err != nil {
   308  		return false, err
   309  	}
   310  	if len(arg2) == 0 {
   311  		return false, errNoComparison
   312  	}
   313  	for _, arg := range arg2 {
   314  		v2 := reflect.ValueOf(arg)
   315  		k2, err := basicKind(v2)
   316  		if err != nil {
   317  			return false, err
   318  		}
   319  		truth := false
   320  		if k1 != k2 {
   321  			// Special case: Can compare integer values regardless of type's sign.
   322  			switch {
   323  			case k1 == intKind && k2 == uintKind:
   324  				truth = v1.Int() >= 0 && uint64(v1.Int()) == v2.Uint()
   325  			case k1 == uintKind && k2 == intKind:
   326  				truth = v2.Int() >= 0 && v1.Uint() == uint64(v2.Int())
   327  			default:
   328  				return false, errBadComparison
   329  			}
   330  		} else {
   331  			switch k1 {
   332  			case boolKind:
   333  				truth = v1.Bool() == v2.Bool()
   334  			case complexKind:
   335  				truth = v1.Complex() == v2.Complex()
   336  			case floatKind:
   337  				truth = v1.Float() == v2.Float()
   338  			case intKind:
   339  				truth = v1.Int() == v2.Int()
   340  			case stringKind:
   341  				truth = v1.String() == v2.String()
   342  			case uintKind:
   343  				truth = v1.Uint() == v2.Uint()
   344  			default:
   345  				panic("invalid kind")
   346  			}
   347  		}
   348  		if truth {
   349  			return true, nil
   350  		}
   351  	}
   352  	return false, nil
   353  }
   354  
   355  // ne evaluates the comparison a != b.
   356  func ne(arg1, arg2 interface{}) (bool, error) {
   357  	// != is the inverse of ==.
   358  	equal, err := eq(arg1, arg2)
   359  	return !equal, err
   360  }
   361  
   362  // lt evaluates the comparison a < b.
   363  func lt(arg1, arg2 interface{}) (bool, error) {
   364  	v1 := reflect.ValueOf(arg1)
   365  	k1, err := basicKind(v1)
   366  	if err != nil {
   367  		return false, err
   368  	}
   369  	v2 := reflect.ValueOf(arg2)
   370  	k2, err := basicKind(v2)
   371  	if err != nil {
   372  		return false, err
   373  	}
   374  	truth := false
   375  	if k1 != k2 {
   376  		// Special case: Can compare integer values regardless of type's sign.
   377  		switch {
   378  		case k1 == intKind && k2 == uintKind:
   379  			truth = v1.Int() < 0 || uint64(v1.Int()) < v2.Uint()
   380  		case k1 == uintKind && k2 == intKind:
   381  			truth = v2.Int() >= 0 && v1.Uint() < uint64(v2.Int())
   382  		default:
   383  			return false, errBadComparison
   384  		}
   385  	} else {
   386  		switch k1 {
   387  		case boolKind, complexKind:
   388  			return false, errBadComparisonType
   389  		case floatKind:
   390  			truth = v1.Float() < v2.Float()
   391  		case intKind:
   392  			truth = v1.Int() < v2.Int()
   393  		case stringKind:
   394  			truth = v1.String() < v2.String()
   395  		case uintKind:
   396  			truth = v1.Uint() < v2.Uint()
   397  		default:
   398  			panic("invalid kind")
   399  		}
   400  	}
   401  	return truth, nil
   402  }
   403  
   404  // le evaluates the comparison <= b.
   405  func le(arg1, arg2 interface{}) (bool, error) {
   406  	// <= is < or ==.
   407  	lessThan, err := lt(arg1, arg2)
   408  	if lessThan || err != nil {
   409  		return lessThan, err
   410  	}
   411  	return eq(arg1, arg2)
   412  }
   413  
   414  // gt evaluates the comparison a > b.
   415  func gt(arg1, arg2 interface{}) (bool, error) {
   416  	// > is the inverse of <=.
   417  	lessOrEqual, err := le(arg1, arg2)
   418  	if err != nil {
   419  		return false, err
   420  	}
   421  	return !lessOrEqual, nil
   422  }
   423  
   424  // ge evaluates the comparison a >= b.
   425  func ge(arg1, arg2 interface{}) (bool, error) {
   426  	// >= is the inverse of <.
   427  	lessThan, err := lt(arg1, arg2)
   428  	if err != nil {
   429  		return false, err
   430  	}
   431  	return !lessThan, nil
   432  }
   433  
   434  // HTML escaping.
   435  
   436  var (
   437  	htmlQuot = []byte("&#34;") // shorter than "&quot;"
   438  	htmlApos = []byte("&#39;") // shorter than "&apos;" and apos was not in HTML until HTML5
   439  	htmlAmp  = []byte("&amp;")
   440  	htmlLt   = []byte("&lt;")
   441  	htmlGt   = []byte("&gt;")
   442  )
   443  
   444  // HTMLEscape writes to w the escaped HTML equivalent of the plain text data b.
   445  func HTMLEscape(w io.Writer, b []byte) {
   446  	last := 0
   447  	for i, c := range b {
   448  		var html []byte
   449  		switch c {
   450  		case '"':
   451  			html = htmlQuot
   452  		case '\'':
   453  			html = htmlApos
   454  		case '&':
   455  			html = htmlAmp
   456  		case '<':
   457  			html = htmlLt
   458  		case '>':
   459  			html = htmlGt
   460  		default:
   461  			continue
   462  		}
   463  		w.Write(b[last:i])
   464  		w.Write(html)
   465  		last = i + 1
   466  	}
   467  	w.Write(b[last:])
   468  }
   469  
   470  // HTMLEscapeString returns the escaped HTML equivalent of the plain text data s.
   471  func HTMLEscapeString(s string) string {
   472  	// Avoid allocation if we can.
   473  	if strings.IndexAny(s, `'"&<>`) < 0 {
   474  		return s
   475  	}
   476  	var b bytes.Buffer
   477  	HTMLEscape(&b, []byte(s))
   478  	return b.String()
   479  }
   480  
   481  // HTMLEscaper returns the escaped HTML equivalent of the textual
   482  // representation of its arguments.
   483  func HTMLEscaper(args ...interface{}) string {
   484  	return HTMLEscapeString(evalArgs(args))
   485  }
   486  
   487  // JavaScript escaping.
   488  
   489  var (
   490  	jsLowUni = []byte(`\u00`)
   491  	hex      = []byte("0123456789ABCDEF")
   492  
   493  	jsBackslash = []byte(`\\`)
   494  	jsApos      = []byte(`\'`)
   495  	jsQuot      = []byte(`\"`)
   496  	jsLt        = []byte(`\x3C`)
   497  	jsGt        = []byte(`\x3E`)
   498  )
   499  
   500  // JSEscape writes to w the escaped JavaScript equivalent of the plain text data b.
   501  func JSEscape(w io.Writer, b []byte) {
   502  	last := 0
   503  	for i := 0; i < len(b); i++ {
   504  		c := b[i]
   505  
   506  		if !jsIsSpecial(rune(c)) {
   507  			// fast path: nothing to do
   508  			continue
   509  		}
   510  		w.Write(b[last:i])
   511  
   512  		if c < utf8.RuneSelf {
   513  			// Quotes, slashes and angle brackets get quoted.
   514  			// Control characters get written as \u00XX.
   515  			switch c {
   516  			case '\\':
   517  				w.Write(jsBackslash)
   518  			case '\'':
   519  				w.Write(jsApos)
   520  			case '"':
   521  				w.Write(jsQuot)
   522  			case '<':
   523  				w.Write(jsLt)
   524  			case '>':
   525  				w.Write(jsGt)
   526  			default:
   527  				w.Write(jsLowUni)
   528  				t, b := c>>4, c&0x0f
   529  				w.Write(hex[t : t+1])
   530  				w.Write(hex[b : b+1])
   531  			}
   532  		} else {
   533  			// Unicode rune.
   534  			r, size := utf8.DecodeRune(b[i:])
   535  			if unicode.IsPrint(r) {
   536  				w.Write(b[i : i+size])
   537  			} else {
   538  				fmt.Fprintf(w, "\\u%04X", r)
   539  			}
   540  			i += size - 1
   541  		}
   542  		last = i + 1
   543  	}
   544  	w.Write(b[last:])
   545  }
   546  
   547  // JSEscapeString returns the escaped JavaScript equivalent of the plain text data s.
   548  func JSEscapeString(s string) string {
   549  	// Avoid allocation if we can.
   550  	if strings.IndexFunc(s, jsIsSpecial) < 0 {
   551  		return s
   552  	}
   553  	var b bytes.Buffer
   554  	JSEscape(&b, []byte(s))
   555  	return b.String()
   556  }
   557  
   558  func jsIsSpecial(r rune) bool {
   559  	switch r {
   560  	case '\\', '\'', '"', '<', '>':
   561  		return true
   562  	}
   563  	return r < ' ' || utf8.RuneSelf <= r
   564  }
   565  
   566  // JSEscaper returns the escaped JavaScript equivalent of the textual
   567  // representation of its arguments.
   568  func JSEscaper(args ...interface{}) string {
   569  	return JSEscapeString(evalArgs(args))
   570  }
   571  
   572  // URLQueryEscaper returns the escaped value of the textual representation of
   573  // its arguments in a form suitable for embedding in a URL query.
   574  func URLQueryEscaper(args ...interface{}) string {
   575  	return url.QueryEscape(evalArgs(args))
   576  }
   577  
   578  // evalArgs formats the list of arguments into a string. It is therefore equivalent to
   579  //	fmt.Sprint(args...)
   580  // except that each argument is indirected (if a pointer), as required,
   581  // using the same rules as the default string evaluation during template
   582  // execution.
   583  func evalArgs(args []interface{}) string {
   584  	ok := false
   585  	var s string
   586  	// Fast path for simple common case.
   587  	if len(args) == 1 {
   588  		s, ok = args[0].(string)
   589  	}
   590  	if !ok {
   591  		for i, arg := range args {
   592  			a, ok := printableValue(reflect.ValueOf(arg))
   593  			if ok {
   594  				args[i] = a
   595  			} // else let fmt do its thing
   596  		}
   597  		s = fmt.Sprint(args...)
   598  	}
   599  	return s
   600  }