github.com/ngicks/gokugen@v0.0.5/impl/repository/sqlite3_repository.go (about)

     1  package repository
     2  
     3  import (
     4  	"database/sql"
     5  	"encoding/hex"
     6  	"encoding/json"
     7  	"strings"
     8  	"sync"
     9  	"time"
    10  
    11  	_ "github.com/mattn/go-sqlite3"
    12  	taskstorage "github.com/ngicks/gokugen/task_storage"
    13  )
    14  
    15  var _ taskstorage.RepositoryUpdater = &Sqlite3Repo{}
    16  
    17  const tableSqlite string = `
    18  CREATE TABLE IF NOT EXISTS taskInfo(
    19      id TEXT NOT NULL PRIMARY KEY,
    20      work_id TEXT NOT NULL,
    21      param BLOB,
    22      scheduled_time INTEGER NOT NULL,
    23      state TEXT NOT NULL,
    24      inserted_at INTEGER NOT NULL DEFAULT (strftime('%s','now') || substr(strftime('%f','now'),4)),
    25      last_modified_at INTEGER NOT NULL DEFAULT (strftime('%s','now') || substr(strftime('%f','now'),4))
    26  ) STRICT;`
    27  
    28  const triggerUpdateSqlite string = `
    29  CREATE TRIGGER IF NOT EXISTS trigger_update_taskInfo_last_modified_at AFTER UPDATE ON taskInfo
    30  BEGIN
    31      UPDATE taskInfo SET last_modified_at = strftime('%s','now') || substr(strftime('%f','now'),4) WHERE id == NEW.id;
    32  END;`
    33  
    34  const insertSqlite string = `
    35  INSERT INTO taskInfo(
    36  	id,
    37  	work_id,
    38  	param,
    39  	scheduled_time,
    40  	state
    41  ) VALUES(?,?,?,?,?)
    42  `
    43  
    44  type sqlite3TaskInfo struct {
    45  	Id               string
    46  	Work_id          string
    47  	Param            []byte
    48  	Scheduled_time   int64
    49  	State            string
    50  	Inserted_at      int64
    51  	Last_modified_at int64
    52  }
    53  
    54  func fromHighLevel(
    55  	taskId string,
    56  	taskInfo taskstorage.TaskInfo,
    57  ) (lowTaskInfo sqlite3TaskInfo, err error) {
    58  	if taskInfo.State < 0 {
    59  		err = taskstorage.ErrInvalidEnt
    60  		return
    61  	}
    62  	// Implementation detail: param must be json marshalable
    63  	var paramMarshaled []byte
    64  	if taskInfo.Param != nil {
    65  		paramMarshaled, err = json.Marshal(taskInfo.Param)
    66  		if err != nil {
    67  			return
    68  		}
    69  	}
    70  
    71  	return sqlite3TaskInfo{
    72  		Id:             taskId,
    73  		Work_id:        taskInfo.WorkId,
    74  		Param:          paramMarshaled,
    75  		Scheduled_time: taskInfo.ScheduledTime.UnixMilli(),
    76  		State:          taskInfo.State.String(),
    77  	}, nil
    78  }
    79  
    80  func (ti sqlite3TaskInfo) toHighLevel() (taskInfo taskstorage.TaskInfo, err error) {
    81  	var param any
    82  	if ti.Param != nil && len(ti.Param) != 0 {
    83  		err = json.Unmarshal(ti.Param, &param)
    84  		if err != nil {
    85  			return
    86  		}
    87  	}
    88  	state := taskstorage.NewStateFromString(ti.State)
    89  	if state < 0 {
    90  		err = taskstorage.ErrInvalidEnt
    91  		return
    92  	}
    93  	taskInfo = taskstorage.TaskInfo{
    94  		Id:            ti.Id,
    95  		WorkId:        ti.Work_id,
    96  		Param:         param,
    97  		ScheduledTime: time.UnixMilli(ti.Scheduled_time),
    98  		State:         state,
    99  		LastModified:  time.UnixMilli(ti.Last_modified_at),
   100  	}
   101  	return
   102  }
   103  
   104  type Sqlite3Repo struct {
   105  	randomStr *RandStringGenerator
   106  	mu        sync.RWMutex
   107  	db        *sql.DB
   108  }
   109  
   110  func NewSql3Repo(dbName string) (repo *Sqlite3Repo, err error) {
   111  	db, err := sql.Open("sqlite3", dbName)
   112  	if err != nil {
   113  		return
   114  	}
   115  
   116  	for _, stmt := range []string{tableSqlite, triggerUpdateSqlite} {
   117  		if err := execStmt(db, stmt); err != nil {
   118  			return nil, err
   119  		}
   120  	}
   121  
   122  	return &Sqlite3Repo{
   123  		randomStr: NewRandStringGenerator(time.Now().UnixMicro(), 16, hex.NewEncoder),
   124  		db:        db,
   125  	}, nil
   126  }
   127  
   128  func createTable(db *sql.DB) (err error) {
   129  	stmt, err := db.Prepare(tableSqlite)
   130  	if err != nil {
   131  		return
   132  	}
   133  	_, err = stmt.Exec()
   134  	if err != nil {
   135  		return
   136  	}
   137  	return
   138  }
   139  
   140  func execStmt(db *sql.DB, stmtQuery string) (err error) {
   141  	stmt, err := db.Prepare(stmtQuery)
   142  	if err != nil {
   143  		return
   144  	}
   145  	_, err = stmt.Exec()
   146  	if err != nil {
   147  		return
   148  	}
   149  	return
   150  }
   151  
   152  func (r *Sqlite3Repo) Close() error {
   153  	r.mu.Lock()
   154  	defer r.mu.Unlock()
   155  
   156  	return r.db.Close()
   157  }
   158  
   159  func (r *Sqlite3Repo) Insert(taskInfo taskstorage.TaskInfo) (taskId string, err error) {
   160  	r.mu.Lock()
   161  	defer r.mu.Unlock()
   162  	taskId, err = r.randomStr.Generate()
   163  	if err != nil {
   164  		return
   165  	}
   166  
   167  	lowlevel, err := fromHighLevel(taskId, taskInfo)
   168  	if err != nil {
   169  		return
   170  	}
   171  
   172  	stmt, err := r.db.Prepare(insertSqlite)
   173  	if err != nil {
   174  		return
   175  	}
   176  	defer stmt.Close()
   177  
   178  	_, err = stmt.Exec(
   179  		lowlevel.Id,
   180  		lowlevel.Work_id,
   181  		lowlevel.Param,
   182  		lowlevel.Scheduled_time,
   183  		lowlevel.State,
   184  	)
   185  
   186  	if err != nil {
   187  		return
   188  	}
   189  
   190  	return
   191  }
   192  
   193  func (r *Sqlite3Repo) fetchAllForQuery(query string, exec ...any) (taskInfos []taskstorage.TaskInfo, err error) {
   194  	r.mu.Lock()
   195  	defer r.mu.Unlock()
   196  
   197  	stmt, err := r.db.Prepare(query)
   198  	if err != nil {
   199  		return
   200  	}
   201  	rows, err := stmt.Query(exec...)
   202  	if err != nil {
   203  		return
   204  	}
   205  
   206  	defer rows.Close()
   207  
   208  	for rows.Next() {
   209  		lowlevel := sqlite3TaskInfo{}
   210  		err = rows.Scan(
   211  			&lowlevel.Id,
   212  			&lowlevel.Work_id,
   213  			&lowlevel.Param,
   214  			&lowlevel.Scheduled_time,
   215  			&lowlevel.State,
   216  			&lowlevel.Inserted_at, // this is unused. TODO: change `select *`` to like `select all but inserted_at` query.
   217  			&lowlevel.Last_modified_at,
   218  		)
   219  		if err != nil {
   220  			return
   221  		}
   222  		var taskInfo taskstorage.TaskInfo
   223  		taskInfo, err = lowlevel.toHighLevel()
   224  		if err != nil {
   225  			return
   226  		}
   227  		taskInfos = append(taskInfos, taskInfo)
   228  	}
   229  	return
   230  }
   231  
   232  func (r *Sqlite3Repo) GetAll() (taskInfos []taskstorage.TaskInfo, err error) {
   233  	return r.fetchAllForQuery("SELECT * FROM taskInfo ORDER BY last_modified_at ASC")
   234  }
   235  
   236  func (r *Sqlite3Repo) GetUpdatedSince(since time.Time) ([]taskstorage.TaskInfo, error) {
   237  	return r.fetchAllForQuery("SELECT * FROM taskInfo WHERE last_modified_at >= ? ORDER BY last_modified_at ASC", since.UnixMilli())
   238  }
   239  
   240  func (r *Sqlite3Repo) GetById(taskId string) (taskInfo taskstorage.TaskInfo, err error) {
   241  	infos, err := r.fetchAllForQuery("SELECT * FROM taskInfo WHERE id = ?", taskId)
   242  	if err != nil {
   243  		return
   244  	}
   245  	if len(infos) == 0 {
   246  		err = taskstorage.ErrNoEnt
   247  		return
   248  	}
   249  	return infos[0], nil
   250  }
   251  
   252  func (r *Sqlite3Repo) markState(id string, new taskstorage.TaskState) (ok bool, err error) {
   253  	r.mu.Lock()
   254  	defer r.mu.Unlock()
   255  	stmt, err := r.db.Prepare("UPDATE taskInfo SET state = ? WHERE id = ? AND (state = ? OR state = ?)")
   256  	if err != nil {
   257  		return
   258  	}
   259  	res, err := stmt.Exec(new.String(), id, taskstorage.Initialized.String(), taskstorage.Working.String())
   260  	if err != nil {
   261  		return
   262  	}
   263  	affected, err := res.RowsAffected()
   264  	if err != nil {
   265  		return
   266  	}
   267  	if affected <= 0 {
   268  		return false, nil
   269  	}
   270  	return true, nil
   271  }
   272  
   273  func (r *Sqlite3Repo) MarkAsDone(id string) (ok bool, err error) {
   274  	return r.markState(id, taskstorage.Done)
   275  }
   276  func (r *Sqlite3Repo) MarkAsCancelled(id string) (ok bool, err error) {
   277  	return r.markState(id, taskstorage.Cancelled)
   278  }
   279  func (r *Sqlite3Repo) MarkAsFailed(id string) (ok bool, err error) {
   280  	return r.markState(id, taskstorage.Failed)
   281  }
   282  
   283  func (r *Sqlite3Repo) UpdateState(id string, old, new taskstorage.TaskState) (swapped bool, err error) {
   284  	r.mu.Lock()
   285  	defer r.mu.Unlock()
   286  	stmt, err := r.db.Prepare("UPDATE taskInfo SET state = ? WHERE id = ? AND state = ?")
   287  	if err != nil {
   288  		return
   289  	}
   290  	res, err := stmt.Exec(new.String(), id, old.String())
   291  	if err != nil {
   292  		return
   293  	}
   294  	affected, err := res.RowsAffected()
   295  	if err != nil {
   296  		return
   297  	}
   298  	if affected <= 0 {
   299  		return false, nil
   300  	}
   301  	return true, nil
   302  }
   303  
   304  func (r *Sqlite3Repo) Update(id string, diff taskstorage.UpdateDiff) (err error) {
   305  	r.mu.Lock()
   306  	defer r.mu.Unlock()
   307  
   308  	args := make([]any, 0)
   309  
   310  	setter := make([]string, 0)
   311  	lowlevel, err := fromHighLevel(id, diff.Diff)
   312  	if err != nil {
   313  		return
   314  	}
   315  	if diff.UpdateKey.WorkId {
   316  		setter = append(setter, "work_id = ?")
   317  		args = append(args, lowlevel.Work_id)
   318  	}
   319  	if diff.UpdateKey.Param {
   320  		setter = append(setter, "param = ?")
   321  		args = append(args, lowlevel.Param)
   322  	}
   323  	if diff.UpdateKey.ScheduledTime {
   324  		setter = append(setter, "scheduled_time = ?")
   325  		args = append(args, lowlevel.Scheduled_time)
   326  	}
   327  	if diff.UpdateKey.State {
   328  		setter = append(setter, "state = ?")
   329  		args = append(args, lowlevel.State)
   330  	}
   331  
   332  	if len(args) == 0 {
   333  		// no-op
   334  		return
   335  	}
   336  
   337  	query := "UPDATE taskInfo SET " + strings.Join(setter, ", ") + " WHERE id = ? AND state = ?"
   338  	args = append(args, []any{id, taskstorage.Initialized.String()}...)
   339  
   340  	stmt, err := r.db.Prepare(query)
   341  	if err != nil {
   342  		return
   343  	}
   344  	res, err := stmt.Exec(args...)
   345  	if err != nil {
   346  		return
   347  	}
   348  	affected, err := res.RowsAffected()
   349  	if err != nil {
   350  		return
   351  	}
   352  	if affected <= 0 {
   353  		return taskstorage.ErrNoEnt
   354  	}
   355  	return
   356  
   357  }