github.com/RevenueMonster/sqlike@v1.0.6/sqlike/table.go (about)

     1  package sqlike
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"reflect"
     7  	"strings"
     8  
     9  	"github.com/RevenueMonster/sqlike/reflext"
    10  	"github.com/RevenueMonster/sqlike/sql"
    11  
    12  	"github.com/RevenueMonster/sqlike/sql/codec"
    13  	"github.com/RevenueMonster/sqlike/sql/dialect"
    14  	sqldriver "github.com/RevenueMonster/sqlike/sql/driver"
    15  	sqlstmt "github.com/RevenueMonster/sqlike/sql/stmt"
    16  	"github.com/RevenueMonster/sqlike/sqlike/logs"
    17  )
    18  
    19  // ErrNoRecordAffected :
    20  var ErrNoRecordAffected = errors.New("no record affected")
    21  
    22  // ErrExpectedStruct :
    23  var ErrExpectedStruct = errors.New("expected struct as a source")
    24  
    25  // ErrEmptyFields :
    26  var ErrEmptyFields = errors.New("empty fields")
    27  
    28  // Table :
    29  type Table struct {
    30  	// current database name
    31  	dbName string
    32  
    33  	// table name
    34  	name string
    35  
    36  	// default primary key
    37  	pk string
    38  
    39  	client *Client
    40  
    41  	// sql driver
    42  	driver sqldriver.Driver
    43  
    44  	// sql dialect
    45  	dialect dialect.Dialect
    46  
    47  	// encoder and decoder for the value
    48  	codec  codec.Codecer
    49  	logger logs.Logger
    50  }
    51  
    52  // Rename : rename the current table name to new table name
    53  func (tb *Table) Rename(ctx context.Context, name string) error {
    54  	stmt := sqlstmt.AcquireStmt(tb.dialect)
    55  	defer sqlstmt.ReleaseStmt(stmt)
    56  	tb.dialect.RenameTable(stmt, tb.dbName, tb.name, name)
    57  	_, err := sqldriver.Execute(
    58  		ctx,
    59  		tb.driver,
    60  		stmt,
    61  		tb.logger,
    62  	)
    63  	return err
    64  }
    65  
    66  // Exists : this will return true when the table exists in the database
    67  func (tb *Table) Exists(ctx context.Context) bool {
    68  	var count int
    69  	stmt := sqlstmt.AcquireStmt(tb.dialect)
    70  	defer sqlstmt.ReleaseStmt(stmt)
    71  	tb.dialect.HasTable(stmt, tb.dbName, tb.name)
    72  	if err := sqldriver.QueryRowContext(
    73  		ctx,
    74  		tb.driver,
    75  		stmt,
    76  		tb.logger,
    77  	).Scan(&count); err != nil {
    78  		panic(err)
    79  	}
    80  	return count > 0
    81  }
    82  
    83  // Columns :
    84  func (tb *Table) Columns() *ColumnView {
    85  	return &ColumnView{tb: tb}
    86  }
    87  
    88  // ListColumns : list all the column of the table.
    89  func (tb *Table) ListColumns(ctx context.Context) ([]Column, error) {
    90  	stmt := sqlstmt.AcquireStmt(tb.dialect)
    91  	defer sqlstmt.ReleaseStmt(stmt)
    92  	tb.dialect.GetColumns(stmt, tb.dbName, tb.name)
    93  	rows, err := sqldriver.Query(
    94  		ctx,
    95  		tb.driver,
    96  		stmt,
    97  		tb.logger,
    98  	)
    99  	if err != nil {
   100  		return nil, err
   101  	}
   102  	defer rows.Close()
   103  
   104  	columns := make([]Column, 0)
   105  	for i := 0; rows.Next(); i++ {
   106  		col := Column{}
   107  
   108  		if err := rows.Scan(
   109  			&col.Position,
   110  			&col.Name,
   111  			&col.Type,
   112  			&col.DefaultValue,
   113  			&col.IsNullable,
   114  			&col.DataType,
   115  			&col.Charset,
   116  			&col.Collation,
   117  			&col.Comment,
   118  			&col.Extra,
   119  		); err != nil {
   120  			return nil, err
   121  		}
   122  
   123  		col.Type = strings.ToUpper(col.Type)
   124  		col.DataType = strings.ToUpper(col.DataType)
   125  
   126  		columns = append(columns, col)
   127  	}
   128  	return columns, nil
   129  }
   130  
   131  // ListIndexes : list all the index of the table.
   132  func (tb *Table) ListIndexes(ctx context.Context) ([]Index, error) {
   133  	stmt := sqlstmt.AcquireStmt(tb.dialect)
   134  	defer sqlstmt.ReleaseStmt(stmt)
   135  	tb.dialect.GetIndexes(stmt, tb.dbName, tb.name)
   136  	rows, err := sqldriver.Query(
   137  		ctx,
   138  		tb.driver,
   139  		stmt,
   140  		tb.logger,
   141  	)
   142  	if err != nil {
   143  		return nil, err
   144  	}
   145  	defer rows.Close()
   146  
   147  	idxs := make([]Index, 0)
   148  	for i := 0; rows.Next(); i++ {
   149  		idx := Index{}
   150  		if err := rows.Scan(
   151  			&idx.Name,
   152  			&idx.Type,
   153  			&idx.IsUnique,
   154  		); err != nil {
   155  			return nil, err
   156  		}
   157  		idx.IsUnique = !idx.IsUnique
   158  		idxs = append(idxs, idx)
   159  	}
   160  	return idxs, nil
   161  }
   162  
   163  // MustMigrate : this will ensure the migrate is complete, otherwise it will panic
   164  func (tb Table) MustMigrate(ctx context.Context, entity interface{}) {
   165  	err := tb.Migrate(ctx, entity)
   166  	if err != nil {
   167  		panic(err)
   168  	}
   169  }
   170  
   171  // Migrate : migrate will create a new table follows by the definition of struct tag, alter when the table already exists
   172  func (tb *Table) Migrate(ctx context.Context, entity interface{}) error {
   173  	return tb.migrateOne(ctx, tb.client.cache, entity, false)
   174  }
   175  
   176  // UnsafeMigrate : unsafe migration will delete non-exist index and columns, beware when you use this
   177  func (tb *Table) UnsafeMigrate(ctx context.Context, entity interface{}) error {
   178  	return tb.migrateOne(ctx, tb.client.cache, entity, true)
   179  }
   180  
   181  // MustUnsafeMigrate : this will panic if it get error on unsafe migrate
   182  func (tb *Table) MustUnsafeMigrate(ctx context.Context, entity interface{}) {
   183  	err := tb.migrateOne(ctx, tb.client.cache, entity, true)
   184  	if err != nil {
   185  		panic(err)
   186  	}
   187  }
   188  
   189  // Truncate : delete all the table data.
   190  func (tb *Table) Truncate(ctx context.Context) (err error) {
   191  	stmt := sqlstmt.AcquireStmt(tb.dialect)
   192  	defer sqlstmt.ReleaseStmt(stmt)
   193  	tb.dialect.TruncateTable(stmt, tb.dbName, tb.name)
   194  	_, err = sqldriver.Execute(
   195  		ctx,
   196  		tb.driver,
   197  		stmt,
   198  		tb.logger,
   199  	)
   200  	return
   201  }
   202  
   203  // DropIfExists : will drop the table only if it exists.
   204  func (tb Table) DropIfExists(ctx context.Context) (err error) {
   205  	stmt := sqlstmt.AcquireStmt(tb.dialect)
   206  	defer sqlstmt.ReleaseStmt(stmt)
   207  	tb.dialect.DropTable(stmt, tb.dbName, tb.name, true)
   208  	_, err = sqldriver.Execute(
   209  		ctx,
   210  		tb.driver,
   211  		stmt,
   212  		tb.logger,
   213  	)
   214  	return
   215  }
   216  
   217  // Drop : drop the table, but it might throw error when the table is not exists
   218  func (tb Table) Drop(ctx context.Context) (err error) {
   219  	stmt := sqlstmt.AcquireStmt(tb.dialect)
   220  	defer sqlstmt.ReleaseStmt(stmt)
   221  	tb.dialect.DropTable(stmt, tb.dbName, tb.name, false)
   222  	_, err = sqldriver.Execute(
   223  		ctx,
   224  		tb.driver,
   225  		stmt,
   226  		tb.logger,
   227  	)
   228  	return
   229  }
   230  
   231  // Replace :
   232  func (tb *Table) Replace(ctx context.Context, fields []string, query *sql.SelectStmt) error {
   233  	stmt := sqlstmt.AcquireStmt(tb.dialect)
   234  	defer sqlstmt.ReleaseStmt(stmt)
   235  	if err := tb.dialect.Replace(
   236  		stmt,
   237  		tb.dbName,
   238  		tb.name,
   239  		fields,
   240  		query,
   241  	); err != nil {
   242  		return err
   243  	}
   244  
   245  	if _, err := sqldriver.Execute(
   246  		ctx,
   247  		tb.driver,
   248  		stmt,
   249  		tb.logger,
   250  	); err != nil {
   251  		return err
   252  	}
   253  	return nil
   254  }
   255  
   256  // Indexes :
   257  func (tb *Table) Indexes() *IndexView {
   258  	return &IndexView{tb: tb}
   259  }
   260  
   261  // HasIndexByName :
   262  func (tb *Table) HasIndexByName(ctx context.Context, name string) (bool, error) {
   263  	return isIndexExists(
   264  		ctx,
   265  		tb.dbName,
   266  		tb.name,
   267  		name,
   268  		tb.driver,
   269  		tb.dialect,
   270  		tb.logger,
   271  	)
   272  }
   273  
   274  func (tb *Table) migrateOne(ctx context.Context, cache reflext.StructMapper, entity interface{}, unsafe bool) error {
   275  	v := reflext.ValueOf(entity)
   276  	if !v.IsValid() {
   277  		return ErrInvalidInput
   278  	}
   279  
   280  	t := reflext.Deref(v.Type())
   281  	if !reflext.IsKind(t, reflect.Struct) {
   282  		return ErrExpectedStruct
   283  	}
   284  
   285  	cdc := cache.CodecByType(t)
   286  	fields := skipColumns(cdc.Properties(), nil)
   287  	if len(fields) < 1 {
   288  		return ErrEmptyFields
   289  	}
   290  
   291  	if !tb.Exists(ctx) {
   292  		return tb.createTable(ctx, fields)
   293  	}
   294  
   295  	columns, err := tb.ListColumns(ctx)
   296  	if err != nil {
   297  		return err
   298  	}
   299  	idxs, err := tb.ListIndexes(ctx)
   300  	if err != nil {
   301  		return err
   302  	}
   303  	return tb.alterTable(ctx, fields, columns, idxs, unsafe)
   304  }
   305  
   306  func (tb *Table) createTable(ctx context.Context, fields []reflext.StructFielder) error {
   307  	stmt := sqlstmt.AcquireStmt(tb.dialect)
   308  	defer sqlstmt.ReleaseStmt(stmt)
   309  	if err := tb.dialect.CreateTable(
   310  		stmt,
   311  		tb.dbName,
   312  		tb.name,
   313  		tb.pk,
   314  		tb.client.DriverInfo,
   315  		fields,
   316  	); err != nil {
   317  		return err
   318  	}
   319  	if _, err := sqldriver.Execute(
   320  		ctx,
   321  		tb.driver,
   322  		stmt,
   323  		tb.logger,
   324  	); err != nil {
   325  		return err
   326  	}
   327  	return nil
   328  }
   329  
   330  func (tb *Table) alterTable(ctx context.Context, fields []reflext.StructFielder, columns []Column, indexs []Index, unsafe bool) error {
   331  	cols := make([]string, len(columns))
   332  	for i, col := range columns {
   333  		cols[i] = col.Name
   334  	}
   335  	idxs := make([]string, len(indexs))
   336  	for i, idx := range indexs {
   337  		idxs[i] = idx.Name
   338  	}
   339  	stmt := sqlstmt.AcquireStmt(tb.dialect)
   340  	defer sqlstmt.ReleaseStmt(stmt)
   341  	tb.dialect.HasPrimaryKey(stmt, tb.dbName, tb.name)
   342  	var count uint
   343  	if err := sqldriver.QueryRowContext(
   344  		ctx,
   345  		tb.driver,
   346  		stmt,
   347  		tb.logger,
   348  	).Scan(&count); err != nil {
   349  		return err
   350  	}
   351  	stmt.Reset()
   352  	if err := tb.dialect.AlterTable(
   353  		stmt,
   354  		tb.dbName, tb.name, tb.pk, count > 0,
   355  		tb.client.DriverInfo,
   356  		fields, cols, idxs, unsafe,
   357  	); err != nil {
   358  		return err
   359  	}
   360  	if _, err := sqldriver.Execute(
   361  		ctx,
   362  		tb.driver,
   363  		stmt,
   364  		tb.logger,
   365  	); err != nil {
   366  		return err
   367  	}
   368  	return nil
   369  }