github.com/m3db/m3@v1.5.1-0.20231129193456-75a402aa583b/src/dbnode/encoding/proto/round_trip_test.go (about)

     1  // Copyright (c) 2019 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  
    21  package proto
    22  
    23  import (
    24  	"errors"
    25  	"fmt"
    26  	"testing"
    27  	"time"
    28  
    29  	"github.com/m3db/m3/src/dbnode/encoding"
    30  	"github.com/m3db/m3/src/dbnode/namespace"
    31  	"github.com/m3db/m3/src/dbnode/ts"
    32  	"github.com/m3db/m3/src/dbnode/x/xio"
    33  	"github.com/m3db/m3/src/x/pool"
    34  	xtime "github.com/m3db/m3/src/x/time"
    35  
    36  	"github.com/jhump/protoreflect/desc"
    37  	"github.com/jhump/protoreflect/desc/protoparse"
    38  	"github.com/jhump/protoreflect/dynamic"
    39  	"github.com/stretchr/testify/require"
    40  )
    41  
    42  var (
    43  	testVLSchema  = newVLMessageDescriptor()
    44  	testVL2Schema = newVL2MessageDescriptor()
    45  	bytesPool     = pool.NewCheckedBytesPool(nil, nil, func(s []pool.Bucket) pool.BytesPool {
    46  		return pool.NewBytesPool(s, nil)
    47  	})
    48  	testEncodingOptions = encoding.NewOptions().
    49  				SetDefaultTimeUnit(xtime.Second).
    50  				SetBytesPool(bytesPool)
    51  )
    52  
    53  func init() {
    54  	bytesPool.Init()
    55  }
    56  
    57  // TestRoundTrip is intentionally simple to facilitate fast and easy debugging of changes
    58  // as well as to serve as a basic sanity test. However, the bulk of the confidence in this
    59  // code's correctness comes from the `TestRoundtripProp` test which is much more exhaustive.
    60  func TestRoundTrip(t *testing.T) {
    61  	testCases := []struct {
    62  		timestamp  xtime.UnixNano
    63  		unit       xtime.Unit
    64  		latitude   float64
    65  		longitude  float64
    66  		epoch      int64
    67  		deliveryID []byte
    68  		attributes map[string]string
    69  	}{
    70  		{
    71  			unit:      xtime.Second,
    72  			latitude:  0.1,
    73  			longitude: 1.1,
    74  			epoch:     -1,
    75  		},
    76  		{
    77  			unit:       xtime.Nanosecond,
    78  			latitude:   0.1,
    79  			longitude:  1.1,
    80  			epoch:      0,
    81  			deliveryID: []byte("123123123123"),
    82  			attributes: map[string]string{"key1": "val1"},
    83  		},
    84  		{
    85  			unit:       xtime.Nanosecond,
    86  			latitude:   0.2,
    87  			longitude:  2.2,
    88  			epoch:      1,
    89  			deliveryID: []byte("789789789789"),
    90  			attributes: map[string]string{"key1": "val1"},
    91  		},
    92  		{
    93  			unit:       xtime.Millisecond,
    94  			latitude:   0.3,
    95  			longitude:  2.3,
    96  			epoch:      2,
    97  			deliveryID: []byte("123123123123"),
    98  		},
    99  		{
   100  			unit:       xtime.Second,
   101  			latitude:   0.4,
   102  			longitude:  2.4,
   103  			epoch:      3,
   104  			attributes: map[string]string{"key1": "val1"},
   105  		},
   106  		{
   107  			unit:       xtime.Second,
   108  			latitude:   0.5,
   109  			longitude:  2.5,
   110  			epoch:      4,
   111  			deliveryID: []byte("456456456456"),
   112  			attributes: map[string]string{
   113  				"key1": "val1",
   114  				"key2": "val2",
   115  			},
   116  		},
   117  		{
   118  			unit:       xtime.Millisecond,
   119  			latitude:   0.6,
   120  			longitude:  2.6,
   121  			deliveryID: nil,
   122  		},
   123  		{
   124  			unit:       xtime.Nanosecond,
   125  			latitude:   0.5,
   126  			longitude:  2.5,
   127  			deliveryID: []byte("789789789789"),
   128  		},
   129  	}
   130  
   131  	curr := xtime.Now().Truncate(2 * time.Minute)
   132  	enc := newTestEncoder(curr)
   133  	enc.SetSchema(namespace.GetTestSchemaDescr(testVLSchema))
   134  
   135  	for i, tc := range testCases {
   136  		vl := newVL(
   137  			tc.latitude, tc.longitude, tc.epoch, tc.deliveryID, tc.attributes)
   138  		marshalledVL, err := vl.Marshal()
   139  		require.NoError(t, err)
   140  
   141  		duration, err := xtime.DurationFromUnit(tc.unit)
   142  		require.NoError(t, err)
   143  		currTime := curr.Add(time.Duration(i) * duration)
   144  		testCases[i].timestamp = currTime
   145  		// Encoder should ignore value so we set it to make sure it gets ignored.
   146  		err = enc.Encode(ts.Datapoint{TimestampNanos: currTime, Value: float64(i)},
   147  			tc.unit, marshalledVL)
   148  		require.NoError(t, err)
   149  
   150  		lastEncoded, err := enc.LastEncoded()
   151  		require.NoError(t, err)
   152  		require.Equal(t, currTime, lastEncoded.TimestampNanos)
   153  		require.Equal(t, currTime, lastEncoded.TimestampNanos)
   154  		require.Equal(t, float64(0), lastEncoded.Value)
   155  	}
   156  
   157  	// Add some sanity to make sure that the compression (especially string compression)
   158  	// is working properly.
   159  	numExpectedBytes := 281
   160  	require.Equal(t, numExpectedBytes, enc.Stats().CompressedBytes)
   161  
   162  	rawBytes, err := enc.Bytes()
   163  	require.NoError(t, err)
   164  	require.Equal(t, numExpectedBytes, len(rawBytes))
   165  
   166  	r := xio.NewBytesReader64(rawBytes)
   167  	iter := NewIterator(r, namespace.GetTestSchemaDescr(testVLSchema), testEncodingOptions)
   168  
   169  	i := 0
   170  	for iter.Next() {
   171  		var (
   172  			tc                   = testCases[i]
   173  			dp, unit, annotation = iter.Current()
   174  		)
   175  		m := dynamic.NewMessage(testVLSchema)
   176  		require.NoError(t, m.Unmarshal(annotation))
   177  
   178  		require.Equal(t, unit, testCases[i].unit)
   179  		require.Equal(t,
   180  			tc.timestamp, dp.TimestampNanos,
   181  			fmt.Sprintf("expected: %s, got: %s", tc.timestamp.String(), dp.TimestampNanos))
   182  		// Value is meaningless for proto so should always be zero
   183  		// regardless of whats written.
   184  		require.Equal(t, float64(0), dp.Value)
   185  		require.Equal(t, tc.unit, unit)
   186  		require.Equal(t, tc.latitude, m.GetFieldByName("latitude"))
   187  		require.Equal(t, tc.longitude, m.GetFieldByName("longitude"))
   188  		require.Equal(t, tc.epoch, m.GetFieldByName("epoch"))
   189  		require.Equal(t, tc.deliveryID, m.GetFieldByName("deliveryID"))
   190  		assertAttributesEqual(t, tc.attributes, m.GetFieldByName("attributes").(map[interface{}]interface{}))
   191  		i++
   192  	}
   193  	require.NoError(t, iter.Err())
   194  	require.Equal(t, len(testCases), i)
   195  }
   196  
   197  func TestRoundTripMidStreamSchemaChanges(t *testing.T) {
   198  	enc := newTestEncoder(xtime.Now().Truncate(time.Second))
   199  	enc.SetSchema(namespace.GetTestSchemaDescr(testVLSchema))
   200  
   201  	attrs := map[string]string{"key1": "val1"}
   202  	vl1Write := newVL(26.0, 27.0, 10, []byte("some_delivery_id"), attrs)
   203  	marshalledVL, err := vl1Write.Marshal()
   204  	require.NoError(t, err)
   205  
   206  	vl1WriteTime := xtime.Now().Truncate(time.Second)
   207  	err = enc.Encode(ts.Datapoint{TimestampNanos: vl1WriteTime},
   208  		xtime.Second, marshalledVL)
   209  	require.NoError(t, err)
   210  
   211  	vl2Write := newVL2(28.0, 29.0, attrs, "some_new_custom_field", map[int]int{1: 2})
   212  	marshalledVL, err = vl2Write.Marshal()
   213  	require.NoError(t, err)
   214  
   215  	vl2WriteTime := vl1WriteTime.Add(time.Second)
   216  	err = enc.Encode(ts.Datapoint{TimestampNanos: vl2WriteTime},
   217  		xtime.Second, marshalledVL)
   218  	require.EqualError(t,
   219  		err,
   220  		"proto encoder: error unmarshalling message: encountered unknown field with field number: 6")
   221  
   222  	enc.SetSchema(namespace.GetTestSchemaDescr(testVL2Schema))
   223  	err = enc.Encode(ts.Datapoint{TimestampNanos: vl2WriteTime},
   224  		xtime.Second, marshalledVL)
   225  	require.NoError(t, err)
   226  
   227  	rawBytes, err := enc.Bytes()
   228  	require.NoError(t, err)
   229  
   230  	// Try reading the stream just using the vl1 schema.
   231  	r := xio.NewBytesReader64(rawBytes)
   232  	iter := NewIterator(r, namespace.GetTestSchemaDescr(testVLSchema), testEncodingOptions)
   233  
   234  	require.True(t, iter.Next(), "iter err: %v", iter.Err())
   235  	dp, unit, annotation := iter.Current()
   236  	m := dynamic.NewMessage(testVLSchema)
   237  	require.NoError(t, m.Unmarshal(annotation))
   238  	require.Equal(t, xtime.Second, unit)
   239  	require.Equal(t, vl1WriteTime, dp.TimestampNanos)
   240  	require.Equal(t, 5, len(m.GetKnownFields()))
   241  	require.Equal(t, vl1Write.GetFieldByName("latitude"), m.GetFieldByName("latitude"))
   242  	require.Equal(t, vl1Write.GetFieldByName("longitude"), m.GetFieldByName("longitude"))
   243  	require.Equal(t, vl1Write.GetFieldByName("epoch"), m.GetFieldByName("epoch"))
   244  	require.Equal(t, vl1Write.GetFieldByName("deliveryID"), m.GetFieldByName("deliveryID"))
   245  	require.Equal(t, vl1Write.GetFieldByName("attributes"), m.GetFieldByName("attributes"))
   246  
   247  	require.True(t, iter.Next(), "iter err: %v", iter.Err())
   248  	dp, unit, annotation = iter.Current()
   249  	m = dynamic.NewMessage(testVLSchema)
   250  	require.NoError(t, m.Unmarshal(annotation))
   251  	require.Equal(t, xtime.Second, unit)
   252  	require.Equal(t, vl2WriteTime, dp.TimestampNanos)
   253  	require.Equal(t, 5, len(m.GetKnownFields()))
   254  	require.Equal(t, vl2Write.GetFieldByName("latitude"), m.GetFieldByName("latitude"))
   255  	require.Equal(t, vl2Write.GetFieldByName("longitude"), m.GetFieldByName("longitude"))
   256  	require.Equal(t, vl1Write.GetFieldByName("attributes"), m.GetFieldByName("attributes"))
   257  	// vl2 doesn't contain these fields so they should have default values when they're
   258  	// decoded with a vl1 schema.
   259  	require.Equal(t, int64(0), m.GetFieldByName("epoch"))
   260  	require.Equal(t, []byte(nil), m.GetFieldByName("deliveryID"))
   261  	require.Equal(t, vl2Write.GetFieldByName("attributes"), m.GetFieldByName("attributes"))
   262  
   263  	require.False(t, iter.Next())
   264  	require.NoError(t, iter.Err())
   265  
   266  	// Try reading the stream just using the vl2 schema.
   267  	r = xio.NewBytesReader64(rawBytes)
   268  	iter = NewIterator(r, namespace.GetTestSchemaDescr(testVL2Schema), testEncodingOptions)
   269  
   270  	require.True(t, iter.Next(), "iter err: %v", iter.Err())
   271  	dp, unit, annotation = iter.Current()
   272  	m = dynamic.NewMessage(testVL2Schema)
   273  	require.NoError(t, m.Unmarshal(annotation))
   274  	require.Equal(t, xtime.Second, unit)
   275  	require.Equal(t, vl1WriteTime, dp.TimestampNanos)
   276  	require.Equal(t, 5, len(m.GetKnownFields()))
   277  	require.Equal(t, vl1Write.GetFieldByName("latitude"), m.GetFieldByName("latitude"))
   278  	require.Equal(t, vl1Write.GetFieldByName("longitude"), m.GetFieldByName("longitude"))
   279  	require.Equal(t, vl1Write.GetFieldByName("attributes"), m.GetFieldByName("attributes"))
   280  	// This field does not exist in VL1 so it should have a default value when decoding
   281  	// with a VL2 schema.
   282  	require.Equal(t, "", m.GetFieldByName("new_custom_field"))
   283  
   284  	// These fields don't exist in the vl2 schema so they should not be in the returned message.
   285  	_, err = m.TryGetFieldByName("epoch")
   286  	require.Error(t, err)
   287  	_, err = m.TryGetFieldByName("deliveryID")
   288  	require.Error(t, err)
   289  
   290  	require.True(t, iter.Next(), "iter err: %v", iter.Err())
   291  	dp, unit, annotation = iter.Current()
   292  	m = dynamic.NewMessage(testVL2Schema)
   293  	require.NoError(t, m.Unmarshal(annotation))
   294  	require.Equal(t, xtime.Second, unit)
   295  	require.Equal(t, vl2WriteTime, dp.TimestampNanos)
   296  	require.Equal(t, 5, len(m.GetKnownFields()))
   297  	require.Equal(t, vl2Write.GetFieldByName("latitude"), m.GetFieldByName("latitude"))
   298  	require.Equal(t, vl2Write.GetFieldByName("longitude"), m.GetFieldByName("longitude"))
   299  	require.Equal(t, vl2Write.GetFieldByName("new_custom_field"), m.GetFieldByName("new_custom_field"))
   300  	require.Equal(t, vl2Write.GetFieldByName("attributes"), m.GetFieldByName("attributes"))
   301  
   302  	// These fields don't exist in the vl2 schema so they should not be in the returned message.
   303  	_, err = m.TryGetFieldByName("epoch")
   304  	require.Error(t, err)
   305  	_, err = m.TryGetFieldByName("deliveryID")
   306  	require.Error(t, err)
   307  
   308  	require.False(t, iter.Next())
   309  	require.NoError(t, iter.Err())
   310  }
   311  
   312  func newTestEncoder(t xtime.UnixNano) *Encoder {
   313  	e := NewEncoder(t, testEncodingOptions)
   314  	e.Reset(t, 0, nil)
   315  
   316  	return e
   317  }
   318  
   319  func newVL(
   320  	lat, long float64,
   321  	epoch int64,
   322  	deliveryID []byte,
   323  	attributes map[string]string,
   324  ) *dynamic.Message {
   325  	newMessage := dynamic.NewMessage(testVLSchema)
   326  	newMessage.SetFieldByName("latitude", lat)
   327  	newMessage.SetFieldByName("longitude", long)
   328  	newMessage.SetFieldByName("deliveryID", deliveryID)
   329  	newMessage.SetFieldByName("epoch", epoch)
   330  	newMessage.SetFieldByName("attributes", attributes)
   331  
   332  	return newMessage
   333  }
   334  
   335  func newVL2(
   336  	lat, long float64,
   337  	attributes map[string]string,
   338  	newCustomField string,
   339  	newProtoField map[int]int,
   340  ) *dynamic.Message {
   341  	newMessage := dynamic.NewMessage(testVL2Schema)
   342  
   343  	newMessage.SetFieldByName("latitude", lat)
   344  	newMessage.SetFieldByName("longitude", long)
   345  	newMessage.SetFieldByName("attributes", attributes)
   346  	newMessage.SetFieldByName("new_custom_field", newCustomField)
   347  	newMessage.SetFieldByName("new_proto_field", newProtoField)
   348  
   349  	return newMessage
   350  }
   351  
   352  func newVLMessageDescriptor() *desc.MessageDescriptor {
   353  	return newVLMessageDescriptorFromFile("./testdata/vehicle_location.proto")
   354  }
   355  
   356  func newVL2MessageDescriptor() *desc.MessageDescriptor {
   357  	return newVLMessageDescriptorFromFile("./testdata/vehicle_location_schema_change.proto")
   358  }
   359  
   360  func newVLMessageDescriptorFromFile(protoSchemaPath string) *desc.MessageDescriptor {
   361  	fds, err := protoparse.Parser{}.ParseFiles(protoSchemaPath)
   362  	if err != nil {
   363  		panic(err)
   364  	}
   365  
   366  	vlMessage := fds[0].FindMessage("VehicleLocation")
   367  	if vlMessage == nil {
   368  		panic(errors.New("could not find VehicleLocation message in first file"))
   369  	}
   370  
   371  	return vlMessage
   372  }
   373  
   374  func assertAttributesEqual(t *testing.T, expected map[string]string, actual map[interface{}]interface{}) {
   375  	require.Equal(t, len(expected), len(actual))
   376  	for k, v := range expected {
   377  		require.Equal(t, v, actual[k].(string))
   378  	}
   379  }