github.com/AESNooper/go/src@v0.0.0-20220218095104-b56a4ab1bbbb/internal/fuzz/encoding.go (about)

     1  // Copyright 2021 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package fuzz
     6  
     7  import (
     8  	"bytes"
     9  	"fmt"
    10  	"go/ast"
    11  	"go/parser"
    12  	"go/token"
    13  	"strconv"
    14  )
    15  
    16  // encVersion1 will be the first line of a file with version 1 encoding.
    17  var encVersion1 = "go test fuzz v1"
    18  
    19  // marshalCorpusFile encodes an arbitrary number of arguments into the file format for the
    20  // corpus.
    21  func marshalCorpusFile(vals ...interface{}) []byte {
    22  	if len(vals) == 0 {
    23  		panic("must have at least one value to marshal")
    24  	}
    25  	b := bytes.NewBuffer([]byte(encVersion1 + "\n"))
    26  	// TODO(katiehockman): keep uint8 and int32 encoding where applicable,
    27  	// instead of changing to byte and rune respectively.
    28  	for _, val := range vals {
    29  		switch t := val.(type) {
    30  		case int, int8, int16, int64, uint, uint16, uint32, uint64, float32, float64, bool:
    31  			fmt.Fprintf(b, "%T(%v)\n", t, t)
    32  		case string:
    33  			fmt.Fprintf(b, "string(%q)\n", t)
    34  		case rune: // int32
    35  			fmt.Fprintf(b, "rune(%q)\n", t)
    36  		case byte: // uint8
    37  			fmt.Fprintf(b, "byte(%q)\n", t)
    38  		case []byte: // []uint8
    39  			fmt.Fprintf(b, "[]byte(%q)\n", t)
    40  		default:
    41  			panic(fmt.Sprintf("unsupported type: %T", t))
    42  		}
    43  	}
    44  	return b.Bytes()
    45  }
    46  
    47  // unmarshalCorpusFile decodes corpus bytes into their respective values.
    48  func unmarshalCorpusFile(b []byte) ([]interface{}, error) {
    49  	if len(b) == 0 {
    50  		return nil, fmt.Errorf("cannot unmarshal empty string")
    51  	}
    52  	lines := bytes.Split(b, []byte("\n"))
    53  	if len(lines) < 2 {
    54  		return nil, fmt.Errorf("must include version and at least one value")
    55  	}
    56  	if string(lines[0]) != encVersion1 {
    57  		return nil, fmt.Errorf("unknown encoding version: %s", lines[0])
    58  	}
    59  	var vals []interface{}
    60  	for _, line := range lines[1:] {
    61  		line = bytes.TrimSpace(line)
    62  		if len(line) == 0 {
    63  			continue
    64  		}
    65  		v, err := parseCorpusValue(line)
    66  		if err != nil {
    67  			return nil, fmt.Errorf("malformed line %q: %v", line, err)
    68  		}
    69  		vals = append(vals, v)
    70  	}
    71  	return vals, nil
    72  }
    73  
    74  func parseCorpusValue(line []byte) (interface{}, error) {
    75  	fs := token.NewFileSet()
    76  	expr, err := parser.ParseExprFrom(fs, "(test)", line, 0)
    77  	if err != nil {
    78  		return nil, err
    79  	}
    80  	call, ok := expr.(*ast.CallExpr)
    81  	if !ok {
    82  		return nil, fmt.Errorf("expected call expression")
    83  	}
    84  	if len(call.Args) != 1 {
    85  		return nil, fmt.Errorf("expected call expression with 1 argument; got %d", len(call.Args))
    86  	}
    87  	arg := call.Args[0]
    88  
    89  	if arrayType, ok := call.Fun.(*ast.ArrayType); ok {
    90  		if arrayType.Len != nil {
    91  			return nil, fmt.Errorf("expected []byte or primitive type")
    92  		}
    93  		elt, ok := arrayType.Elt.(*ast.Ident)
    94  		if !ok || elt.Name != "byte" {
    95  			return nil, fmt.Errorf("expected []byte")
    96  		}
    97  		lit, ok := arg.(*ast.BasicLit)
    98  		if !ok || lit.Kind != token.STRING {
    99  			return nil, fmt.Errorf("string literal required for type []byte")
   100  		}
   101  		s, err := strconv.Unquote(lit.Value)
   102  		if err != nil {
   103  			return nil, err
   104  		}
   105  		return []byte(s), nil
   106  	}
   107  
   108  	idType, ok := call.Fun.(*ast.Ident)
   109  	if !ok {
   110  		return nil, fmt.Errorf("expected []byte or primitive type")
   111  	}
   112  	if idType.Name == "bool" {
   113  		id, ok := arg.(*ast.Ident)
   114  		if !ok {
   115  			return nil, fmt.Errorf("malformed bool")
   116  		}
   117  		if id.Name == "true" {
   118  			return true, nil
   119  		} else if id.Name == "false" {
   120  			return false, nil
   121  		} else {
   122  			return nil, fmt.Errorf("true or false required for type bool")
   123  		}
   124  	}
   125  	var (
   126  		val  string
   127  		kind token.Token
   128  	)
   129  	if op, ok := arg.(*ast.UnaryExpr); ok {
   130  		// Special case for negative numbers.
   131  		lit, ok := op.X.(*ast.BasicLit)
   132  		if !ok || (lit.Kind != token.INT && lit.Kind != token.FLOAT) {
   133  			return nil, fmt.Errorf("expected operation on int or float type")
   134  		}
   135  		if op.Op != token.SUB {
   136  			return nil, fmt.Errorf("unsupported operation on int: %v", op.Op)
   137  		}
   138  		val = op.Op.String() + lit.Value // e.g. "-" + "124"
   139  		kind = lit.Kind
   140  	} else {
   141  		lit, ok := arg.(*ast.BasicLit)
   142  		if !ok {
   143  			return nil, fmt.Errorf("literal value required for primitive type")
   144  		}
   145  		val, kind = lit.Value, lit.Kind
   146  	}
   147  
   148  	switch typ := idType.Name; typ {
   149  	case "string":
   150  		if kind != token.STRING {
   151  			return nil, fmt.Errorf("string literal value required for type string")
   152  		}
   153  		return strconv.Unquote(val)
   154  	case "byte", "rune":
   155  		if kind != token.CHAR {
   156  			return nil, fmt.Errorf("character literal required for byte/rune types")
   157  		}
   158  		n := len(val)
   159  		if n < 2 {
   160  			return nil, fmt.Errorf("malformed character literal, missing single quotes")
   161  		}
   162  		code, _, _, err := strconv.UnquoteChar(val[1:n-1], '\'')
   163  		if err != nil {
   164  			return nil, err
   165  		}
   166  		if typ == "rune" {
   167  			return code, nil
   168  		}
   169  		if code >= 256 {
   170  			return nil, fmt.Errorf("can only encode single byte to a byte type")
   171  		}
   172  		return byte(code), nil
   173  	case "int", "int8", "int16", "int32", "int64":
   174  		if kind != token.INT {
   175  			return nil, fmt.Errorf("integer literal required for int types")
   176  		}
   177  		return parseInt(val, typ)
   178  	case "uint", "uint8", "uint16", "uint32", "uint64":
   179  		if kind != token.INT {
   180  			return nil, fmt.Errorf("integer literal required for uint types")
   181  		}
   182  		return parseUint(val, typ)
   183  	case "float32":
   184  		if kind != token.FLOAT && kind != token.INT {
   185  			return nil, fmt.Errorf("float or integer literal required for float32 type")
   186  		}
   187  		v, err := strconv.ParseFloat(val, 32)
   188  		return float32(v), err
   189  	case "float64":
   190  		if kind != token.FLOAT && kind != token.INT {
   191  			return nil, fmt.Errorf("float or integer literal required for float64 type")
   192  		}
   193  		return strconv.ParseFloat(val, 64)
   194  	default:
   195  		return nil, fmt.Errorf("expected []byte or primitive type")
   196  	}
   197  }
   198  
   199  // parseInt returns an integer of value val and type typ.
   200  func parseInt(val, typ string) (interface{}, error) {
   201  	switch typ {
   202  	case "int":
   203  		return strconv.Atoi(val)
   204  	case "int8":
   205  		i, err := strconv.ParseInt(val, 10, 8)
   206  		return int8(i), err
   207  	case "int16":
   208  		i, err := strconv.ParseInt(val, 10, 16)
   209  		return int16(i), err
   210  	case "int32":
   211  		i, err := strconv.ParseInt(val, 10, 32)
   212  		return int32(i), err
   213  	case "int64":
   214  		return strconv.ParseInt(val, 10, 64)
   215  	default:
   216  		panic("unreachable")
   217  	}
   218  }
   219  
   220  // parseInt returns an unsigned integer of value val and type typ.
   221  func parseUint(val, typ string) (interface{}, error) {
   222  	switch typ {
   223  	case "uint":
   224  		i, err := strconv.ParseUint(val, 10, 0)
   225  		return uint(i), err
   226  	case "uint8":
   227  		i, err := strconv.ParseUint(val, 10, 8)
   228  		return uint8(i), err
   229  	case "uint16":
   230  		i, err := strconv.ParseUint(val, 10, 16)
   231  		return uint16(i), err
   232  	case "uint32":
   233  		i, err := strconv.ParseUint(val, 10, 32)
   234  		return uint32(i), err
   235  	case "uint64":
   236  		return strconv.ParseUint(val, 10, 64)
   237  	default:
   238  		panic("unreachable")
   239  	}
   240  }