github.com/decred/politeia@v1.4.0/politeiawww/sessions/mysql/mysql_test.go (about)

     1  // Copyright (c) 2022 The Decred developers
     2  // Use of this source code is governed by an ISC
     3  // license that can be found in the LICENSE file.
     4  
     5  package mysql
     6  
     7  import (
     8  	"database/sql"
     9  	"database/sql/driver"
    10  	"encoding/json"
    11  	"errors"
    12  	"fmt"
    13  	"testing"
    14  
    15  	"github.com/DATA-DOG/go-sqlmock"
    16  	"github.com/decred/politeia/politeiawww/sessions"
    17  )
    18  
    19  // newTestMySQL returns a mysql context that has been setup for testing along
    20  // with the sql mocking context and a cleanup function. Invocation of the
    21  // cleanup function should be deferred by the caller.
    22  func newTestMySQL(t *testing.T) (*mysql, sqlmock.Sqlmock, func()) {
    23  	t.Helper()
    24  
    25  	// sqlmock defaults to using the expected SQL string as a regular
    26  	// expression to match incoming query strings. The QueryMatcherEqual
    27  	// overrides this default behavior and does a full case sensitive
    28  	// match.
    29  	opts := sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual)
    30  	db, mock, err := sqlmock.New(opts)
    31  	if err != nil {
    32  		t.Fatal(err)
    33  	}
    34  	cleanup := func() {
    35  		defer db.Close()
    36  	}
    37  	m := &mysql{
    38  		db:            db,
    39  		sessionMaxAge: 1,
    40  		opts: &Opts{
    41  			TableName: defaultTableName,
    42  			OpTimeout: defaultOpTimeout,
    43  		},
    44  	}
    45  
    46  	return m, mock, cleanup
    47  }
    48  
    49  func TestSave(t *testing.T) {
    50  	m, mock, cleanup := newTestMySQL(t)
    51  	defer cleanup()
    52  
    53  	// Setup the test data
    54  	var (
    55  		sessionID = "test-session-id"
    56  		es        = sessions.EncodedSession{
    57  			Values: "test-values",
    58  		}
    59  	)
    60  	esB, err := json.Marshal(es)
    61  	if err != nil {
    62  		t.Fatal(err)
    63  	}
    64  
    65  	q := `INSERT INTO %v
    66      (id, encoded_session, created_at) VALUES (?, ?, ?)
    67      ON DUPLICATE KEY UPDATE
    68      encoded_session = VALUES(encoded_session)`
    69  
    70  	q = fmt.Sprintf(q, m.opts.TableName)
    71  
    72  	// Test the unexpected error path
    73  	unexpectedErr := errors.New("unexpected error")
    74  	mock.ExpectExec(q).
    75  		WithArgs(sessionID, esB, AnyInt64{}).
    76  		WillReturnError(unexpectedErr)
    77  
    78  	err = m.Save(sessionID, es)
    79  	if !errors.Is(err, unexpectedErr) {
    80  		t.Errorf("got err '%v', want '%v'", err, unexpectedErr)
    81  	}
    82  
    83  	// Test the success path
    84  	mock.ExpectExec(q).
    85  		WithArgs(sessionID, esB, AnyInt64{}).
    86  		WillReturnResult(sqlmock.NewResult(0, 1))
    87  
    88  	err = m.Save(sessionID, es)
    89  	if err != nil {
    90  		t.Error(err)
    91  	}
    92  }
    93  
    94  func TestDel(t *testing.T) {
    95  	m, mock, cleanup := newTestMySQL(t)
    96  	defer cleanup()
    97  
    98  	// Setup the test data
    99  	var (
   100  		q = fmt.Sprintf("DELETE FROM %v WHERE id = ?", m.opts.TableName)
   101  
   102  		sessionID = "test-session-id"
   103  	)
   104  
   105  	// Test the unexpected error path
   106  	unexpectedErr := errors.New("unexpected error")
   107  	mock.ExpectExec(q).
   108  		WithArgs(sessionID).
   109  		WillReturnError(unexpectedErr)
   110  
   111  	err := m.Del(sessionID)
   112  	if !errors.Is(err, unexpectedErr) {
   113  		t.Errorf("got err '%v', want '%v'", err, unexpectedErr)
   114  	}
   115  
   116  	// Test the success path
   117  	mock.ExpectExec(q).
   118  		WithArgs(sessionID).
   119  		WillReturnResult(sqlmock.NewResult(0, 1))
   120  
   121  	err = m.Del(sessionID)
   122  	if err != nil {
   123  		t.Error(err)
   124  	}
   125  }
   126  
   127  func TestGet(t *testing.T) {
   128  	m, mock, cleanup := newTestMySQL(t)
   129  	defer cleanup()
   130  
   131  	// Setup the test data
   132  	var (
   133  		q = fmt.Sprintf("SELECT encoded_session FROM %v WHERE id = ?",
   134  			m.opts.TableName)
   135  
   136  		sessionID = "test-session-id"
   137  		es        = sessions.EncodedSession{
   138  			Values: "test-values",
   139  		}
   140  	)
   141  	esB, err := json.Marshal(es)
   142  	if err != nil {
   143  		t.Fatal(err)
   144  	}
   145  
   146  	// Test the not found error path
   147  	mock.ExpectQuery(q).
   148  		WithArgs(sessionID).
   149  		WillReturnError(sql.ErrNoRows)
   150  
   151  	_, err = m.Get(sessionID)
   152  	if !errors.Is(err, sessions.ErrNotFound) {
   153  		t.Errorf("got err '%v', want '%v'", err, sessions.ErrNotFound)
   154  	}
   155  
   156  	// Test the unexpected error path
   157  	unexpectedErr := errors.New("unexpected error")
   158  	mock.ExpectQuery(q).
   159  		WithArgs(sessionID).
   160  		WillReturnError(unexpectedErr)
   161  
   162  	_, err = m.Get(sessionID)
   163  	if !errors.Is(err, unexpectedErr) {
   164  		t.Errorf("got err '%v', want '%v'", err, unexpectedErr)
   165  	}
   166  
   167  	// Test the success path
   168  	rows := sqlmock.NewRows([]string{"encoded_session"}).AddRow(esB)
   169  	mock.ExpectQuery(q).
   170  		WithArgs(sessionID).
   171  		WillReturnRows(rows)
   172  
   173  	r, err := m.Get(sessionID)
   174  	switch {
   175  	case err != nil:
   176  		t.Error(err)
   177  	case r == nil:
   178  		t.Errorf("got nil session, want %+v", es)
   179  	case r.Values != es.Values:
   180  		t.Errorf("got sesions values '%v', want '%v'", r.Values, es.Values)
   181  	}
   182  }
   183  
   184  // AnyInt64 can be passed in as a sqlmock prepared statement argument when the
   185  // caller knows that the argument will be an int64, but does not know what the
   186  // exact value of the int64 will be.
   187  type AnyInt64 struct{}
   188  
   189  // Match satisfies sqlmock Argument interface.
   190  func (a AnyInt64) Match(v driver.Value) bool {
   191  	_, ok := v.(int64)
   192  	return ok
   193  }