github.com/ngicks/gokugen@v0.0.5/impl/repository/sqlite3_repository.go (about) 1 package repository 2 3 import ( 4 "database/sql" 5 "encoding/hex" 6 "encoding/json" 7 "strings" 8 "sync" 9 "time" 10 11 _ "github.com/mattn/go-sqlite3" 12 taskstorage "github.com/ngicks/gokugen/task_storage" 13 ) 14 15 var _ taskstorage.RepositoryUpdater = &Sqlite3Repo{} 16 17 const tableSqlite string = ` 18 CREATE TABLE IF NOT EXISTS taskInfo( 19 id TEXT NOT NULL PRIMARY KEY, 20 work_id TEXT NOT NULL, 21 param BLOB, 22 scheduled_time INTEGER NOT NULL, 23 state TEXT NOT NULL, 24 inserted_at INTEGER NOT NULL DEFAULT (strftime('%s','now') || substr(strftime('%f','now'),4)), 25 last_modified_at INTEGER NOT NULL DEFAULT (strftime('%s','now') || substr(strftime('%f','now'),4)) 26 ) STRICT;` 27 28 const triggerUpdateSqlite string = ` 29 CREATE TRIGGER IF NOT EXISTS trigger_update_taskInfo_last_modified_at AFTER UPDATE ON taskInfo 30 BEGIN 31 UPDATE taskInfo SET last_modified_at = strftime('%s','now') || substr(strftime('%f','now'),4) WHERE id == NEW.id; 32 END;` 33 34 const insertSqlite string = ` 35 INSERT INTO taskInfo( 36 id, 37 work_id, 38 param, 39 scheduled_time, 40 state 41 ) VALUES(?,?,?,?,?) 42 ` 43 44 type sqlite3TaskInfo struct { 45 Id string 46 Work_id string 47 Param []byte 48 Scheduled_time int64 49 State string 50 Inserted_at int64 51 Last_modified_at int64 52 } 53 54 func fromHighLevel( 55 taskId string, 56 taskInfo taskstorage.TaskInfo, 57 ) (lowTaskInfo sqlite3TaskInfo, err error) { 58 if taskInfo.State < 0 { 59 err = taskstorage.ErrInvalidEnt 60 return 61 } 62 // Implementation detail: param must be json marshalable 63 var paramMarshaled []byte 64 if taskInfo.Param != nil { 65 paramMarshaled, err = json.Marshal(taskInfo.Param) 66 if err != nil { 67 return 68 } 69 } 70 71 return sqlite3TaskInfo{ 72 Id: taskId, 73 Work_id: taskInfo.WorkId, 74 Param: paramMarshaled, 75 Scheduled_time: taskInfo.ScheduledTime.UnixMilli(), 76 State: taskInfo.State.String(), 77 }, nil 78 } 79 80 func (ti sqlite3TaskInfo) toHighLevel() (taskInfo taskstorage.TaskInfo, err error) { 81 var param any 82 if ti.Param != nil && len(ti.Param) != 0 { 83 err = json.Unmarshal(ti.Param, ¶m) 84 if err != nil { 85 return 86 } 87 } 88 state := taskstorage.NewStateFromString(ti.State) 89 if state < 0 { 90 err = taskstorage.ErrInvalidEnt 91 return 92 } 93 taskInfo = taskstorage.TaskInfo{ 94 Id: ti.Id, 95 WorkId: ti.Work_id, 96 Param: param, 97 ScheduledTime: time.UnixMilli(ti.Scheduled_time), 98 State: state, 99 LastModified: time.UnixMilli(ti.Last_modified_at), 100 } 101 return 102 } 103 104 type Sqlite3Repo struct { 105 randomStr *RandStringGenerator 106 mu sync.RWMutex 107 db *sql.DB 108 } 109 110 func NewSql3Repo(dbName string) (repo *Sqlite3Repo, err error) { 111 db, err := sql.Open("sqlite3", dbName) 112 if err != nil { 113 return 114 } 115 116 for _, stmt := range []string{tableSqlite, triggerUpdateSqlite} { 117 if err := execStmt(db, stmt); err != nil { 118 return nil, err 119 } 120 } 121 122 return &Sqlite3Repo{ 123 randomStr: NewRandStringGenerator(time.Now().UnixMicro(), 16, hex.NewEncoder), 124 db: db, 125 }, nil 126 } 127 128 func createTable(db *sql.DB) (err error) { 129 stmt, err := db.Prepare(tableSqlite) 130 if err != nil { 131 return 132 } 133 _, err = stmt.Exec() 134 if err != nil { 135 return 136 } 137 return 138 } 139 140 func execStmt(db *sql.DB, stmtQuery string) (err error) { 141 stmt, err := db.Prepare(stmtQuery) 142 if err != nil { 143 return 144 } 145 _, err = stmt.Exec() 146 if err != nil { 147 return 148 } 149 return 150 } 151 152 func (r *Sqlite3Repo) Close() error { 153 r.mu.Lock() 154 defer r.mu.Unlock() 155 156 return r.db.Close() 157 } 158 159 func (r *Sqlite3Repo) Insert(taskInfo taskstorage.TaskInfo) (taskId string, err error) { 160 r.mu.Lock() 161 defer r.mu.Unlock() 162 taskId, err = r.randomStr.Generate() 163 if err != nil { 164 return 165 } 166 167 lowlevel, err := fromHighLevel(taskId, taskInfo) 168 if err != nil { 169 return 170 } 171 172 stmt, err := r.db.Prepare(insertSqlite) 173 if err != nil { 174 return 175 } 176 defer stmt.Close() 177 178 _, err = stmt.Exec( 179 lowlevel.Id, 180 lowlevel.Work_id, 181 lowlevel.Param, 182 lowlevel.Scheduled_time, 183 lowlevel.State, 184 ) 185 186 if err != nil { 187 return 188 } 189 190 return 191 } 192 193 func (r *Sqlite3Repo) fetchAllForQuery(query string, exec ...any) (taskInfos []taskstorage.TaskInfo, err error) { 194 r.mu.Lock() 195 defer r.mu.Unlock() 196 197 stmt, err := r.db.Prepare(query) 198 if err != nil { 199 return 200 } 201 rows, err := stmt.Query(exec...) 202 if err != nil { 203 return 204 } 205 206 defer rows.Close() 207 208 for rows.Next() { 209 lowlevel := sqlite3TaskInfo{} 210 err = rows.Scan( 211 &lowlevel.Id, 212 &lowlevel.Work_id, 213 &lowlevel.Param, 214 &lowlevel.Scheduled_time, 215 &lowlevel.State, 216 &lowlevel.Inserted_at, // this is unused. TODO: change `select *`` to like `select all but inserted_at` query. 217 &lowlevel.Last_modified_at, 218 ) 219 if err != nil { 220 return 221 } 222 var taskInfo taskstorage.TaskInfo 223 taskInfo, err = lowlevel.toHighLevel() 224 if err != nil { 225 return 226 } 227 taskInfos = append(taskInfos, taskInfo) 228 } 229 return 230 } 231 232 func (r *Sqlite3Repo) GetAll() (taskInfos []taskstorage.TaskInfo, err error) { 233 return r.fetchAllForQuery("SELECT * FROM taskInfo ORDER BY last_modified_at ASC") 234 } 235 236 func (r *Sqlite3Repo) GetUpdatedSince(since time.Time) ([]taskstorage.TaskInfo, error) { 237 return r.fetchAllForQuery("SELECT * FROM taskInfo WHERE last_modified_at >= ? ORDER BY last_modified_at ASC", since.UnixMilli()) 238 } 239 240 func (r *Sqlite3Repo) GetById(taskId string) (taskInfo taskstorage.TaskInfo, err error) { 241 infos, err := r.fetchAllForQuery("SELECT * FROM taskInfo WHERE id = ?", taskId) 242 if err != nil { 243 return 244 } 245 if len(infos) == 0 { 246 err = taskstorage.ErrNoEnt 247 return 248 } 249 return infos[0], nil 250 } 251 252 func (r *Sqlite3Repo) markState(id string, new taskstorage.TaskState) (ok bool, err error) { 253 r.mu.Lock() 254 defer r.mu.Unlock() 255 stmt, err := r.db.Prepare("UPDATE taskInfo SET state = ? WHERE id = ? AND (state = ? OR state = ?)") 256 if err != nil { 257 return 258 } 259 res, err := stmt.Exec(new.String(), id, taskstorage.Initialized.String(), taskstorage.Working.String()) 260 if err != nil { 261 return 262 } 263 affected, err := res.RowsAffected() 264 if err != nil { 265 return 266 } 267 if affected <= 0 { 268 return false, nil 269 } 270 return true, nil 271 } 272 273 func (r *Sqlite3Repo) MarkAsDone(id string) (ok bool, err error) { 274 return r.markState(id, taskstorage.Done) 275 } 276 func (r *Sqlite3Repo) MarkAsCancelled(id string) (ok bool, err error) { 277 return r.markState(id, taskstorage.Cancelled) 278 } 279 func (r *Sqlite3Repo) MarkAsFailed(id string) (ok bool, err error) { 280 return r.markState(id, taskstorage.Failed) 281 } 282 283 func (r *Sqlite3Repo) UpdateState(id string, old, new taskstorage.TaskState) (swapped bool, err error) { 284 r.mu.Lock() 285 defer r.mu.Unlock() 286 stmt, err := r.db.Prepare("UPDATE taskInfo SET state = ? WHERE id = ? AND state = ?") 287 if err != nil { 288 return 289 } 290 res, err := stmt.Exec(new.String(), id, old.String()) 291 if err != nil { 292 return 293 } 294 affected, err := res.RowsAffected() 295 if err != nil { 296 return 297 } 298 if affected <= 0 { 299 return false, nil 300 } 301 return true, nil 302 } 303 304 func (r *Sqlite3Repo) Update(id string, diff taskstorage.UpdateDiff) (err error) { 305 r.mu.Lock() 306 defer r.mu.Unlock() 307 308 args := make([]any, 0) 309 310 setter := make([]string, 0) 311 lowlevel, err := fromHighLevel(id, diff.Diff) 312 if err != nil { 313 return 314 } 315 if diff.UpdateKey.WorkId { 316 setter = append(setter, "work_id = ?") 317 args = append(args, lowlevel.Work_id) 318 } 319 if diff.UpdateKey.Param { 320 setter = append(setter, "param = ?") 321 args = append(args, lowlevel.Param) 322 } 323 if diff.UpdateKey.ScheduledTime { 324 setter = append(setter, "scheduled_time = ?") 325 args = append(args, lowlevel.Scheduled_time) 326 } 327 if diff.UpdateKey.State { 328 setter = append(setter, "state = ?") 329 args = append(args, lowlevel.State) 330 } 331 332 if len(args) == 0 { 333 // no-op 334 return 335 } 336 337 query := "UPDATE taskInfo SET " + strings.Join(setter, ", ") + " WHERE id = ? AND state = ?" 338 args = append(args, []any{id, taskstorage.Initialized.String()}...) 339 340 stmt, err := r.db.Prepare(query) 341 if err != nil { 342 return 343 } 344 res, err := stmt.Exec(args...) 345 if err != nil { 346 return 347 } 348 affected, err := res.RowsAffected() 349 if err != nil { 350 return 351 } 352 if affected <= 0 { 353 return taskstorage.ErrNoEnt 354 } 355 return 356 357 }