github.com/RevenueMonster/sqlike@v1.0.6/sqlike/database.go (about) 1 package sqlike 2 3 import ( 4 "context" 5 "database/sql" 6 "errors" 7 "fmt" 8 "io/ioutil" 9 "os" 10 "path/filepath" 11 "strings" 12 "time" 13 14 "github.com/RevenueMonster/sqlike/sql/codec" 15 "github.com/RevenueMonster/sqlike/sql/dialect" 16 "github.com/RevenueMonster/sqlike/sql/driver" 17 sqlstmt "github.com/RevenueMonster/sqlike/sql/stmt" 18 "github.com/RevenueMonster/sqlike/sqlike/indexes" 19 "github.com/RevenueMonster/sqlike/sqlike/logs" 20 "github.com/RevenueMonster/sqlike/sqlike/options" 21 "gopkg.in/yaml.v3" 22 ) 23 24 type txCallback func(ctx SessionContext) error 25 26 // Database : 27 type Database struct { 28 driverName string 29 name string 30 pk string 31 client *Client 32 driver driver.Driver 33 dialect dialect.Dialect 34 codec codec.Codecer 35 logger logs.Logger 36 } 37 38 // Name : to get current database name 39 func (db *Database) Name() string { 40 return db.name 41 } 42 43 // Table : use the table under this database 44 func (db *Database) Table(name string) *Table { 45 return &Table{ 46 dbName: db.name, 47 name: name, 48 pk: db.pk, 49 client: db.client, 50 driver: db.driver, 51 dialect: db.dialect, 52 codec: db.codec, 53 logger: db.logger, 54 } 55 } 56 57 func (db *Database) QueryRow(ctx context.Context, query string, args ...interface{}) SingleResult { 58 rslt := new(Result) 59 rslt.cache = db.client.cache 60 rslt.codec = db.codec 61 rows, err := db.driver.QueryContext(ctx, query, args...) 62 if err != nil { 63 rslt.err = err 64 return rslt 65 } 66 rslt.rows = rows 67 rslt.columnTypes, rslt.err = rows.ColumnTypes() 68 if rslt.err != nil { 69 defer rslt.rows.Close() 70 } 71 for _, col := range rslt.columnTypes { 72 rslt.columns = append(rslt.columns, col.Name()) 73 } 74 rslt.close = true 75 if !rslt.Next() { 76 rslt.err = sql.ErrNoRows 77 } 78 return rslt 79 } 80 81 // QueryStmt : 82 func (db *Database) QueryStmt(ctx context.Context, query interface{}) (*Result, error) { 83 if query == nil { 84 return nil, errors.New("sqlike: empty query statement") 85 } 86 87 stmt := sqlstmt.AcquireStmt(db.dialect) 88 defer sqlstmt.ReleaseStmt(stmt) 89 if err := db.dialect.SelectStmt(stmt, query); err != nil { 90 return nil, err 91 } 92 93 rows, err := driver.Query( 94 ctx, 95 db.driver, 96 stmt, 97 getLogger(db.logger, true), 98 ) 99 if err != nil { 100 return nil, err 101 } 102 103 rslt := new(Result) 104 rslt.cache = db.client.cache 105 rslt.codec = db.codec 106 rslt.rows = rows 107 rslt.columnTypes, rslt.err = rows.ColumnTypes() 108 if rslt.err != nil { 109 defer rslt.rows.Close() 110 } 111 for _, col := range rslt.columnTypes { 112 rslt.columns = append(rslt.columns, col.Name()) 113 } 114 return rslt, rslt.err 115 } 116 117 // BeginTransaction : 118 func (db *Database) BeginTransaction(ctx context.Context, opts ...*sql.TxOptions) (*Transaction, error) { 119 opt := &sql.TxOptions{} 120 if len(opts) > 0 { 121 opt = opts[0] 122 } 123 return db.beginTrans(ctx, opt) 124 } 125 126 func (db *Database) beginTrans(ctx context.Context, opt *sql.TxOptions) (*Transaction, error) { 127 tx, err := db.client.BeginTx(ctx, opt) 128 if err != nil { 129 return nil, err 130 } 131 return &Transaction{ 132 Context: ctx, 133 dbName: db.name, 134 pk: db.pk, 135 client: db.client, 136 driver: tx, 137 dialect: db.dialect, 138 logger: db.logger, 139 codec: db.codec, 140 }, nil 141 } 142 143 // RunInTransaction : 144 func (db *Database) RunInTransaction(ctx context.Context, cb txCallback, opts ...*options.TransactionOptions) error { 145 opt := new(options.TransactionOptions) 146 if len(opts) > 0 && opts[0] != nil { 147 opt = opts[0] 148 } 149 duration := 60 * time.Second 150 if opt.Duration.Seconds() > 0 { 151 duration = opt.Duration 152 } 153 c, cancel := context.WithTimeout(ctx, duration) 154 defer cancel() 155 tx, err := db.beginTrans(c, &sql.TxOptions{ 156 Isolation: opt.IsolationLevel, 157 ReadOnly: opt.ReadOnly, 158 }) 159 if err != nil { 160 return err 161 } 162 defer tx.RollbackTransaction() 163 if err := cb(tx); err != nil { 164 return err 165 } 166 return tx.CommitTransaction() 167 } 168 169 type indexDefinition struct { 170 Indexes []struct { 171 Table string `yaml:"table"` 172 Name string `yaml:"name"` 173 Type string `yaml:"type"` 174 Cast string `yaml:"cast"` 175 As string `yaml:"as"` 176 Comment string `yaml:"comment"` 177 Columns []struct { 178 Name string `yaml:"name"` 179 Direction string `yaml:"direction"` 180 } `yaml:"columns"` 181 } `yaml:"indexes"` 182 } 183 184 // BuildIndexes : 185 func (db *Database) BuildIndexes(ctx context.Context, paths ...string) error { 186 var ( 187 path string 188 err error 189 fi os.FileInfo 190 ) 191 if len(paths) > 0 { 192 path = paths[0] 193 fi, err = os.Stat(path) 194 if err != nil { 195 return err 196 } 197 } else { 198 pwd, _ := os.Getwd() 199 files := []string{pwd + "/index.yml", pwd + "/index.yaml"} 200 for _, f := range files { 201 fi, err = os.Stat(f) 202 if !os.IsNotExist(err) { 203 path = f 204 break 205 } 206 } 207 if err != nil { 208 return err 209 } 210 } 211 212 switch v := fi.Mode(); { 213 case v.IsDir(): 214 if err := filepath.Walk(path, func(fp string, info os.FileInfo, err error) error { 215 if info.IsDir() { 216 return nil 217 } 218 219 ext := filepath.Ext(info.Name()) 220 // only interested on yaml and yml files 221 if ext != ".yaml" && ext != ".yml" { 222 return nil 223 } 224 225 return db.readAndBuildIndexes(ctx, fp) 226 }); err != nil { 227 return err 228 } 229 230 case v.IsRegular(): 231 if err := db.readAndBuildIndexes(ctx, path); err != nil { 232 return err 233 } 234 } 235 236 return nil 237 } 238 239 func (db *Database) readAndBuildIndexes(ctx context.Context, path string) error { 240 var id indexDefinition 241 b, err := ioutil.ReadFile(path) 242 if err != nil { 243 return err 244 } 245 if err := yaml.Unmarshal(b, &id); err != nil { 246 return err 247 } 248 249 for _, idx := range id.Indexes { 250 length := len(idx.Columns) 251 columns := make([]indexes.Col, length) 252 for i, col := range idx.Columns { 253 dir := indexes.Ascending 254 col.Direction = strings.TrimSpace(strings.ToLower(col.Direction)) 255 if col.Direction == "desc" || col.Direction == "descending" { 256 dir = indexes.Descending 257 } 258 columns[i] = indexes.Col{ 259 Name: col.Name, 260 Direction: dir, 261 } 262 } 263 264 index := indexes.Index{ 265 Name: strings.TrimSpace(idx.Name), 266 Type: parseIndexType(idx.Type), 267 Cast: strings.TrimSpace(idx.Cast), 268 As: strings.TrimSpace(idx.As), 269 Columns: columns, 270 Comment: strings.TrimSpace(idx.Comment), 271 } 272 273 if exists, err := isIndexExists( 274 ctx, 275 db.name, 276 idx.Table, 277 index.GetName(), 278 db.driver, 279 db.dialect, 280 db.logger, 281 ); err != nil { 282 return err 283 } else if exists { 284 continue 285 } 286 287 iv := db.Table(idx.Table).Indexes() 288 if err := iv.CreateOne(ctx, index); err != nil { 289 return err 290 } 291 } 292 return nil 293 } 294 295 func parseIndexType(name string) (idxType indexes.Type) { 296 name = strings.TrimSpace(strings.ToLower(name)) 297 if name == "" { 298 idxType = indexes.BTree 299 return 300 } 301 302 switch name { 303 case "spatial": 304 idxType = indexes.Spatial 305 case "unique": 306 idxType = indexes.Unique 307 case "btree": 308 idxType = indexes.BTree 309 case "fulltext": 310 idxType = indexes.FullText 311 case "primary": 312 idxType = indexes.Primary 313 case "multi-valued": 314 idxType = indexes.MultiValued 315 default: 316 panic(fmt.Errorf("invalid index type %q", name)) 317 } 318 return 319 }