github.com/JimmyHuang454/JLS-go@v0.0.0-20230831150107-90d536585ba0/internal/fuzz/encoding.go (about)

     1  // Copyright 2021 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package fuzz
     6  
     7  import (
     8  	"bytes"
     9  	"fmt"
    10  	"go/ast"
    11  	"go/parser"
    12  	"go/token"
    13  	"math"
    14  	"strconv"
    15  	"strings"
    16  	"unicode/utf8"
    17  )
    18  
    19  // encVersion1 will be the first line of a file with version 1 encoding.
    20  var encVersion1 = "go test fuzz v1"
    21  
    22  // marshalCorpusFile encodes an arbitrary number of arguments into the file format for the
    23  // corpus.
    24  func marshalCorpusFile(vals ...any) []byte {
    25  	if len(vals) == 0 {
    26  		panic("must have at least one value to marshal")
    27  	}
    28  	b := bytes.NewBuffer([]byte(encVersion1 + "\n"))
    29  	// TODO(katiehockman): keep uint8 and int32 encoding where applicable,
    30  	// instead of changing to byte and rune respectively.
    31  	for _, val := range vals {
    32  		switch t := val.(type) {
    33  		case int, int8, int16, int64, uint, uint16, uint32, uint64, bool:
    34  			fmt.Fprintf(b, "%T(%v)\n", t, t)
    35  		case float32:
    36  			if math.IsNaN(float64(t)) && math.Float32bits(t) != math.Float32bits(float32(math.NaN())) {
    37  				// We encode unusual NaNs as hex values, because that is how users are
    38  				// likely to encounter them in literature about floating-point encoding.
    39  				// This allows us to reproduce fuzz failures that depend on the specific
    40  				// NaN representation (for float32 there are about 2^24 possibilities!),
    41  				// not just the fact that the value is *a* NaN.
    42  				//
    43  				// Note that the specific value of float32(math.NaN()) can vary based on
    44  				// whether the architecture represents signaling NaNs using a low bit
    45  				// (as is common) or a high bit (as commonly implemented on MIPS
    46  				// hardware before around 2012). We believe that the increase in clarity
    47  				// from identifying "NaN" with math.NaN() is worth the slight ambiguity
    48  				// from a platform-dependent value.
    49  				fmt.Fprintf(b, "math.Float32frombits(0x%x)\n", math.Float32bits(t))
    50  			} else {
    51  				// We encode all other values — including the NaN value that is
    52  				// bitwise-identical to float32(math.Nan()) — using the default
    53  				// formatting, which is equivalent to strconv.FormatFloat with format
    54  				// 'g' and can be parsed by strconv.ParseFloat.
    55  				//
    56  				// For an ordinary floating-point number this format includes
    57  				// sufficiently many digits to reconstruct the exact value. For positive
    58  				// or negative infinity it is the string "+Inf" or "-Inf". For positive
    59  				// or negative zero it is "0" or "-0". For NaN, it is the string "NaN".
    60  				fmt.Fprintf(b, "%T(%v)\n", t, t)
    61  			}
    62  		case float64:
    63  			if math.IsNaN(t) && math.Float64bits(t) != math.Float64bits(math.NaN()) {
    64  				fmt.Fprintf(b, "math.Float64frombits(0x%x)\n", math.Float64bits(t))
    65  			} else {
    66  				fmt.Fprintf(b, "%T(%v)\n", t, t)
    67  			}
    68  		case string:
    69  			fmt.Fprintf(b, "string(%q)\n", t)
    70  		case rune: // int32
    71  			// Although rune and int32 are represented by the same type, only a subset
    72  			// of valid int32 values can be expressed as rune literals. Notably,
    73  			// negative numbers, surrogate halves, and values above unicode.MaxRune
    74  			// have no quoted representation.
    75  			//
    76  			// fmt with "%q" (and the corresponding functions in the strconv package)
    77  			// would quote out-of-range values to the Unicode replacement character
    78  			// instead of the original value (see https://go.dev/issue/51526), so
    79  			// they must be treated as int32 instead.
    80  			//
    81  			// We arbitrarily draw the line at UTF-8 validity, which biases toward the
    82  			// "rune" interpretation. (However, we accept either format as input.)
    83  			if utf8.ValidRune(t) {
    84  				fmt.Fprintf(b, "rune(%q)\n", t)
    85  			} else {
    86  				fmt.Fprintf(b, "int32(%v)\n", t)
    87  			}
    88  		case byte: // uint8
    89  			// For bytes, we arbitrarily prefer the character interpretation.
    90  			// (Every byte has a valid character encoding.)
    91  			fmt.Fprintf(b, "byte(%q)\n", t)
    92  		case []byte: // []uint8
    93  			fmt.Fprintf(b, "[]byte(%q)\n", t)
    94  		default:
    95  			panic(fmt.Sprintf("unsupported type: %T", t))
    96  		}
    97  	}
    98  	return b.Bytes()
    99  }
   100  
   101  // unmarshalCorpusFile decodes corpus bytes into their respective values.
   102  func unmarshalCorpusFile(b []byte) ([]any, error) {
   103  	if len(b) == 0 {
   104  		return nil, fmt.Errorf("cannot unmarshal empty string")
   105  	}
   106  	lines := bytes.Split(b, []byte("\n"))
   107  	if len(lines) < 2 {
   108  		return nil, fmt.Errorf("must include version and at least one value")
   109  	}
   110  	version := strings.TrimSuffix(string(lines[0]), "\r")
   111  	if version != encVersion1 {
   112  		return nil, fmt.Errorf("unknown encoding version: %s", version)
   113  	}
   114  	var vals []any
   115  	for _, line := range lines[1:] {
   116  		line = bytes.TrimSpace(line)
   117  		if len(line) == 0 {
   118  			continue
   119  		}
   120  		v, err := parseCorpusValue(line)
   121  		if err != nil {
   122  			return nil, fmt.Errorf("malformed line %q: %v", line, err)
   123  		}
   124  		vals = append(vals, v)
   125  	}
   126  	return vals, nil
   127  }
   128  
   129  func parseCorpusValue(line []byte) (any, error) {
   130  	fs := token.NewFileSet()
   131  	expr, err := parser.ParseExprFrom(fs, "(test)", line, 0)
   132  	if err != nil {
   133  		return nil, err
   134  	}
   135  	call, ok := expr.(*ast.CallExpr)
   136  	if !ok {
   137  		return nil, fmt.Errorf("expected call expression")
   138  	}
   139  	if len(call.Args) != 1 {
   140  		return nil, fmt.Errorf("expected call expression with 1 argument; got %d", len(call.Args))
   141  	}
   142  	arg := call.Args[0]
   143  
   144  	if arrayType, ok := call.Fun.(*ast.ArrayType); ok {
   145  		if arrayType.Len != nil {
   146  			return nil, fmt.Errorf("expected []byte or primitive type")
   147  		}
   148  		elt, ok := arrayType.Elt.(*ast.Ident)
   149  		if !ok || elt.Name != "byte" {
   150  			return nil, fmt.Errorf("expected []byte")
   151  		}
   152  		lit, ok := arg.(*ast.BasicLit)
   153  		if !ok || lit.Kind != token.STRING {
   154  			return nil, fmt.Errorf("string literal required for type []byte")
   155  		}
   156  		s, err := strconv.Unquote(lit.Value)
   157  		if err != nil {
   158  			return nil, err
   159  		}
   160  		return []byte(s), nil
   161  	}
   162  
   163  	var idType *ast.Ident
   164  	if selector, ok := call.Fun.(*ast.SelectorExpr); ok {
   165  		xIdent, ok := selector.X.(*ast.Ident)
   166  		if !ok || xIdent.Name != "math" {
   167  			return nil, fmt.Errorf("invalid selector type")
   168  		}
   169  		switch selector.Sel.Name {
   170  		case "Float64frombits":
   171  			idType = &ast.Ident{Name: "float64-bits"}
   172  		case "Float32frombits":
   173  			idType = &ast.Ident{Name: "float32-bits"}
   174  		default:
   175  			return nil, fmt.Errorf("invalid selector type")
   176  		}
   177  	} else {
   178  		idType, ok = call.Fun.(*ast.Ident)
   179  		if !ok {
   180  			return nil, fmt.Errorf("expected []byte or primitive type")
   181  		}
   182  		if idType.Name == "bool" {
   183  			id, ok := arg.(*ast.Ident)
   184  			if !ok {
   185  				return nil, fmt.Errorf("malformed bool")
   186  			}
   187  			if id.Name == "true" {
   188  				return true, nil
   189  			} else if id.Name == "false" {
   190  				return false, nil
   191  			} else {
   192  				return nil, fmt.Errorf("true or false required for type bool")
   193  			}
   194  		}
   195  	}
   196  
   197  	var (
   198  		val  string
   199  		kind token.Token
   200  	)
   201  	if op, ok := arg.(*ast.UnaryExpr); ok {
   202  		switch lit := op.X.(type) {
   203  		case *ast.BasicLit:
   204  			if op.Op != token.SUB {
   205  				return nil, fmt.Errorf("unsupported operation on int/float: %v", op.Op)
   206  			}
   207  			// Special case for negative numbers.
   208  			val = op.Op.String() + lit.Value // e.g. "-" + "124"
   209  			kind = lit.Kind
   210  		case *ast.Ident:
   211  			if lit.Name != "Inf" {
   212  				return nil, fmt.Errorf("expected operation on int or float type")
   213  			}
   214  			if op.Op == token.SUB {
   215  				val = "-Inf"
   216  			} else {
   217  				val = "+Inf"
   218  			}
   219  			kind = token.FLOAT
   220  		default:
   221  			return nil, fmt.Errorf("expected operation on int or float type")
   222  		}
   223  	} else {
   224  		switch lit := arg.(type) {
   225  		case *ast.BasicLit:
   226  			val, kind = lit.Value, lit.Kind
   227  		case *ast.Ident:
   228  			if lit.Name != "NaN" {
   229  				return nil, fmt.Errorf("literal value required for primitive type")
   230  			}
   231  			val, kind = "NaN", token.FLOAT
   232  		default:
   233  			return nil, fmt.Errorf("literal value required for primitive type")
   234  		}
   235  	}
   236  
   237  	switch typ := idType.Name; typ {
   238  	case "string":
   239  		if kind != token.STRING {
   240  			return nil, fmt.Errorf("string literal value required for type string")
   241  		}
   242  		return strconv.Unquote(val)
   243  	case "byte", "rune":
   244  		if kind == token.INT {
   245  			switch typ {
   246  			case "rune":
   247  				return parseInt(val, typ)
   248  			case "byte":
   249  				return parseUint(val, typ)
   250  			}
   251  		}
   252  		if kind != token.CHAR {
   253  			return nil, fmt.Errorf("character literal required for byte/rune types")
   254  		}
   255  		n := len(val)
   256  		if n < 2 {
   257  			return nil, fmt.Errorf("malformed character literal, missing single quotes")
   258  		}
   259  		code, _, _, err := strconv.UnquoteChar(val[1:n-1], '\'')
   260  		if err != nil {
   261  			return nil, err
   262  		}
   263  		if typ == "rune" {
   264  			return code, nil
   265  		}
   266  		if code >= 256 {
   267  			return nil, fmt.Errorf("can only encode single byte to a byte type")
   268  		}
   269  		return byte(code), nil
   270  	case "int", "int8", "int16", "int32", "int64":
   271  		if kind != token.INT {
   272  			return nil, fmt.Errorf("integer literal required for int types")
   273  		}
   274  		return parseInt(val, typ)
   275  	case "uint", "uint8", "uint16", "uint32", "uint64":
   276  		if kind != token.INT {
   277  			return nil, fmt.Errorf("integer literal required for uint types")
   278  		}
   279  		return parseUint(val, typ)
   280  	case "float32":
   281  		if kind != token.FLOAT && kind != token.INT {
   282  			return nil, fmt.Errorf("float or integer literal required for float32 type")
   283  		}
   284  		v, err := strconv.ParseFloat(val, 32)
   285  		return float32(v), err
   286  	case "float64":
   287  		if kind != token.FLOAT && kind != token.INT {
   288  			return nil, fmt.Errorf("float or integer literal required for float64 type")
   289  		}
   290  		return strconv.ParseFloat(val, 64)
   291  	case "float32-bits":
   292  		if kind != token.INT {
   293  			return nil, fmt.Errorf("integer literal required for math.Float32frombits type")
   294  		}
   295  		bits, err := parseUint(val, "uint32")
   296  		if err != nil {
   297  			return nil, err
   298  		}
   299  		return math.Float32frombits(bits.(uint32)), nil
   300  	case "float64-bits":
   301  		if kind != token.FLOAT && kind != token.INT {
   302  			return nil, fmt.Errorf("integer literal required for math.Float64frombits type")
   303  		}
   304  		bits, err := parseUint(val, "uint64")
   305  		if err != nil {
   306  			return nil, err
   307  		}
   308  		return math.Float64frombits(bits.(uint64)), nil
   309  	default:
   310  		return nil, fmt.Errorf("expected []byte or primitive type")
   311  	}
   312  }
   313  
   314  // parseInt returns an integer of value val and type typ.
   315  func parseInt(val, typ string) (any, error) {
   316  	switch typ {
   317  	case "int":
   318  		// The int type may be either 32 or 64 bits. If 32, the fuzz tests in the
   319  		// corpus may include 64-bit values produced by fuzzing runs on 64-bit
   320  		// architectures. When running those tests, we implicitly wrap the values to
   321  		// fit in a regular int. (The test case is still “interesting”, even if the
   322  		// specific values of its inputs are platform-dependent.)
   323  		i, err := strconv.ParseInt(val, 0, 64)
   324  		return int(i), err
   325  	case "int8":
   326  		i, err := strconv.ParseInt(val, 0, 8)
   327  		return int8(i), err
   328  	case "int16":
   329  		i, err := strconv.ParseInt(val, 0, 16)
   330  		return int16(i), err
   331  	case "int32", "rune":
   332  		i, err := strconv.ParseInt(val, 0, 32)
   333  		return int32(i), err
   334  	case "int64":
   335  		return strconv.ParseInt(val, 0, 64)
   336  	default:
   337  		panic("unreachable")
   338  	}
   339  }
   340  
   341  // parseUint returns an unsigned integer of value val and type typ.
   342  func parseUint(val, typ string) (any, error) {
   343  	switch typ {
   344  	case "uint":
   345  		i, err := strconv.ParseUint(val, 0, 64)
   346  		return uint(i), err
   347  	case "uint8", "byte":
   348  		i, err := strconv.ParseUint(val, 0, 8)
   349  		return uint8(i), err
   350  	case "uint16":
   351  		i, err := strconv.ParseUint(val, 0, 16)
   352  		return uint16(i), err
   353  	case "uint32":
   354  		i, err := strconv.ParseUint(val, 0, 32)
   355  		return uint32(i), err
   356  	case "uint64":
   357  		return strconv.ParseUint(val, 0, 64)
   358  	default:
   359  		panic("unreachable")
   360  	}
   361  }