github.com/grafviktor/keep-my-secret@v0.9.10-0.20230908165355-19f35cce90e5/internal/storage/sql/sql_storage.go (about)

     1  package sql
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"errors"
     7  	"fmt"
     8  
     9  	"github.com/mattn/go-sqlite3"
    10  
    11  	"github.com/grafviktor/keep-my-secret/internal/constant"
    12  	"github.com/grafviktor/keep-my-secret/internal/model"
    13  )
    14  
    15  type sqlStorage struct {
    16  	*sql.DB
    17  }
    18  
    19  func (ss sqlStorage) AddUser(ctx context.Context, u *model.User) (*model.User, error) {
    20  	_, err := ss.ExecContext(ctx, sqlInsertUser, u.Login, u.HashedPassword, "", u.DataKey)
    21  	if err != nil {
    22  		var sqliteErr sqlite3.Error
    23  		if errors.As(err, &sqliteErr) {
    24  			if sqliteErr.Code == sqlite3.ErrConstraint {
    25  				return nil, constant.ErrDuplicateRecord
    26  			}
    27  		}
    28  		return nil, err
    29  	}
    30  
    31  	return u, nil
    32  }
    33  
    34  func (ss sqlStorage) GetUser(ctx context.Context, login string) (*model.User, error) {
    35  	u := model.User{}
    36  	err := ss.QueryRowContext(ctx, sqlSelectUser, login).
    37  		Scan(&u.ID, &u.Login, &u.HashedPassword, &u.RestorePassword, &u.DataKey)
    38  
    39  	switch {
    40  	case errors.Is(err, sql.ErrNoRows):
    41  		return nil, constant.ErrNotFound
    42  	case err != nil:
    43  		return nil, err
    44  	}
    45  
    46  	return &u, nil
    47  }
    48  
    49  func (ss sqlStorage) SaveSecret(ctx context.Context, s *model.Secret, login string) (*model.Secret, error) {
    50  	var result sql.Result
    51  	var err error
    52  
    53  	if s.ID != 0 {
    54  		result, err = ss.ExecContext(
    55  			ctx,
    56  			sqlUpdateSecret,
    57  			s.Type,
    58  			s.Title,
    59  			s.Login,
    60  			s.Password,
    61  			s.Note,
    62  			s.FileName,
    63  			s.CardholderName,
    64  			s.CardNumber,
    65  			s.Expiration,
    66  			s.SecurityCode,
    67  			s.ID,
    68  			login,
    69  		)
    70  	} else {
    71  		result, err = ss.ExecContext(
    72  			ctx,
    73  			sqlInsertSecret,
    74  			s.Type,
    75  			s.Title,
    76  			s.Login,
    77  			s.Password,
    78  			s.Note,
    79  			s.File,
    80  			s.FileName,
    81  			s.CardholderName,
    82  			s.CardNumber,
    83  			s.Expiration,
    84  			s.SecurityCode,
    85  			login,
    86  		)
    87  	}
    88  
    89  	if err != nil {
    90  		return nil, err
    91  	}
    92  
    93  	insertedID, err := result.LastInsertId()
    94  	if err != nil {
    95  		return nil, err
    96  	}
    97  
    98  	s.ID = insertedID
    99  
   100  	return s, nil
   101  }
   102  
   103  func (ss sqlStorage) GetSecretsByUser(ctx context.Context, login string) (map[int]*model.Secret, error) {
   104  	result := make(map[int]*model.Secret)
   105  	rows, err := ss.QueryContext(ctx, sqlFindSecretsByUser, login)
   106  	if errors.Is(err, sql.ErrNoRows) {
   107  		// This case is valid, though user exists, she doesn't have any secrets
   108  		return result, nil
   109  	} else if err != nil {
   110  		return nil, err
   111  	}
   112  	defer rows.Close()
   113  
   114  	for rows.Next() {
   115  		var secret model.Secret
   116  
   117  		err = rows.Scan(
   118  			&secret.ID,
   119  			&secret.Type,
   120  			&secret.Title,
   121  			&secret.Login,
   122  			&secret.Password,
   123  			&secret.Note,
   124  			&secret.File,
   125  			&secret.FileName,
   126  			&secret.CardholderName,
   127  			&secret.CardNumber,
   128  			&secret.Expiration,
   129  			&secret.SecurityCode,
   130  		)
   131  		if err != nil {
   132  			return nil, err
   133  		}
   134  
   135  		result[int(secret.ID)] = &secret
   136  	}
   137  
   138  	if err = rows.Err(); err != nil {
   139  		return nil, err
   140  	}
   141  
   142  	return result, nil
   143  }
   144  
   145  func (ss sqlStorage) DeleteSecret(ctx context.Context, id, login string) error {
   146  	result, err := ss.ExecContext(ctx, sqlDeleteSecret, id, login)
   147  	if err != nil {
   148  		return err
   149  	}
   150  
   151  	rows, err := result.RowsAffected()
   152  	if err != nil {
   153  		return err
   154  	}
   155  
   156  	if rows != 1 {
   157  		return fmt.Errorf("expected to affect 1 row, affected %d", rows)
   158  	}
   159  
   160  	return nil
   161  }
   162  
   163  func (ss sqlStorage) GetSecret(ctx context.Context, secretID, login string) (*model.Secret, error) {
   164  	secret := model.Secret{}
   165  
   166  	err := ss.QueryRowContext(ctx, sqlGetSecretByID, secretID, login).Scan(
   167  		&secret.ID,
   168  		&secret.Type,
   169  		&secret.Title,
   170  		&secret.Login,
   171  		&secret.Password,
   172  		&secret.Note,
   173  		&secret.File,
   174  		&secret.FileName,
   175  		&secret.CardholderName,
   176  		&secret.CardNumber,
   177  		&secret.Expiration,
   178  		&secret.SecurityCode,
   179  	)
   180  
   181  	switch {
   182  	case errors.Is(err, sql.ErrNoRows):
   183  		return nil, constant.ErrNotFound
   184  	case err != nil:
   185  		return nil, err
   186  	}
   187  
   188  	return &secret, nil
   189  }
   190  
   191  func (ss sqlStorage) Close() error {
   192  	return ss.DB.Close()
   193  }
   194  
   195  func NewSQLStorage(ctx context.Context, dsn string) sqlStorage {
   196  	db, err := sql.Open("sqlite3", "./kms.db")
   197  	if err != nil {
   198  		panic(err)
   199  	}
   200  
   201  	_, err = db.Exec(sqlCreateUserTable)
   202  	if err != nil {
   203  		panic(err)
   204  	}
   205  
   206  	_, err = db.Exec(sqlCreateSecretTable)
   207  	if err != nil {
   208  		panic(err)
   209  	}
   210  
   211  	return sqlStorage{
   212  		DB: db,
   213  	}
   214  }