github.com/Azareal/Gosora@v0.0.0-20210729070923-553e66b59003/common/mfa_store.go (about)

     1  package common
     2  
     3  import (
     4  	"database/sql"
     5  	"errors"
     6  	"strings"
     7  
     8  	qgen "github.com/Azareal/Gosora/query_gen"
     9  )
    10  
    11  var MFAstore MFAStore
    12  var ErrMFAScratchIndexOutOfBounds = errors.New("That MFA scratch index is out of bounds")
    13  
    14  type MFAItemStmts struct {
    15  	update *sql.Stmt
    16  	delete *sql.Stmt
    17  }
    18  
    19  var mfaItemStmts MFAItemStmts
    20  
    21  func init() {
    22  	DbInits.Add(func(acc *qgen.Accumulator) error {
    23  		mfaItemStmts = MFAItemStmts{
    24  			update: acc.Update("users_2fa_keys").Set("scratch1=?,scratch2=?,scratch3=?,scratch4=?,scratch5=?,scratch6=?,scratch7=?,scratch8=?").Where("uid=?").Prepare(),
    25  			delete: acc.Delete("users_2fa_keys").Where("uid=?").Prepare(),
    26  		}
    27  		return acc.FirstError()
    28  	})
    29  }
    30  
    31  type MFAItem struct {
    32  	UID     int
    33  	Secret  string
    34  	Scratch []string
    35  }
    36  
    37  func (i *MFAItem) BurnScratch(index int) error {
    38  	if index < 0 || len(i.Scratch) <= index {
    39  		return ErrMFAScratchIndexOutOfBounds
    40  	}
    41  	newScratch, err := mfaCreateScratch()
    42  	if err != nil {
    43  		return err
    44  	}
    45  	i.Scratch[index] = newScratch
    46  
    47  	_, err = mfaItemStmts.update.Exec(i.Scratch[0], i.Scratch[1], i.Scratch[2], i.Scratch[3], i.Scratch[4], i.Scratch[5], i.Scratch[6], i.Scratch[7], i.UID)
    48  	return err
    49  }
    50  
    51  func (i *MFAItem) Delete() error {
    52  	_, err := mfaItemStmts.delete.Exec(i.UID)
    53  	return err
    54  }
    55  
    56  func mfaCreateScratch() (string, error) {
    57  	code, err := GenerateStd32SafeString(8)
    58  	return strings.Replace(code, "=", "", -1), err
    59  }
    60  
    61  type MFAStore interface {
    62  	Get(id int) (*MFAItem, error)
    63  	Create(secret string, uid int) (err error)
    64  }
    65  
    66  type SQLMFAStore struct {
    67  	get    *sql.Stmt
    68  	create *sql.Stmt
    69  }
    70  
    71  func NewSQLMFAStore(acc *qgen.Accumulator) (*SQLMFAStore, error) {
    72  	return &SQLMFAStore{
    73  		get:    acc.Select("users_2fa_keys").Columns("secret,scratch1,scratch2,scratch3,scratch4,scratch5,scratch6,scratch7,scratch8").Where("uid=?").Prepare(),
    74  		create: acc.Insert("users_2fa_keys").Columns("uid,secret,scratch1,scratch2,scratch3,scratch4,scratch5,scratch6,scratch7,scratch8,createdAt").Fields("?,?,?,?,?,?,?,?,?,?,UTC_TIMESTAMP()").Prepare(),
    75  	}, acc.FirstError()
    76  }
    77  
    78  // TODO: Write a test for this
    79  func (s *SQLMFAStore) Get(id int) (*MFAItem, error) {
    80  	i := MFAItem{UID: id, Scratch: make([]string, 8)}
    81  	err := s.get.QueryRow(id).Scan(&i.Secret, &i.Scratch[0], &i.Scratch[1], &i.Scratch[2], &i.Scratch[3], &i.Scratch[4], &i.Scratch[5], &i.Scratch[6], &i.Scratch[7])
    82  	return &i, err
    83  
    84  }
    85  
    86  // TODO: Write a test for this
    87  func (s *SQLMFAStore) Create(secret string, uid int) (err error) {
    88  	params := make([]interface{}, 10)
    89  	params[0] = uid
    90  	params[1] = secret
    91  	for i := 2; i < len(params); i++ {
    92  		code, err := mfaCreateScratch()
    93  		if err != nil {
    94  			return err
    95  		}
    96  		params[i] = code
    97  	}
    98  
    99  	_, err = s.create.Exec(params...)
   100  	return err
   101  }