github.com/readium/readium-lcp-server@v0.0.0-20240101192032-6e95190e99f1/frontend/webuser/webuser.go (about)

     1  // Copyright 2020 Readium Foundation. All rights reserved.
     2  // Use of this source code is governed by a BSD-style license
     3  // that can be found in the LICENSE file exposed on Github (readium) in the project repository.
     4  
     5  package webuser
     6  
     7  import (
     8  	"database/sql"
     9  	"errors"
    10  	"log"
    11  
    12  	"github.com/readium/readium-lcp-server/config"
    13  	uuid "github.com/satori/go.uuid"
    14  )
    15  
    16  //ErrNotFound error trown when user is not found
    17  var ErrNotFound = errors.New("User not found")
    18  
    19  // WebUser interface for user db interaction
    20  type WebUser interface {
    21  	Get(id int64) (User, error)
    22  	GetByEmail(email string) (User, error)
    23  	Add(c User) error
    24  	Update(c User) error
    25  	DeleteUser(UserID int64) error
    26  	ListUsers(page int, pageNum int) func() (User, error)
    27  }
    28  
    29  //User struct defines a user
    30  type User struct {
    31  	ID       int64  `json:"id"`
    32  	UUID     string `json:"uuid"`
    33  	Name     string `json:"name,omitempty"`
    34  	Email    string `json:"email,omitempty"`
    35  	Password string `json:"password,omitempty"`
    36  	Hint     string `json:"hint"`
    37  }
    38  
    39  type dbUser struct {
    40  	db           *sql.DB
    41  	dbGetUser    *sql.Stmt
    42  	dbGetByEmail *sql.Stmt
    43  	dbList       *sql.Stmt
    44  }
    45  
    46  // Get returns a user
    47  func (user dbUser) Get(id int64) (User, error) {
    48  
    49  	row := user.dbGetUser.QueryRow(id)
    50  	var c User
    51  	err := row.Scan(&c.ID, &c.UUID, &c.Name, &c.Email, &c.Password, &c.Hint)
    52  	if err != nil {
    53  		return User{}, ErrNotFound
    54  	}
    55  	return c, err
    56  }
    57  
    58  // GetByEmail returns a user
    59  func (user dbUser) GetByEmail(email string) (User, error) {
    60  
    61  	row := user.dbGetByEmail.QueryRow(email)
    62  	var c User
    63  	err := row.Scan(&c.ID, &c.UUID, &c.Name, &c.Email, &c.Password, &c.Hint)
    64  	return c, err
    65  }
    66  
    67  // Add inserts a user
    68  func (user dbUser) Add(newUser User) error {
    69  
    70  	// Create uuid
    71  	uid, err_u := uuid.NewV4()
    72  	if err_u != nil {
    73  		return err_u
    74  	}
    75  	newUser.UUID = uid.String()
    76  
    77  	_, err := user.db.Exec("INSERT INTO \"user\" (uuid, name, email, password, hint) VALUES (?, ?, ?, ?, ?)",
    78  		newUser.UUID, newUser.Name, newUser.Email, newUser.Password, newUser.Hint)
    79  	return err
    80  }
    81  
    82  // Update updates a user
    83  func (user dbUser) Update(changedUser User) error {
    84  
    85  	_, err := user.db.Exec("UPDATE \"user\" SET name=? , email=?, password=?, hint=? WHERE id=?",
    86  		changedUser.Name, changedUser.Email, changedUser.Password, changedUser.Hint, changedUser.ID)
    87  	return err
    88  }
    89  
    90  // DeleteUser deletes a user
    91  func (user dbUser) DeleteUser(userID int64) error {
    92  
    93  	// delete user purchases
    94  	_, err := user.db.Exec("DELETE FROM purchase WHERE user_id=?", userID)
    95  	if err != nil {
    96  		return err
    97  	}
    98  
    99  	// delete user
   100  	_, err = user.db.Exec("DELETE FROM \"user\" WHERE id=?", userID)
   101  	return err
   102  }
   103  
   104  // ListUsers lists users
   105  func (user dbUser) ListUsers(page int, pageNum int) func() (User, error) {
   106  
   107  	var rows *sql.Rows
   108  	var err error
   109  	driver, _ := config.GetDatabase(config.Config.FrontendServer.Database)
   110  	if driver == "mssql" {
   111  		rows, err = user.dbList.Query(pageNum*page, page)
   112  	} else {
   113  		rows, err = user.dbList.Query(page, pageNum*page)
   114  	}
   115  	if err != nil {
   116  		return func() (User, error) { return User{}, err }
   117  	}
   118  
   119  	return func() (User, error) {
   120  		var u User
   121  		var err error
   122  		if rows.Next() {
   123  			err = rows.Scan(&u.ID, &u.UUID, &u.Name, &u.Email, &u.Password, &u.Hint)
   124  		} else {
   125  			rows.Close()
   126  			err = ErrNotFound
   127  		}
   128  		return u, err
   129  	}
   130  }
   131  
   132  //Open  returns a WebUser interface (db interaction)
   133  func Open(db *sql.DB) (i WebUser, err error) {
   134  
   135  	driver, _ := config.GetDatabase(config.Config.FrontendServer.Database)
   136  	// if sqlite, create the content table in the frontend db if it does not exist
   137  	if driver == "sqlite3" {
   138  		_, err = db.Exec(tableDef)
   139  		if err != nil {
   140  			log.Println("Error creating user table")
   141  			return
   142  		}
   143  	}
   144  
   145  	var dbGetUser *sql.Stmt
   146  	dbGetUser, err = db.Prepare("SELECT id, uuid, name, email, password, hint FROM \"user\" WHERE id = ?")
   147  	if err != nil {
   148  		return
   149  	}
   150  
   151  	var dbGetByEmail *sql.Stmt
   152  	if driver == "mssql" {
   153  		dbGetByEmail, err = db.Prepare("SELECT TOP 1 id, uuid, name, email, password, hint FROM \"user\" WHERE email = ?")
   154  	} else {
   155  		dbGetByEmail, err = db.Prepare("SELECT id, uuid, name, email, password, hint FROM \"user\" WHERE email = ? LIMIT 1")
   156  	}
   157  	if err != nil {
   158  		return
   159  	}
   160  
   161  	var dbList *sql.Stmt
   162  	if driver == "mssql" {
   163  		dbList, err = db.Prepare("SELECT id, uuid, name, email, password, hint	FROM \"user\" ORDER BY email desc OFFSET ? ROWS FETCH NEXT ? ROWS ONLY")
   164  	} else {
   165  		dbList, err = db.Prepare("SELECT id, uuid, name, email, password, hint	FROM \"user\" ORDER BY email desc LIMIT ? OFFSET ?")
   166  	}
   167  	if err != nil {
   168  		return
   169  	}
   170  
   171  	i = dbUser{db, dbGetUser, dbGetByEmail, dbList}
   172  	return
   173  }
   174  
   175  const tableDef = "CREATE TABLE IF NOT EXISTS \"user\" (" +
   176  	"id integer NOT NULL PRIMARY KEY," +
   177  	"uuid varchar(255) NOT NULL," +
   178  	"name varchar(64) NOT NULL," +
   179  	"email varchar(64) NOT NULL," +
   180  	"password varchar(64) NOT NULL," +
   181  	"hint varchar(64) NOT NULL)"