github.com/decred/politeia@v1.4.0/politeiawww/legacy/user/localdb/localdb_test.go (about)

     1  // Copyright (c) 2020 The Decred developers
     2  // Use of this source code is governed by an ISC
     3  // license that can be found in the LICENSE file.
     4  
     5  package localdb
     6  
     7  import (
     8  	"encoding/base32"
     9  	"errors"
    10  	"os"
    11  	"path/filepath"
    12  	"testing"
    13  
    14  	"github.com/decred/politeia/politeiawww/legacy/user"
    15  	"github.com/google/uuid"
    16  	"github.com/gorilla/securecookie"
    17  )
    18  
    19  func setupTestData(t *testing.T) (*localdb, string) {
    20  	t.Helper()
    21  
    22  	dataDir, err := os.MkdirTemp("", "politeiawww.user.localdb.test")
    23  	if err != nil {
    24  		t.Fatalf("tmp dir: %v", err)
    25  	}
    26  
    27  	db, err := New(filepath.Join(dataDir, "localdb"))
    28  	if err != nil {
    29  		t.Fatalf("setup database: %v", err)
    30  	}
    31  
    32  	return db, dataDir
    33  }
    34  
    35  func teardownTestData(t *testing.T, db *localdb, dataDir string) {
    36  	t.Helper()
    37  
    38  	err := db.Close()
    39  	if err != nil {
    40  		t.Fatalf("close db: %v", err)
    41  	}
    42  
    43  	err = os.RemoveAll(dataDir)
    44  	if err != nil {
    45  		t.Fatalf("remove tmp dir: %v", err)
    46  	}
    47  }
    48  
    49  func newSessionID() string {
    50  	return base32.StdEncoding.EncodeToString(securecookie.GenerateRandomKey(32))
    51  }
    52  
    53  func TestSessionSave(t *testing.T) {
    54  	db, dataDir := setupTestData(t)
    55  	defer teardownTestData(t, db, dataDir)
    56  
    57  	// Save session
    58  	s := user.Session{
    59  		ID:     newSessionID(),
    60  		UserID: uuid.New(),
    61  		Values: "v1",
    62  	}
    63  	err := db.SessionSave(s)
    64  	if err != nil {
    65  		t.Error(err)
    66  	}
    67  
    68  	// Verify session
    69  	b, err := db.userdb.Get([]byte(sessionPrefix+s.ID), nil)
    70  	if err != nil {
    71  		t.Error(err)
    72  	}
    73  	sessionInDB, err := user.DecodeSession(b)
    74  	if err != nil {
    75  		t.Error(err)
    76  	}
    77  	if *sessionInDB != s {
    78  		t.Errorf("got session %v, want %v", sessionInDB, s)
    79  	}
    80  
    81  	// Save a session that already exists
    82  	s.Values = "v2"
    83  	err = db.SessionSave(s)
    84  	if err != nil {
    85  		t.Error(err)
    86  	}
    87  
    88  	// Verify session was updated correctly
    89  	b, err = db.userdb.Get([]byte(sessionPrefix+s.ID), nil)
    90  	if err != nil {
    91  		t.Error(err)
    92  	}
    93  	sessionInDB, err = user.DecodeSession(b)
    94  	if err != nil {
    95  		t.Error(err)
    96  	}
    97  	if *sessionInDB != s {
    98  		t.Errorf("got session %v, want %v", sessionInDB, s)
    99  	}
   100  }
   101  
   102  func TestSessionGetByID(t *testing.T) {
   103  	db, dataDir := setupTestData(t)
   104  	defer teardownTestData(t, db, dataDir)
   105  
   106  	// Save session
   107  	s := user.Session{
   108  		ID:     newSessionID(),
   109  		UserID: uuid.New(),
   110  		Values: "",
   111  	}
   112  	err := db.SessionSave(s)
   113  	if err != nil {
   114  		t.Error(err)
   115  	}
   116  
   117  	// Get existing session
   118  	sessionInDB, err := db.SessionGetByID(s.ID)
   119  	if err != nil {
   120  		t.Error(err)
   121  	}
   122  	if *sessionInDB != s {
   123  		t.Errorf("got session %v, want %v", sessionInDB, s)
   124  	}
   125  
   126  	// Get session that does not exist
   127  	_, err = db.SessionGetByID(uuid.New().String())
   128  	if !errors.Is(err, user.ErrSessionNotFound) {
   129  		t.Errorf("got error '%v', want '%v'", err, user.ErrSessionNotFound)
   130  	}
   131  }
   132  
   133  func TestSessionDeleteByID(t *testing.T) {
   134  	db, dataDir := setupTestData(t)
   135  	defer teardownTestData(t, db, dataDir)
   136  
   137  	// Session 1
   138  	s1 := user.Session{
   139  		ID:     newSessionID(),
   140  		UserID: uuid.New(),
   141  		Values: "",
   142  	}
   143  
   144  	// Session 2
   145  	s2 := user.Session{
   146  		ID:     newSessionID(),
   147  		UserID: uuid.New(),
   148  		Values: "",
   149  	}
   150  
   151  	// Save sessions
   152  	err := db.SessionSave(s1)
   153  	if err != nil {
   154  		t.Fatal(err)
   155  	}
   156  	err = db.SessionSave(s2)
   157  	if err != nil {
   158  		t.Fatal(err)
   159  	}
   160  
   161  	// Delete one of the sessions
   162  	err = db.SessionDeleteByID(s1.ID)
   163  	if err != nil {
   164  		t.Error(err)
   165  	}
   166  
   167  	// Verify session was deleted
   168  	_, err = db.SessionGetByID(s1.ID)
   169  	if !errors.Is(err, user.ErrSessionNotFound) {
   170  		t.Errorf("error got '%v', want '%v'", err, user.ErrSessionNotFound)
   171  	}
   172  
   173  	// Verify the remaining session still exists
   174  	s2DB, err := db.SessionGetByID(s2.ID)
   175  	if err != nil {
   176  		t.Errorf("error got '%v', want nil", err)
   177  	}
   178  	if *s2DB != s2 {
   179  		t.Errorf("session got %v, want %v", s2DB, s2)
   180  	}
   181  }
   182  
   183  func TestIsUserRecord(t *testing.T) {
   184  	tests := []struct {
   185  		input string
   186  		want  bool
   187  	}{
   188  		{
   189  			input: UserVersionKey,
   190  			want:  false,
   191  		},
   192  		{
   193  			input: LastPaywallAddressIndex,
   194  			want:  false,
   195  		},
   196  		{
   197  			input: sessionPrefix + uuid.New().String(),
   198  			want:  false,
   199  		},
   200  	}
   201  
   202  	for _, test := range tests {
   203  		got := isUserRecord(test.input)
   204  		if got != test.want {
   205  			t.Errorf("isUserRecord(%v) got %v, want %v",
   206  				test.input, got, test.want)
   207  		}
   208  	}
   209  }