github.com/m3db/m3@v1.5.0/src/dbnode/testdata/prototest/fixture.go (about)

     1  // Copyright (c) 2019 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  
    21  package prototest
    22  
    23  import (
    24  	"bytes"
    25  	"io/ioutil"
    26  	"os"
    27  	"path/filepath"
    28  	"testing"
    29  	"time"
    30  
    31  	"github.com/jhump/protoreflect/desc"
    32  	"github.com/jhump/protoreflect/dynamic"
    33  	"github.com/m3db/m3/src/dbnode/namespace"
    34  	"github.com/stretchr/testify/require"
    35  )
    36  
    37  const (
    38  	protoStr = `syntax = "proto3";
    39  package mainpkg;
    40  
    41  message TestMessage {
    42    double latitude = 1;
    43    double longitude = 2;
    44    int64 epoch = 3;
    45    bytes deliveryID = 4;
    46    map<string, string> attributes = 5;
    47  }
    48  `
    49  )
    50  
    51  type TestMessage struct {
    52  	timestamp  time.Time
    53  	latitude   float64
    54  	longitude  float64
    55  	epoch      int64
    56  	deliveryID []byte
    57  	attributes map[string]string
    58  }
    59  
    60  func NewSchemaHistory() namespace.SchemaHistory {
    61  	tempDir, err := ioutil.TempDir("", "m3dbnode-prototest")
    62  	if err != nil {
    63  		panic(err)
    64  	}
    65  	defer os.RemoveAll(tempDir)
    66  
    67  	testProtoFile := filepath.Join(tempDir, "test.proto")
    68  	err = ioutil.WriteFile(testProtoFile, []byte(protoStr), 0666)
    69  	if err != nil {
    70  		panic(err)
    71  	}
    72  
    73  	schemaHis, err := namespace.LoadSchemaHistory(namespace.GenTestSchemaOptions(testProtoFile))
    74  	if err != nil {
    75  		panic(err)
    76  	}
    77  	return schemaHis
    78  }
    79  
    80  func NewMessageDescriptor(his namespace.SchemaHistory) *desc.MessageDescriptor {
    81  	schema, ok := his.GetLatest()
    82  	if !ok {
    83  		panic("schema history is empty")
    84  	}
    85  	return schema.Get().MessageDescriptor
    86  }
    87  
    88  func NewProtoTestMessages(md *desc.MessageDescriptor) []*dynamic.Message {
    89  	testFixtures := []TestMessage{
    90  		{
    91  			latitude:  0.1,
    92  			longitude: 1.1,
    93  			epoch:     -1,
    94  		},
    95  		{
    96  			latitude:   0.1,
    97  			longitude:  1.1,
    98  			epoch:      0,
    99  			deliveryID: []byte("123123123123"),
   100  			attributes: map[string]string{"key1": "val1"},
   101  		},
   102  		{
   103  			latitude:   0.2,
   104  			longitude:  2.2,
   105  			epoch:      1,
   106  			deliveryID: []byte("789789789789"),
   107  			attributes: map[string]string{"key1": "val1"},
   108  		},
   109  		{
   110  			latitude:   0.3,
   111  			longitude:  2.3,
   112  			epoch:      2,
   113  			deliveryID: []byte("123123123123"),
   114  		},
   115  		{
   116  			latitude:   0.4,
   117  			longitude:  2.4,
   118  			epoch:      3,
   119  			attributes: map[string]string{"key1": "val1"},
   120  		},
   121  		{
   122  			latitude:   0.5,
   123  			longitude:  2.5,
   124  			epoch:      4,
   125  			deliveryID: []byte("456456456456"),
   126  			attributes: map[string]string{
   127  				"key1": "val1",
   128  				"key2": "val2",
   129  			},
   130  		},
   131  		{
   132  			latitude:   0.6,
   133  			longitude:  2.6,
   134  			deliveryID: nil,
   135  		},
   136  		{
   137  			latitude:   0.5,
   138  			longitude:  2.5,
   139  			deliveryID: []byte("789789789789"),
   140  		},
   141  	}
   142  
   143  	msgs := make([]*dynamic.Message, len(testFixtures))
   144  	for i := 0; i < len(msgs); i++ {
   145  		newMessage := dynamic.NewMessage(md)
   146  		newMessage.SetFieldByName("latitude", testFixtures[i].latitude)
   147  		newMessage.SetFieldByName("longitude", testFixtures[i].longitude)
   148  		newMessage.SetFieldByName("deliveryID", testFixtures[i].deliveryID)
   149  		newMessage.SetFieldByName("epoch", testFixtures[i].epoch)
   150  		newMessage.SetFieldByName("attributes", testFixtures[i].attributes)
   151  		msgs[i] = newMessage
   152  	}
   153  
   154  	return msgs
   155  }
   156  
   157  func RequireEqual(t *testing.T, md *desc.MessageDescriptor, expected, actual []byte) {
   158  	expectedMsg := dynamic.NewMessage(md)
   159  	require.NoError(t, expectedMsg.Unmarshal(expected))
   160  	actualMsg := dynamic.NewMessage(md)
   161  	require.NoError(t, actualMsg.Unmarshal(actual))
   162  
   163  	require.Equal(t, expectedMsg.GetFieldByName("latitude"),
   164  		actualMsg.GetFieldByName("latitude"))
   165  	require.Equal(t, expectedMsg.GetFieldByName("longitude"),
   166  		actualMsg.GetFieldByName("longitude"))
   167  	require.Equal(t, expectedMsg.GetFieldByName("deliveryID"),
   168  		actualMsg.GetFieldByName("deliveryID"))
   169  	require.Equal(t, expectedMsg.GetFieldByName("epoch"),
   170  		actualMsg.GetFieldByName("epoch"))
   171  	requireAttributesEqual(t, expectedMsg.GetFieldByName("attributes").(map[interface{}]interface{}),
   172  		actualMsg.GetFieldByName("attributes").(map[interface{}]interface{}))
   173  }
   174  
   175  func requireAttributesEqual(t *testing.T, expected, actual map[interface{}]interface{}) {
   176  	require.Equal(t, len(expected), len(actual))
   177  	for k, v := range expected {
   178  		require.Equal(t, v, actual[k])
   179  	}
   180  }
   181  
   182  func ProtoEqual(md *desc.MessageDescriptor, expected, actual []byte) bool {
   183  	expectedMsg := dynamic.NewMessage(md)
   184  	if expectedMsg.Unmarshal(expected) != nil {
   185  		return false
   186  	}
   187  	actualMsg := dynamic.NewMessage(md)
   188  	if actualMsg.Unmarshal(actual) != nil {
   189  		return false
   190  	}
   191  
   192  	if expectedMsg.GetFieldByName("latitude") !=
   193  		actualMsg.GetFieldByName("latitude") {
   194  		return false
   195  	}
   196  	if expectedMsg.GetFieldByName("longitude") !=
   197  		actualMsg.GetFieldByName("longitude") {
   198  		return false
   199  	}
   200  	if !bytes.Equal(expectedMsg.GetFieldByName("deliveryID").([]byte),
   201  		actualMsg.GetFieldByName("deliveryID").([]byte)) {
   202  		return false
   203  	}
   204  	if expectedMsg.GetFieldByName("epoch") !=
   205  		actualMsg.GetFieldByName("epoch") {
   206  		return false
   207  	}
   208  	return attributesEqual(expectedMsg.GetFieldByName("attributes").(map[interface{}]interface{}),
   209  		actualMsg.GetFieldByName("attributes").(map[interface{}]interface{}))
   210  }
   211  
   212  func attributesEqual(expected, actual map[interface{}]interface{}) bool {
   213  	if len(expected) != len(actual) {
   214  		return false
   215  	}
   216  	for k, v := range expected {
   217  		if v.(string) != actual[k].(string) {
   218  			return false
   219  		}
   220  	}
   221  	return true
   222  }
   223  
   224  type ProtoMessageIterator struct {
   225  	messages []*dynamic.Message
   226  	i        int
   227  }
   228  
   229  func NewProtoMessageIterator(messages []*dynamic.Message) *ProtoMessageIterator {
   230  	return &ProtoMessageIterator{messages: messages}
   231  }
   232  
   233  func (pmi *ProtoMessageIterator) Next() []byte {
   234  	n := pmi.messages[pmi.i%len(pmi.messages)]
   235  	pmi.i++
   236  	mbytes, err := n.Marshal()
   237  	if err != nil {
   238  		panic(err.Error())
   239  	}
   240  	return mbytes
   241  }
   242  
   243  func (pmi *ProtoMessageIterator) Reset() {
   244  	pmi.i = 0
   245  }