github.com/minio/simdjson-go@v0.4.6-0.20231116094823-04d21cddf993/fuzz_test.go (about)

     1  //go:build go1.18
     2  // +build go1.18
     3  
     4  /*
     5   * MinIO Cloud Storage, (C) 2022 MinIO, Inc.
     6   *
     7   * Licensed under the Apache License, Version 2.0 (the "License");
     8   * you may not use this file except in compliance with the License.
     9   * You may obtain a copy of the License at
    10   *
    11   *     http://www.apache.org/licenses/LICENSE-2.0
    12   *
    13   * Unless required by applicable law or agreed to in writing, software
    14   * distributed under the License is distributed on an "AS IS" BASIS,
    15   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    16   * See the License for the specific language governing permissions and
    17   * limitations under the License.
    18   */
    19  
    20  package simdjson
    21  
    22  import (
    23  	"archive/tar"
    24  	"bytes"
    25  	"encoding/json"
    26  	"fmt"
    27  	"go/ast"
    28  	"go/parser"
    29  	"go/token"
    30  	"io"
    31  	"os"
    32  	"strconv"
    33  	"strings"
    34  	"testing"
    35  	"unicode/utf8"
    36  
    37  	"github.com/klauspost/compress/zstd"
    38  )
    39  
    40  func FuzzParse(f *testing.F) {
    41  	if !SupportedCPU() {
    42  		f.SkipNow()
    43  	}
    44  	addBytesFromTarZst(f, "testdata/fuzz/corpus.tar.zst", testing.Short())
    45  	addBytesFromTarZst(f, "testdata/fuzz/go-corpus.tar.zst", testing.Short())
    46  	f.Fuzz(func(t *testing.T, data []byte) {
    47  		var dst map[string]interface{}
    48  		var dstA []interface{}
    49  		pj, err := Parse(data, nil)
    50  		jErr := json.Unmarshal(data, &dst)
    51  		if err != nil {
    52  			if jErr == nil && dst != nil {
    53  				t.Logf("got error %v, but json.Unmarshal could unmarshal", err)
    54  			}
    55  			// Don't continue
    56  			t.Skip()
    57  			return
    58  		}
    59  		if jErr != nil {
    60  			if strings.Contains(jErr.Error(), "cannot unmarshal array into") {
    61  				jErr2 := json.Unmarshal(data, &dstA)
    62  				if jErr2 != nil {
    63  					t.Logf("no error reported, but json.Unmarshal (Array) reported: %v", jErr2)
    64  				}
    65  			} else {
    66  				t.Logf("no error reported, but json.Unmarshal reported: %v", jErr)
    67  			}
    68  		}
    69  		// Check if we can convert back
    70  		i := pj.Iter()
    71  		if i.PeekNextTag() != TagEnd {
    72  			_, err = i.MarshalJSON()
    73  			if err != nil {
    74  				switch {
    75  				// This is ok.
    76  				case strings.Contains(err.Error(), "INF or NaN number found"):
    77  				default:
    78  					t.Error(err)
    79  				}
    80  			}
    81  		}
    82  		// Do simple ND test.
    83  		d2 := append(make([]byte, 0, len(data)*3+2), data...)
    84  		d2 = append(d2, '\n')
    85  		d2 = append(d2, data...)
    86  		d2 = append(d2, '\n')
    87  		d2 = append(d2, data...)
    88  		_, _ = ParseND(data, nil)
    89  		return
    90  	})
    91  }
    92  
    93  // FuzzCorrect will check for correctness and compare output to stdlib.
    94  func FuzzCorrect(f *testing.F) {
    95  	if !SupportedCPU() {
    96  		f.SkipNow()
    97  	}
    98  	const (
    99  		// fail if simdjson doesn't report error, but json.Unmarshal does
   100  		failOnMissingError = true
   101  		// Run input through json.Unmarshal/json.Marshal first
   102  		filterRaw = true
   103  	)
   104  	addBytesFromTarZst(f, "testdata/fuzz/corpus.tar.zst", testing.Short())
   105  	addBytesFromTarZst(f, "testdata/fuzz/go-corpus.tar.zst", testing.Short())
   106  	f.Fuzz(func(t *testing.T, data []byte) {
   107  		var want map[string]interface{}
   108  		var wantA []interface{}
   109  		if !utf8.Valid(data) {
   110  			t.SkipNow()
   111  		}
   112  		if filterRaw {
   113  			var tmp interface{}
   114  			err := json.Unmarshal(data, &tmp)
   115  			if err != nil {
   116  				t.SkipNow()
   117  			}
   118  			data, err = json.Marshal(tmp)
   119  			if err != nil {
   120  				t.Fatal(err)
   121  			}
   122  			if tmp == nil {
   123  				t.SkipNow()
   124  			}
   125  		}
   126  		pj, err := Parse(data, nil)
   127  		jErr := json.Unmarshal(data, &want)
   128  		if err != nil {
   129  			if jErr == nil {
   130  				b, _ := json.Marshal(want)
   131  				t.Fatalf("got error %v, but json.Unmarshal could unmarshal to %#v js: %s", err, want, string(b))
   132  			}
   133  			// Don't continue
   134  			t.SkipNow()
   135  		}
   136  		if jErr != nil {
   137  			want = nil
   138  			if strings.Contains(jErr.Error(), "cannot unmarshal array into") {
   139  				jErr2 := json.Unmarshal(data, &wantA)
   140  				if jErr2 != nil {
   141  					if failOnMissingError {
   142  						t.Fatalf("no error reported, but json.Unmarshal (Array) reported: %v", jErr2)
   143  					}
   144  				}
   145  			} else {
   146  				if failOnMissingError {
   147  					t.Fatalf("no error reported, but json.Unmarshal reported: %v", jErr)
   148  				}
   149  				return
   150  			}
   151  		}
   152  		// Check if we can convert back
   153  		var got map[string]interface{}
   154  		var gotA []interface{}
   155  
   156  		i := pj.Iter()
   157  		if i.PeekNextTag() == TagEnd {
   158  			if len(want)+len(wantA) > 0 {
   159  				msg := fmt.Sprintf("stdlib returned data %#v, but nothing from simdjson (tap:%d, str:%d, err:%v)", want, len(pj.Tape), len(pj.Strings.B), err)
   160  				panic(msg)
   161  			}
   162  			t.SkipNow()
   163  		}
   164  
   165  		data, err = i.MarshalJSON()
   166  		if err != nil {
   167  			switch {
   168  			// This is ok.
   169  			case strings.Contains(err.Error(), "INF or NaN number found"):
   170  			default:
   171  				panic(err)
   172  			}
   173  		}
   174  		var wantB []byte
   175  		var gotB []byte
   176  		if want != nil {
   177  			// We should be able to unmarshal into msi
   178  			i := pj.Iter()
   179  			i.AdvanceInto()
   180  			for i.Type() != TypeNone {
   181  				switch i.Type() {
   182  				case TypeRoot:
   183  					i.Advance()
   184  				case TypeObject:
   185  					obj, err := i.Object(nil)
   186  					if err != nil {
   187  						panic(err)
   188  					}
   189  					got, err = obj.Map(got)
   190  					if err != nil {
   191  						panic(err)
   192  					}
   193  					i.Advance()
   194  				default:
   195  					allOfit := pj.Iter()
   196  					msg, _ := allOfit.MarshalJSON()
   197  					t.Fatalf("Unexpected type: %v, all: %s", i.Type(), string(msg))
   198  				}
   199  			}
   200  			gotB, err = json.Marshal(got)
   201  			if err != nil {
   202  				panic(err)
   203  			}
   204  			wantB, err = json.Marshal(want)
   205  			if err != nil {
   206  				panic(err)
   207  			}
   208  		}
   209  		if wantA != nil {
   210  			// We should be able to unmarshal into msi
   211  			i := pj.Iter()
   212  			i.AdvanceInto()
   213  			for i.Type() != TypeNone {
   214  				switch i.Type() {
   215  				case TypeRoot:
   216  					i.Advance()
   217  				case TypeArray:
   218  					arr, err := i.Array(nil)
   219  					if err != nil {
   220  						panic(err)
   221  					}
   222  					gotA, err = arr.Interface()
   223  					if err != nil {
   224  						panic(err)
   225  					}
   226  					i.Advance()
   227  				default:
   228  					t.Fatalf("Unexpected type: %v", i.Type())
   229  				}
   230  			}
   231  			gotB, err = json.Marshal(gotA)
   232  			if err != nil {
   233  				panic(err)
   234  			}
   235  			wantB, err = json.Marshal(wantA)
   236  			if err != nil {
   237  				panic(err)
   238  			}
   239  		}
   240  		if !bytes.Equal(gotB, wantB) {
   241  			if len(want)+len(got) == 0 {
   242  				t.SkipNow()
   243  			}
   244  			if bytes.Equal(bytes.ReplaceAll(wantB, []byte("-0"), []byte("0")), bytes.ReplaceAll(gotB, []byte("-0"), []byte("0"))) {
   245  				// let -0 == 0
   246  				return
   247  			}
   248  			allOfit := pj.Iter()
   249  			simdOut, _ := allOfit.MarshalJSON()
   250  
   251  			t.Fatalf("Marshal data mismatch:\nstdlib: %v\nsimdjson:%v\n\nsimdjson:%s", string(wantB), string(gotB), string(simdOut))
   252  		}
   253  
   254  		return
   255  	})
   256  }
   257  
   258  // FuzzCorrect will check for correctness and compare output to stdlib.
   259  func FuzzSerialize(f *testing.F) {
   260  	if !SupportedCPU() {
   261  		f.SkipNow()
   262  	}
   263  	addBytesFromTarZst(f, "testdata/fuzz/corpus.tar.zst", testing.Short())
   264  	addBytesFromTarZst(f, "testdata/fuzz/go-corpus.tar.zst", testing.Short())
   265  	f.Fuzz(func(t *testing.T, data []byte) {
   266  		// Create a tape from the input and ensure that the output of JSON matches.
   267  		pj, err := Parse(data, nil)
   268  		if err != nil {
   269  			pj, err = ParseND(data, pj)
   270  			if err != nil {
   271  				// Don't continue
   272  				t.SkipNow()
   273  			}
   274  		}
   275  		i := pj.Iter()
   276  		want, err := i.MarshalJSON()
   277  		if err != nil {
   278  			panic(err)
   279  		}
   280  		// Check if we can convert back
   281  		s := NewSerializer()
   282  		got := make([]byte, 0, len(want))
   283  		var dst []byte
   284  		var target *ParsedJson
   285  		for _, comp := range []CompressMode{CompressNone, CompressFast, CompressDefault, CompressBest} {
   286  			level := fmt.Sprintf("level-%d:", comp)
   287  			s.CompressMode(comp)
   288  			dst = s.Serialize(dst[:0], *pj)
   289  			target, err = s.Deserialize(dst, target)
   290  			if err != nil {
   291  				t.Error(level + err.Error())
   292  			}
   293  			i := target.Iter()
   294  			got, err = i.MarshalJSONBuffer(got[:0])
   295  			if err != nil {
   296  				t.Error(level + err.Error())
   297  			}
   298  			if !bytes.Equal(want, got) {
   299  				err := fmt.Sprintf("%s JSON mismatch:\nwant: %s\ngot :%s", level, string(want), string(got))
   300  				err += fmt.Sprintf("\ntap0:%x", pj.Tape)
   301  				err += fmt.Sprintf("\ntap1:%x", target.Tape)
   302  				t.Error(err)
   303  			}
   304  		}
   305  		return
   306  	})
   307  }
   308  func addBytesFromTarZst(f *testing.F, filename string, short bool) {
   309  	file, err := os.Open(filename)
   310  	if err != nil {
   311  		f.Fatal(err)
   312  	}
   313  	defer file.Close()
   314  	zr, err := zstd.NewReader(file)
   315  	if err != nil {
   316  		f.Fatal(err)
   317  	}
   318  	defer zr.Close()
   319  	tr := tar.NewReader(zr)
   320  	i := 0
   321  	for h, err := tr.Next(); err == nil; h, err = tr.Next() {
   322  		i++
   323  		if short && i%100 != 0 {
   324  			continue
   325  		}
   326  		b := make([]byte, h.Size)
   327  		_, err := io.ReadFull(tr, b)
   328  		if err != nil {
   329  			f.Fatal(err)
   330  		}
   331  		raw := true
   332  		if bytes.HasPrefix(b, []byte("go test fuzz")) {
   333  			raw = false
   334  		}
   335  		if raw {
   336  			f.Add(b)
   337  			continue
   338  		}
   339  		vals, err := unmarshalCorpusFile(b)
   340  		if err != nil {
   341  			f.Fatal(err)
   342  		}
   343  		for _, v := range vals {
   344  			f.Add(v)
   345  		}
   346  	}
   347  }
   348  
   349  // unmarshalCorpusFile decodes corpus bytes into their respective values.
   350  func unmarshalCorpusFile(b []byte) ([][]byte, error) {
   351  	if len(b) == 0 {
   352  		return nil, fmt.Errorf("cannot unmarshal empty string")
   353  	}
   354  	lines := bytes.Split(b, []byte("\n"))
   355  	if len(lines) < 2 {
   356  		return nil, fmt.Errorf("must include version and at least one value")
   357  	}
   358  	var vals = make([][]byte, 0, len(lines)-1)
   359  	for _, line := range lines[1:] {
   360  		line = bytes.TrimSpace(line)
   361  		if len(line) == 0 {
   362  			continue
   363  		}
   364  		v, err := parseCorpusValue(line)
   365  		if err != nil {
   366  			return nil, fmt.Errorf("malformed line %q: %v", line, err)
   367  		}
   368  		vals = append(vals, v)
   369  	}
   370  	return vals, nil
   371  }
   372  
   373  // parseCorpusValue
   374  func parseCorpusValue(line []byte) ([]byte, error) {
   375  	fs := token.NewFileSet()
   376  	expr, err := parser.ParseExprFrom(fs, "(test)", line, 0)
   377  	if err != nil {
   378  		return nil, err
   379  	}
   380  	call, ok := expr.(*ast.CallExpr)
   381  	if !ok {
   382  		return nil, fmt.Errorf("expected call expression")
   383  	}
   384  	if len(call.Args) != 1 {
   385  		return nil, fmt.Errorf("expected call expression with 1 argument; got %d", len(call.Args))
   386  	}
   387  	arg := call.Args[0]
   388  
   389  	if arrayType, ok := call.Fun.(*ast.ArrayType); ok {
   390  		if arrayType.Len != nil {
   391  			return nil, fmt.Errorf("expected []byte or primitive type")
   392  		}
   393  		elt, ok := arrayType.Elt.(*ast.Ident)
   394  		if !ok || elt.Name != "byte" {
   395  			return nil, fmt.Errorf("expected []byte")
   396  		}
   397  		lit, ok := arg.(*ast.BasicLit)
   398  		if !ok || lit.Kind != token.STRING {
   399  			return nil, fmt.Errorf("string literal required for type []byte")
   400  		}
   401  		s, err := strconv.Unquote(lit.Value)
   402  		if err != nil {
   403  			return nil, err
   404  		}
   405  		return []byte(s), nil
   406  	}
   407  	return nil, fmt.Errorf("expected []byte")
   408  }