github.com/decred/politeia@v1.4.0/politeiad/backendv2/tstorebe/store/mysql/mysql_test.go (about)

     1  // Copyright (c) 2021-2022 The Decred developers
     2  // Use of this source code is governed by an ISC
     3  // license that can be found in the LICENSE file.
     4  
     5  package mysql
     6  
     7  import (
     8  	"bytes"
     9  	"fmt"
    10  	"testing"
    11  
    12  	"github.com/DATA-DOG/go-sqlmock"
    13  	"github.com/decred/politeia/util"
    14  	"github.com/decred/politeia/util/unittest"
    15  )
    16  
    17  // newTestMySQL returns a new mysql structure that has been setup for testing.
    18  func newTestMySQL(t *testing.T) (*mysqlCtx, func()) {
    19  	t.Helper()
    20  
    21  	// Setup the mock sql database
    22  	opt := sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)
    23  	db, mock, err := sqlmock.New(opt)
    24  	if err != nil {
    25  		t.Fatal(err)
    26  	}
    27  	cleanup := func() {
    28  		defer db.Close()
    29  	}
    30  
    31  	// Setup the mysql struct
    32  	s := &mysqlCtx{
    33  		db:      db,
    34  		testing: true,
    35  		mock:    mock,
    36  	}
    37  
    38  	// Derive a test encryption key
    39  	password := "passwordsosikrit"
    40  	s.argon2idKey(password, util.NewArgon2Params())
    41  
    42  	return s, cleanup
    43  }
    44  
    45  func TestGet(t *testing.T) {
    46  	// Setup the mysql test struct
    47  	s, cleanup := newTestMySQL(t)
    48  	defer cleanup()
    49  
    50  	// Test the single query code path
    51  	t.Run("single query", func(t *testing.T) {
    52  		testGetSingleQuery(t, s)
    53  	})
    54  
    55  	// Test the multiple query code path
    56  	t.Run("multi query", func(t *testing.T) {
    57  		testGetMultiQuery(t, s)
    58  	})
    59  }
    60  
    61  // testGetSingleQuery tests the mysql Get() method when the number of records
    62  // being retrieved can be fit into a single MySQL SELECT statement.
    63  func testGetSingleQuery(t *testing.T, s *mysqlCtx) {
    64  	var (
    65  		// Test params
    66  		key1   = "key1"
    67  		key2   = "key2"
    68  		value1 = []byte("value1")
    69  		value2 = []byte("value2")
    70  
    71  		// rows contains the rows that will be returned
    72  		// from the mocked sql query.
    73  		rows = sqlmock.NewRows([]string{"k", "v"}).
    74  			AddRow(key1, value1).
    75  			AddRow(key2, value2)
    76  	)
    77  
    78  	// Setup the sql expectations
    79  	s.mock.ExpectQuery("SELECT k, v FROM kv WHERE k IN (?,?);").
    80  		WithArgs(key1, key2).
    81  		WillReturnRows(rows).
    82  		RowsWillBeClosed()
    83  
    84  	// Run the test
    85  	blobs, err := s.Get([]string{key1, key2})
    86  	if err != nil {
    87  		t.Error(err)
    88  	}
    89  
    90  	// Verify the sql expectations
    91  	err = s.mock.ExpectationsWereMet()
    92  	if err != nil {
    93  		t.Error(err)
    94  	}
    95  
    96  	// Verify the returned value
    97  	if len(blobs) != 2 {
    98  		t.Errorf("got %v blobs, want 2", len(blobs))
    99  	}
   100  	v1 := blobs[key1]
   101  	if !bytes.Equal(v1, value1) {
   102  		t.Errorf("got '%s' for value 1; want '%s'", v1, value1)
   103  	}
   104  	v2 := blobs[key2]
   105  	if !bytes.Equal(v2, value2) {
   106  		t.Errorf("got '%s' for value 2; want '%s'", v2, value2)
   107  	}
   108  }
   109  
   110  // testGetMultiQuery tests the mysql Get() method when the number of records
   111  // being retrieved cannot fit into a single MySQL SELECT statement and must
   112  // be broken up into multiple SELECT statements.
   113  func testGetMultiQuery(t *testing.T, s *mysqlCtx) {
   114  	// Prepare the test data. The maximum number of records
   115  	// that can be returned in a single SELECT statement is
   116  	// limited by the maxPlaceholders variable. We multiply
   117  	// this by 2 in order to ensure that multiple queries
   118  	// are required.
   119  	var (
   120  		keysCount = maxPlaceholders * 2
   121  		keys      = make([]string, 0, keysCount)
   122  
   123  		// These variables contain the rows that will be
   124  		// returned from each mocked sql query.
   125  		rows1 = sqlmock.NewRows([]string{"k", "v"})
   126  		rows2 = sqlmock.NewRows([]string{"k", "v"})
   127  	)
   128  	for i := 0; i < keysCount; i++ {
   129  		key := fmt.Sprintf("key%v", i)
   130  		value := []byte(fmt.Sprintf("value%v", i))
   131  		keys = append(keys, key)
   132  
   133  		if i < keysCount/2 {
   134  			// Add to the first query results
   135  			rows1.AddRow(key, value)
   136  		} else {
   137  			// Add to the second query results
   138  			rows2.AddRow(key, value)
   139  		}
   140  	}
   141  
   142  	// Setup the sql expectations for both queries
   143  	query := buildSelectQuery(keysCount / 2)
   144  	s.mock.ExpectQuery(query).
   145  		WillReturnRows(rows1).
   146  		RowsWillBeClosed()
   147  	s.mock.ExpectQuery(query).
   148  		WillReturnRows(rows2).
   149  		RowsWillBeClosed()
   150  
   151  	// Run the test
   152  	blobs, err := s.Get(keys)
   153  	if err != nil {
   154  		t.Errorf("multi query get failed; skipped printing " +
   155  			"the error for readability")
   156  	}
   157  
   158  	// Verify the sql expectations
   159  	err = s.mock.ExpectationsWereMet()
   160  	if err != nil {
   161  		t.Errorf("multi query sql expectations were not ; " +
   162  			"met; skipped printing the error for readability")
   163  	}
   164  
   165  	// Verify the returned values contain entries from both
   166  	// queries.
   167  	var (
   168  		idx1 = keysCount/2 - 1
   169  		idx2 = keysCount/2 + 2
   170  
   171  		key1 = fmt.Sprintf("key%v", idx1)
   172  		key2 = fmt.Sprintf("key%v", idx2)
   173  
   174  		value1 = []byte(fmt.Sprintf("value%v", idx1))
   175  		value2 = []byte(fmt.Sprintf("value%v", idx2))
   176  	)
   177  	if len(blobs) != keysCount {
   178  		t.Errorf("got %v blobs, want %v", len(blobs), keysCount)
   179  	}
   180  	v1 := blobs[key1]
   181  	if !bytes.Equal(v1, value1) {
   182  		t.Errorf("got '%s' for value 1; want '%s'", v1, value1)
   183  	}
   184  	v2 := blobs[key2]
   185  	if !bytes.Equal(v2, value2) {
   186  		t.Errorf("got '%s' for value 2; want '%s'", v2, value2)
   187  	}
   188  }
   189  
   190  func TestBuildSelectStatements(t *testing.T) {
   191  	var (
   192  		// sizeLimit is the max number of placeholders
   193  		// that the function will include in a single
   194  		// select statement.
   195  		sizeLimit = 2
   196  
   197  		// Test keys
   198  		key1 = "key1"
   199  		key2 = "key2"
   200  		key3 = "key3"
   201  		key4 = "key4"
   202  	)
   203  	var tests = []struct {
   204  		name       string
   205  		keys       []string
   206  		statements []selectStatement
   207  	}{
   208  		{
   209  			"one statement under the size limit",
   210  			[]string{key1},
   211  			[]selectStatement{
   212  				{
   213  					Query: buildSelectQuery(1),
   214  					Args:  []interface{}{key1},
   215  				},
   216  			},
   217  		},
   218  		{
   219  			"one statement at the size limit",
   220  			[]string{key1, key2},
   221  			[]selectStatement{
   222  				{
   223  					Query: buildSelectQuery(2),
   224  					Args:  []interface{}{key1, key2},
   225  				},
   226  			},
   227  		},
   228  		{
   229  			"second statement under the size limit",
   230  			[]string{key1, key2, key3},
   231  			[]selectStatement{
   232  				{
   233  					Query: buildSelectQuery(2),
   234  					Args:  []interface{}{key1, key2},
   235  				},
   236  				{
   237  					Query: buildSelectQuery(1),
   238  					Args:  []interface{}{key3},
   239  				},
   240  			},
   241  		},
   242  		{
   243  			"second statement at the size limit",
   244  			[]string{key1, key2, key3, key4},
   245  			[]selectStatement{
   246  				{
   247  					Query: buildSelectQuery(2),
   248  					Args:  []interface{}{key1, key2},
   249  				},
   250  				{
   251  					Query: buildSelectQuery(2),
   252  					Args:  []interface{}{key3, key4},
   253  				},
   254  			},
   255  		},
   256  	}
   257  	for _, tc := range tests {
   258  		t.Run(tc.name, func(t *testing.T) {
   259  			// Run the test
   260  			statements := buildSelectStatements(tc.keys, sizeLimit)
   261  
   262  			// Verify the output
   263  			diff := unittest.DeepEqual(statements, tc.statements)
   264  			if diff != "" {
   265  				t.Error(diff)
   266  			}
   267  		})
   268  	}
   269  }
   270  
   271  func TestBuildPlaceholders(t *testing.T) {
   272  	var tests = []struct {
   273  		placeholders int
   274  		output       string
   275  	}{
   276  		{0, "()"},
   277  		{1, "(?)"},
   278  		{3, "(?,?,?)"},
   279  	}
   280  	for _, tc := range tests {
   281  		name := fmt.Sprintf("%v placeholders", tc.placeholders)
   282  		t.Run(name, func(t *testing.T) {
   283  			output := buildPlaceholders(tc.placeholders)
   284  			if output != tc.output {
   285  				t.Errorf("got %v, want %v", output, tc.output)
   286  			}
   287  		})
   288  	}
   289  }