github.com/RevenueMonster/sqlike@v1.0.6/sqlike/table.go (about) 1 package sqlike 2 3 import ( 4 "context" 5 "errors" 6 "reflect" 7 "strings" 8 9 "github.com/RevenueMonster/sqlike/reflext" 10 "github.com/RevenueMonster/sqlike/sql" 11 12 "github.com/RevenueMonster/sqlike/sql/codec" 13 "github.com/RevenueMonster/sqlike/sql/dialect" 14 sqldriver "github.com/RevenueMonster/sqlike/sql/driver" 15 sqlstmt "github.com/RevenueMonster/sqlike/sql/stmt" 16 "github.com/RevenueMonster/sqlike/sqlike/logs" 17 ) 18 19 // ErrNoRecordAffected : 20 var ErrNoRecordAffected = errors.New("no record affected") 21 22 // ErrExpectedStruct : 23 var ErrExpectedStruct = errors.New("expected struct as a source") 24 25 // ErrEmptyFields : 26 var ErrEmptyFields = errors.New("empty fields") 27 28 // Table : 29 type Table struct { 30 // current database name 31 dbName string 32 33 // table name 34 name string 35 36 // default primary key 37 pk string 38 39 client *Client 40 41 // sql driver 42 driver sqldriver.Driver 43 44 // sql dialect 45 dialect dialect.Dialect 46 47 // encoder and decoder for the value 48 codec codec.Codecer 49 logger logs.Logger 50 } 51 52 // Rename : rename the current table name to new table name 53 func (tb *Table) Rename(ctx context.Context, name string) error { 54 stmt := sqlstmt.AcquireStmt(tb.dialect) 55 defer sqlstmt.ReleaseStmt(stmt) 56 tb.dialect.RenameTable(stmt, tb.dbName, tb.name, name) 57 _, err := sqldriver.Execute( 58 ctx, 59 tb.driver, 60 stmt, 61 tb.logger, 62 ) 63 return err 64 } 65 66 // Exists : this will return true when the table exists in the database 67 func (tb *Table) Exists(ctx context.Context) bool { 68 var count int 69 stmt := sqlstmt.AcquireStmt(tb.dialect) 70 defer sqlstmt.ReleaseStmt(stmt) 71 tb.dialect.HasTable(stmt, tb.dbName, tb.name) 72 if err := sqldriver.QueryRowContext( 73 ctx, 74 tb.driver, 75 stmt, 76 tb.logger, 77 ).Scan(&count); err != nil { 78 panic(err) 79 } 80 return count > 0 81 } 82 83 // Columns : 84 func (tb *Table) Columns() *ColumnView { 85 return &ColumnView{tb: tb} 86 } 87 88 // ListColumns : list all the column of the table. 89 func (tb *Table) ListColumns(ctx context.Context) ([]Column, error) { 90 stmt := sqlstmt.AcquireStmt(tb.dialect) 91 defer sqlstmt.ReleaseStmt(stmt) 92 tb.dialect.GetColumns(stmt, tb.dbName, tb.name) 93 rows, err := sqldriver.Query( 94 ctx, 95 tb.driver, 96 stmt, 97 tb.logger, 98 ) 99 if err != nil { 100 return nil, err 101 } 102 defer rows.Close() 103 104 columns := make([]Column, 0) 105 for i := 0; rows.Next(); i++ { 106 col := Column{} 107 108 if err := rows.Scan( 109 &col.Position, 110 &col.Name, 111 &col.Type, 112 &col.DefaultValue, 113 &col.IsNullable, 114 &col.DataType, 115 &col.Charset, 116 &col.Collation, 117 &col.Comment, 118 &col.Extra, 119 ); err != nil { 120 return nil, err 121 } 122 123 col.Type = strings.ToUpper(col.Type) 124 col.DataType = strings.ToUpper(col.DataType) 125 126 columns = append(columns, col) 127 } 128 return columns, nil 129 } 130 131 // ListIndexes : list all the index of the table. 132 func (tb *Table) ListIndexes(ctx context.Context) ([]Index, error) { 133 stmt := sqlstmt.AcquireStmt(tb.dialect) 134 defer sqlstmt.ReleaseStmt(stmt) 135 tb.dialect.GetIndexes(stmt, tb.dbName, tb.name) 136 rows, err := sqldriver.Query( 137 ctx, 138 tb.driver, 139 stmt, 140 tb.logger, 141 ) 142 if err != nil { 143 return nil, err 144 } 145 defer rows.Close() 146 147 idxs := make([]Index, 0) 148 for i := 0; rows.Next(); i++ { 149 idx := Index{} 150 if err := rows.Scan( 151 &idx.Name, 152 &idx.Type, 153 &idx.IsUnique, 154 ); err != nil { 155 return nil, err 156 } 157 idx.IsUnique = !idx.IsUnique 158 idxs = append(idxs, idx) 159 } 160 return idxs, nil 161 } 162 163 // MustMigrate : this will ensure the migrate is complete, otherwise it will panic 164 func (tb Table) MustMigrate(ctx context.Context, entity interface{}) { 165 err := tb.Migrate(ctx, entity) 166 if err != nil { 167 panic(err) 168 } 169 } 170 171 // Migrate : migrate will create a new table follows by the definition of struct tag, alter when the table already exists 172 func (tb *Table) Migrate(ctx context.Context, entity interface{}) error { 173 return tb.migrateOne(ctx, tb.client.cache, entity, false) 174 } 175 176 // UnsafeMigrate : unsafe migration will delete non-exist index and columns, beware when you use this 177 func (tb *Table) UnsafeMigrate(ctx context.Context, entity interface{}) error { 178 return tb.migrateOne(ctx, tb.client.cache, entity, true) 179 } 180 181 // MustUnsafeMigrate : this will panic if it get error on unsafe migrate 182 func (tb *Table) MustUnsafeMigrate(ctx context.Context, entity interface{}) { 183 err := tb.migrateOne(ctx, tb.client.cache, entity, true) 184 if err != nil { 185 panic(err) 186 } 187 } 188 189 // Truncate : delete all the table data. 190 func (tb *Table) Truncate(ctx context.Context) (err error) { 191 stmt := sqlstmt.AcquireStmt(tb.dialect) 192 defer sqlstmt.ReleaseStmt(stmt) 193 tb.dialect.TruncateTable(stmt, tb.dbName, tb.name) 194 _, err = sqldriver.Execute( 195 ctx, 196 tb.driver, 197 stmt, 198 tb.logger, 199 ) 200 return 201 } 202 203 // DropIfExists : will drop the table only if it exists. 204 func (tb Table) DropIfExists(ctx context.Context) (err error) { 205 stmt := sqlstmt.AcquireStmt(tb.dialect) 206 defer sqlstmt.ReleaseStmt(stmt) 207 tb.dialect.DropTable(stmt, tb.dbName, tb.name, true) 208 _, err = sqldriver.Execute( 209 ctx, 210 tb.driver, 211 stmt, 212 tb.logger, 213 ) 214 return 215 } 216 217 // Drop : drop the table, but it might throw error when the table is not exists 218 func (tb Table) Drop(ctx context.Context) (err error) { 219 stmt := sqlstmt.AcquireStmt(tb.dialect) 220 defer sqlstmt.ReleaseStmt(stmt) 221 tb.dialect.DropTable(stmt, tb.dbName, tb.name, false) 222 _, err = sqldriver.Execute( 223 ctx, 224 tb.driver, 225 stmt, 226 tb.logger, 227 ) 228 return 229 } 230 231 // Replace : 232 func (tb *Table) Replace(ctx context.Context, fields []string, query *sql.SelectStmt) error { 233 stmt := sqlstmt.AcquireStmt(tb.dialect) 234 defer sqlstmt.ReleaseStmt(stmt) 235 if err := tb.dialect.Replace( 236 stmt, 237 tb.dbName, 238 tb.name, 239 fields, 240 query, 241 ); err != nil { 242 return err 243 } 244 245 if _, err := sqldriver.Execute( 246 ctx, 247 tb.driver, 248 stmt, 249 tb.logger, 250 ); err != nil { 251 return err 252 } 253 return nil 254 } 255 256 // Indexes : 257 func (tb *Table) Indexes() *IndexView { 258 return &IndexView{tb: tb} 259 } 260 261 // HasIndexByName : 262 func (tb *Table) HasIndexByName(ctx context.Context, name string) (bool, error) { 263 return isIndexExists( 264 ctx, 265 tb.dbName, 266 tb.name, 267 name, 268 tb.driver, 269 tb.dialect, 270 tb.logger, 271 ) 272 } 273 274 func (tb *Table) migrateOne(ctx context.Context, cache reflext.StructMapper, entity interface{}, unsafe bool) error { 275 v := reflext.ValueOf(entity) 276 if !v.IsValid() { 277 return ErrInvalidInput 278 } 279 280 t := reflext.Deref(v.Type()) 281 if !reflext.IsKind(t, reflect.Struct) { 282 return ErrExpectedStruct 283 } 284 285 cdc := cache.CodecByType(t) 286 fields := skipColumns(cdc.Properties(), nil) 287 if len(fields) < 1 { 288 return ErrEmptyFields 289 } 290 291 if !tb.Exists(ctx) { 292 return tb.createTable(ctx, fields) 293 } 294 295 columns, err := tb.ListColumns(ctx) 296 if err != nil { 297 return err 298 } 299 idxs, err := tb.ListIndexes(ctx) 300 if err != nil { 301 return err 302 } 303 return tb.alterTable(ctx, fields, columns, idxs, unsafe) 304 } 305 306 func (tb *Table) createTable(ctx context.Context, fields []reflext.StructFielder) error { 307 stmt := sqlstmt.AcquireStmt(tb.dialect) 308 defer sqlstmt.ReleaseStmt(stmt) 309 if err := tb.dialect.CreateTable( 310 stmt, 311 tb.dbName, 312 tb.name, 313 tb.pk, 314 tb.client.DriverInfo, 315 fields, 316 ); err != nil { 317 return err 318 } 319 if _, err := sqldriver.Execute( 320 ctx, 321 tb.driver, 322 stmt, 323 tb.logger, 324 ); err != nil { 325 return err 326 } 327 return nil 328 } 329 330 func (tb *Table) alterTable(ctx context.Context, fields []reflext.StructFielder, columns []Column, indexs []Index, unsafe bool) error { 331 cols := make([]string, len(columns)) 332 for i, col := range columns { 333 cols[i] = col.Name 334 } 335 idxs := make([]string, len(indexs)) 336 for i, idx := range indexs { 337 idxs[i] = idx.Name 338 } 339 stmt := sqlstmt.AcquireStmt(tb.dialect) 340 defer sqlstmt.ReleaseStmt(stmt) 341 tb.dialect.HasPrimaryKey(stmt, tb.dbName, tb.name) 342 var count uint 343 if err := sqldriver.QueryRowContext( 344 ctx, 345 tb.driver, 346 stmt, 347 tb.logger, 348 ).Scan(&count); err != nil { 349 return err 350 } 351 stmt.Reset() 352 if err := tb.dialect.AlterTable( 353 stmt, 354 tb.dbName, tb.name, tb.pk, count > 0, 355 tb.client.DriverInfo, 356 fields, cols, idxs, unsafe, 357 ); err != nil { 358 return err 359 } 360 if _, err := sqldriver.Execute( 361 ctx, 362 tb.driver, 363 stmt, 364 tb.logger, 365 ); err != nil { 366 return err 367 } 368 return nil 369 }