github.com/hamba/avro/v2@v2.22.1-0.20240518180522-aff3955acf7d/gen/gen_test.go (about)

     1  package gen_test
     2  
     3  import (
     4  	"bytes"
     5  	"flag"
     6  	"go/format"
     7  	"os"
     8  	"regexp"
     9  	"strings"
    10  	"testing"
    11  
    12  	"github.com/hamba/avro/v2"
    13  	"github.com/hamba/avro/v2/gen"
    14  	"github.com/stretchr/testify/assert"
    15  	"github.com/stretchr/testify/require"
    16  )
    17  
    18  var update = flag.Bool("update", false, "Update golden files")
    19  
    20  func TestStruct_InvalidSchemaYieldsErr(t *testing.T) {
    21  	err := gen.Struct(`asd`, &bytes.Buffer{}, gen.Config{})
    22  
    23  	assert.Error(t, err)
    24  }
    25  
    26  func TestStruct_NonRecordSchemasAreNotSupported(t *testing.T) {
    27  	err := gen.Struct(`{"type": "string"}`, &bytes.Buffer{}, gen.Config{})
    28  
    29  	require.Error(t, err)
    30  	assert.Contains(t, strings.ToLower(err.Error()), "only")
    31  	assert.Contains(t, strings.ToLower(err.Error()), "record schema")
    32  }
    33  
    34  func TestStruct_AvroStyleCannotBeOverridden(t *testing.T) {
    35  	schema := `{
    36    "type": "record",
    37    "name": "test",
    38    "fields": [
    39      { "name": "someString", "type": "string" }
    40    ]
    41  }`
    42  	gc := gen.Config{
    43  		PackageName: "Something",
    44  		Tags: map[string]gen.TagStyle{
    45  			"avro": gen.Kebab,
    46  		},
    47  	}
    48  
    49  	_, lines := generate(t, schema, gc)
    50  
    51  	for _, expected := range []string{
    52  		"package something",
    53  		"type Test struct {",
    54  		"SomeString string `avro:\"someString\"`",
    55  		"}",
    56  	} {
    57  		assert.Contains(t, lines, expected, "avro tags should not be configurable, they need to match the schema")
    58  	}
    59  }
    60  
    61  func TestStruct_HandlesGoInitialisms(t *testing.T) {
    62  	schema := `{
    63    "type": "record",
    64    "name": "httpRecord",
    65    "fields": [
    66      { "name": "someString", "type": "string" }
    67    ]
    68  }`
    69  	gc := gen.Config{
    70  		PackageName: "Something",
    71  	}
    72  
    73  	_, lines := generate(t, schema, gc)
    74  
    75  	assert.Contains(t, lines, "type HTTPRecord struct {")
    76  }
    77  
    78  func TestStruct_HandlesAdditionalInitialisms(t *testing.T) {
    79  	schema := `{
    80    "type": "record",
    81    "name": "CidOverHttpRecord",
    82    "fields": [
    83      { "name": "someString", "type": "string" }
    84    ]
    85  }`
    86  	gc := gen.Config{
    87  		PackageName: "Something",
    88  		Initialisms: []string{"CID"},
    89  	}
    90  
    91  	_, lines := generate(t, schema, gc)
    92  
    93  	assert.Contains(t, lines, "type CIDOverHTTPRecord struct {")
    94  }
    95  
    96  func TestStruct_HandlesStrictTypes(t *testing.T) {
    97  	schema := `{
    98    "type": "record",
    99    "name": "test",
   100    "fields": [
   101      { "name": "someString", "type": "int" }
   102    ]
   103  }`
   104  	gc := gen.Config{
   105  		PackageName: "Something",
   106  		StrictTypes: true,
   107  	}
   108  
   109  	_, lines := generate(t, schema, gc)
   110  
   111  	assert.Contains(t, lines, "SomeString int32 `avro:\"someString\"`")
   112  }
   113  
   114  func TestStruct_ConfigurableFieldTags(t *testing.T) {
   115  	schema := `{
   116    "type": "record",
   117    "name": "test",
   118    "fields": [
   119      { "name": "someSTRING", "type": "string" }
   120    ]
   121  }`
   122  
   123  	tests := []struct {
   124  		tagStyle    gen.TagStyle
   125  		expectedTag string
   126  	}{
   127  		{tagStyle: gen.Camel, expectedTag: "json:\"someString\""},
   128  		{tagStyle: gen.Snake, expectedTag: "json:\"some_string\""},
   129  		{tagStyle: gen.Kebab, expectedTag: "json:\"some-string\""},
   130  		{tagStyle: gen.UpperCamel, expectedTag: "json:\"SomeString\""},
   131  		{tagStyle: gen.Original, expectedTag: "json:\"someSTRING\""},
   132  		{tagStyle: gen.TagStyle(""), expectedTag: "json:\"someSTRING\""},
   133  	}
   134  
   135  	for _, test := range tests {
   136  		test := test
   137  		t.Run(string(test.tagStyle), func(t *testing.T) {
   138  			gc := gen.Config{
   139  				PackageName: "Something",
   140  				Tags: map[string]gen.TagStyle{
   141  					"json": test.tagStyle,
   142  				},
   143  			}
   144  			_, lines := generate(t, schema, gc)
   145  
   146  			for _, expected := range []string{
   147  				"package something",
   148  				"type Test struct {",
   149  				"SomeString string `avro:\"someSTRING\" " + test.expectedTag + "`",
   150  				"}",
   151  			} {
   152  				assert.Contains(t, lines, expected)
   153  			}
   154  		})
   155  	}
   156  }
   157  
   158  func TestStruct_GenFromRecordSchema(t *testing.T) {
   159  	schema, err := os.ReadFile("testdata/golden.avsc")
   160  	require.NoError(t, err)
   161  
   162  	gc := gen.Config{PackageName: "Something"}
   163  	file, _ := generate(t, string(schema), gc)
   164  
   165  	if *update {
   166  		err = os.WriteFile("testdata/golden.go", file, 0600)
   167  		require.NoError(t, err)
   168  	}
   169  
   170  	want, err := os.ReadFile("testdata/golden.go")
   171  	require.NoError(t, err)
   172  	assert.Equal(t, string(want), string(file))
   173  }
   174  
   175  func TestStruct_GenFromRecordSchemaWithFullName(t *testing.T) {
   176  	schema, err := os.ReadFile("testdata/golden.avsc")
   177  	require.NoError(t, err)
   178  
   179  	gc := gen.Config{PackageName: "Something", FullName: true}
   180  	file, _ := generate(t, string(schema), gc)
   181  
   182  	if *update {
   183  		err = os.WriteFile("testdata/golden_fullname.go", file, 0600)
   184  		require.NoError(t, err)
   185  	}
   186  
   187  	want, err := os.ReadFile("testdata/golden_fullname.go")
   188  	require.NoError(t, err)
   189  	assert.Equal(t, string(want), string(file))
   190  }
   191  
   192  func TestStruct_GenFromRecordSchemaWithEncoders(t *testing.T) {
   193  	schema, err := os.ReadFile("testdata/golden.avsc")
   194  	require.NoError(t, err)
   195  
   196  	gc := gen.Config{PackageName: "Something", Encoders: true}
   197  	file, _ := generate(t, string(schema), gc)
   198  
   199  	if *update {
   200  		err = os.WriteFile("testdata/golden_encoders.go", file, 0600)
   201  		require.NoError(t, err)
   202  	}
   203  
   204  	want, err := os.ReadFile("testdata/golden_encoders.go")
   205  	require.NoError(t, err)
   206  	assert.Equal(t, string(want), string(file))
   207  }
   208  
   209  func TestGenerator(t *testing.T) {
   210  	unionSchema, err := avro.ParseFiles("testdata/uniontype.avsc")
   211  	require.NoError(t, err)
   212  
   213  	mainSchema, err := avro.ParseFiles("testdata/main.avsc")
   214  	require.NoError(t, err)
   215  
   216  	g := gen.NewGenerator("something", map[string]gen.TagStyle{})
   217  	g.Parse(unionSchema)
   218  	g.Parse(mainSchema)
   219  
   220  	var buf bytes.Buffer
   221  	err = g.Write(&buf)
   222  	require.NoError(t, err)
   223  
   224  	formatted, err := format.Source(buf.Bytes())
   225  	require.NoError(t, err)
   226  
   227  	if *update {
   228  		err = os.WriteFile("testdata/golden_multiple.go", formatted, 0600)
   229  		require.NoError(t, err)
   230  	}
   231  
   232  	want, err := os.ReadFile("testdata/golden_multiple.go")
   233  	require.NoError(t, err)
   234  	assert.Equal(t, string(want), string(formatted))
   235  }
   236  
   237  // generate is a utility to run the generation and return the result as a tuple
   238  func generate(t *testing.T, schema string, gc gen.Config) ([]byte, []string) {
   239  	t.Helper()
   240  
   241  	buf := &bytes.Buffer{}
   242  	err := gen.Struct(schema, buf, gc)
   243  	require.NoError(t, err)
   244  
   245  	b := make([]byte, buf.Len())
   246  	copy(b, buf.Bytes())
   247  
   248  	return buf.Bytes(), removeSpaceAndEmptyLines(b)
   249  }
   250  
   251  func removeSpaceAndEmptyLines(goCode []byte) []string {
   252  	var lines []string
   253  	for _, lineBytes := range bytes.Split(goCode, []byte("\n")) {
   254  		if len(lineBytes) == 0 {
   255  			continue
   256  		}
   257  		trimmed := removeMoreThanOneConsecutiveSpaces(lineBytes)
   258  		lines = append(lines, trimmed)
   259  	}
   260  	return lines
   261  }
   262  
   263  // removeMoreThanOneConsecutiveSpaces replaces all sequences of more than one space, with a single one
   264  func removeMoreThanOneConsecutiveSpaces(lineBytes []byte) string {
   265  	lines := strings.TrimSpace(string(lineBytes))
   266  	return strings.Join(regexp.MustCompile("\\s+|\\t+").Split(lines, -1), " ")
   267  }