github.com/Azareal/Gosora@v0.0.0-20210729070923-553e66b59003/common/mfa_store.go (about) 1 package common 2 3 import ( 4 "database/sql" 5 "errors" 6 "strings" 7 8 qgen "github.com/Azareal/Gosora/query_gen" 9 ) 10 11 var MFAstore MFAStore 12 var ErrMFAScratchIndexOutOfBounds = errors.New("That MFA scratch index is out of bounds") 13 14 type MFAItemStmts struct { 15 update *sql.Stmt 16 delete *sql.Stmt 17 } 18 19 var mfaItemStmts MFAItemStmts 20 21 func init() { 22 DbInits.Add(func(acc *qgen.Accumulator) error { 23 mfaItemStmts = MFAItemStmts{ 24 update: acc.Update("users_2fa_keys").Set("scratch1=?,scratch2=?,scratch3=?,scratch4=?,scratch5=?,scratch6=?,scratch7=?,scratch8=?").Where("uid=?").Prepare(), 25 delete: acc.Delete("users_2fa_keys").Where("uid=?").Prepare(), 26 } 27 return acc.FirstError() 28 }) 29 } 30 31 type MFAItem struct { 32 UID int 33 Secret string 34 Scratch []string 35 } 36 37 func (i *MFAItem) BurnScratch(index int) error { 38 if index < 0 || len(i.Scratch) <= index { 39 return ErrMFAScratchIndexOutOfBounds 40 } 41 newScratch, err := mfaCreateScratch() 42 if err != nil { 43 return err 44 } 45 i.Scratch[index] = newScratch 46 47 _, err = mfaItemStmts.update.Exec(i.Scratch[0], i.Scratch[1], i.Scratch[2], i.Scratch[3], i.Scratch[4], i.Scratch[5], i.Scratch[6], i.Scratch[7], i.UID) 48 return err 49 } 50 51 func (i *MFAItem) Delete() error { 52 _, err := mfaItemStmts.delete.Exec(i.UID) 53 return err 54 } 55 56 func mfaCreateScratch() (string, error) { 57 code, err := GenerateStd32SafeString(8) 58 return strings.Replace(code, "=", "", -1), err 59 } 60 61 type MFAStore interface { 62 Get(id int) (*MFAItem, error) 63 Create(secret string, uid int) (err error) 64 } 65 66 type SQLMFAStore struct { 67 get *sql.Stmt 68 create *sql.Stmt 69 } 70 71 func NewSQLMFAStore(acc *qgen.Accumulator) (*SQLMFAStore, error) { 72 return &SQLMFAStore{ 73 get: acc.Select("users_2fa_keys").Columns("secret,scratch1,scratch2,scratch3,scratch4,scratch5,scratch6,scratch7,scratch8").Where("uid=?").Prepare(), 74 create: acc.Insert("users_2fa_keys").Columns("uid,secret,scratch1,scratch2,scratch3,scratch4,scratch5,scratch6,scratch7,scratch8,createdAt").Fields("?,?,?,?,?,?,?,?,?,?,UTC_TIMESTAMP()").Prepare(), 75 }, acc.FirstError() 76 } 77 78 // TODO: Write a test for this 79 func (s *SQLMFAStore) Get(id int) (*MFAItem, error) { 80 i := MFAItem{UID: id, Scratch: make([]string, 8)} 81 err := s.get.QueryRow(id).Scan(&i.Secret, &i.Scratch[0], &i.Scratch[1], &i.Scratch[2], &i.Scratch[3], &i.Scratch[4], &i.Scratch[5], &i.Scratch[6], &i.Scratch[7]) 82 return &i, err 83 84 } 85 86 // TODO: Write a test for this 87 func (s *SQLMFAStore) Create(secret string, uid int) (err error) { 88 params := make([]interface{}, 10) 89 params[0] = uid 90 params[1] = secret 91 for i := 2; i < len(params); i++ { 92 code, err := mfaCreateScratch() 93 if err != nil { 94 return err 95 } 96 params[i] = code 97 } 98 99 _, err = s.create.Exec(params...) 100 return err 101 }