github.com/decred/politeia@v1.4.0/politeiawww/sessions/mysql/mysql_test.go (about) 1 // Copyright (c) 2022 The Decred developers 2 // Use of this source code is governed by an ISC 3 // license that can be found in the LICENSE file. 4 5 package mysql 6 7 import ( 8 "database/sql" 9 "database/sql/driver" 10 "encoding/json" 11 "errors" 12 "fmt" 13 "testing" 14 15 "github.com/DATA-DOG/go-sqlmock" 16 "github.com/decred/politeia/politeiawww/sessions" 17 ) 18 19 // newTestMySQL returns a mysql context that has been setup for testing along 20 // with the sql mocking context and a cleanup function. Invocation of the 21 // cleanup function should be deferred by the caller. 22 func newTestMySQL(t *testing.T) (*mysql, sqlmock.Sqlmock, func()) { 23 t.Helper() 24 25 // sqlmock defaults to using the expected SQL string as a regular 26 // expression to match incoming query strings. The QueryMatcherEqual 27 // overrides this default behavior and does a full case sensitive 28 // match. 29 opts := sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual) 30 db, mock, err := sqlmock.New(opts) 31 if err != nil { 32 t.Fatal(err) 33 } 34 cleanup := func() { 35 defer db.Close() 36 } 37 m := &mysql{ 38 db: db, 39 sessionMaxAge: 1, 40 opts: &Opts{ 41 TableName: defaultTableName, 42 OpTimeout: defaultOpTimeout, 43 }, 44 } 45 46 return m, mock, cleanup 47 } 48 49 func TestSave(t *testing.T) { 50 m, mock, cleanup := newTestMySQL(t) 51 defer cleanup() 52 53 // Setup the test data 54 var ( 55 sessionID = "test-session-id" 56 es = sessions.EncodedSession{ 57 Values: "test-values", 58 } 59 ) 60 esB, err := json.Marshal(es) 61 if err != nil { 62 t.Fatal(err) 63 } 64 65 q := `INSERT INTO %v 66 (id, encoded_session, created_at) VALUES (?, ?, ?) 67 ON DUPLICATE KEY UPDATE 68 encoded_session = VALUES(encoded_session)` 69 70 q = fmt.Sprintf(q, m.opts.TableName) 71 72 // Test the unexpected error path 73 unexpectedErr := errors.New("unexpected error") 74 mock.ExpectExec(q). 75 WithArgs(sessionID, esB, AnyInt64{}). 76 WillReturnError(unexpectedErr) 77 78 err = m.Save(sessionID, es) 79 if !errors.Is(err, unexpectedErr) { 80 t.Errorf("got err '%v', want '%v'", err, unexpectedErr) 81 } 82 83 // Test the success path 84 mock.ExpectExec(q). 85 WithArgs(sessionID, esB, AnyInt64{}). 86 WillReturnResult(sqlmock.NewResult(0, 1)) 87 88 err = m.Save(sessionID, es) 89 if err != nil { 90 t.Error(err) 91 } 92 } 93 94 func TestDel(t *testing.T) { 95 m, mock, cleanup := newTestMySQL(t) 96 defer cleanup() 97 98 // Setup the test data 99 var ( 100 q = fmt.Sprintf("DELETE FROM %v WHERE id = ?", m.opts.TableName) 101 102 sessionID = "test-session-id" 103 ) 104 105 // Test the unexpected error path 106 unexpectedErr := errors.New("unexpected error") 107 mock.ExpectExec(q). 108 WithArgs(sessionID). 109 WillReturnError(unexpectedErr) 110 111 err := m.Del(sessionID) 112 if !errors.Is(err, unexpectedErr) { 113 t.Errorf("got err '%v', want '%v'", err, unexpectedErr) 114 } 115 116 // Test the success path 117 mock.ExpectExec(q). 118 WithArgs(sessionID). 119 WillReturnResult(sqlmock.NewResult(0, 1)) 120 121 err = m.Del(sessionID) 122 if err != nil { 123 t.Error(err) 124 } 125 } 126 127 func TestGet(t *testing.T) { 128 m, mock, cleanup := newTestMySQL(t) 129 defer cleanup() 130 131 // Setup the test data 132 var ( 133 q = fmt.Sprintf("SELECT encoded_session FROM %v WHERE id = ?", 134 m.opts.TableName) 135 136 sessionID = "test-session-id" 137 es = sessions.EncodedSession{ 138 Values: "test-values", 139 } 140 ) 141 esB, err := json.Marshal(es) 142 if err != nil { 143 t.Fatal(err) 144 } 145 146 // Test the not found error path 147 mock.ExpectQuery(q). 148 WithArgs(sessionID). 149 WillReturnError(sql.ErrNoRows) 150 151 _, err = m.Get(sessionID) 152 if !errors.Is(err, sessions.ErrNotFound) { 153 t.Errorf("got err '%v', want '%v'", err, sessions.ErrNotFound) 154 } 155 156 // Test the unexpected error path 157 unexpectedErr := errors.New("unexpected error") 158 mock.ExpectQuery(q). 159 WithArgs(sessionID). 160 WillReturnError(unexpectedErr) 161 162 _, err = m.Get(sessionID) 163 if !errors.Is(err, unexpectedErr) { 164 t.Errorf("got err '%v', want '%v'", err, unexpectedErr) 165 } 166 167 // Test the success path 168 rows := sqlmock.NewRows([]string{"encoded_session"}).AddRow(esB) 169 mock.ExpectQuery(q). 170 WithArgs(sessionID). 171 WillReturnRows(rows) 172 173 r, err := m.Get(sessionID) 174 switch { 175 case err != nil: 176 t.Error(err) 177 case r == nil: 178 t.Errorf("got nil session, want %+v", es) 179 case r.Values != es.Values: 180 t.Errorf("got sesions values '%v', want '%v'", r.Values, es.Values) 181 } 182 } 183 184 // AnyInt64 can be passed in as a sqlmock prepared statement argument when the 185 // caller knows that the argument will be an int64, but does not know what the 186 // exact value of the int64 will be. 187 type AnyInt64 struct{} 188 189 // Match satisfies sqlmock Argument interface. 190 func (a AnyInt64) Match(v driver.Value) bool { 191 _, ok := v.(int64) 192 return ok 193 }