github.com/RevenueMonster/sqlike@v1.0.6/sqlike/database.go (about)

     1  package sqlike
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"errors"
     7  	"fmt"
     8  	"io/ioutil"
     9  	"os"
    10  	"path/filepath"
    11  	"strings"
    12  	"time"
    13  
    14  	"github.com/RevenueMonster/sqlike/sql/codec"
    15  	"github.com/RevenueMonster/sqlike/sql/dialect"
    16  	"github.com/RevenueMonster/sqlike/sql/driver"
    17  	sqlstmt "github.com/RevenueMonster/sqlike/sql/stmt"
    18  	"github.com/RevenueMonster/sqlike/sqlike/indexes"
    19  	"github.com/RevenueMonster/sqlike/sqlike/logs"
    20  	"github.com/RevenueMonster/sqlike/sqlike/options"
    21  	"gopkg.in/yaml.v3"
    22  )
    23  
    24  type txCallback func(ctx SessionContext) error
    25  
    26  // Database :
    27  type Database struct {
    28  	driverName string
    29  	name       string
    30  	pk         string
    31  	client     *Client
    32  	driver     driver.Driver
    33  	dialect    dialect.Dialect
    34  	codec      codec.Codecer
    35  	logger     logs.Logger
    36  }
    37  
    38  // Name : to get current database name
    39  func (db *Database) Name() string {
    40  	return db.name
    41  }
    42  
    43  // Table : use the table under this database
    44  func (db *Database) Table(name string) *Table {
    45  	return &Table{
    46  		dbName:  db.name,
    47  		name:    name,
    48  		pk:      db.pk,
    49  		client:  db.client,
    50  		driver:  db.driver,
    51  		dialect: db.dialect,
    52  		codec:   db.codec,
    53  		logger:  db.logger,
    54  	}
    55  }
    56  
    57  func (db *Database) QueryRow(ctx context.Context, query string, args ...interface{}) SingleResult {
    58  	rslt := new(Result)
    59  	rslt.cache = db.client.cache
    60  	rslt.codec = db.codec
    61  	rows, err := db.driver.QueryContext(ctx, query, args...)
    62  	if err != nil {
    63  		rslt.err = err
    64  		return rslt
    65  	}
    66  	rslt.rows = rows
    67  	rslt.columnTypes, rslt.err = rows.ColumnTypes()
    68  	if rslt.err != nil {
    69  		defer rslt.rows.Close()
    70  	}
    71  	for _, col := range rslt.columnTypes {
    72  		rslt.columns = append(rslt.columns, col.Name())
    73  	}
    74  	rslt.close = true
    75  	if !rslt.Next() {
    76  		rslt.err = sql.ErrNoRows
    77  	}
    78  	return rslt
    79  }
    80  
    81  // QueryStmt :
    82  func (db *Database) QueryStmt(ctx context.Context, query interface{}) (*Result, error) {
    83  	if query == nil {
    84  		return nil, errors.New("sqlike: empty query statement")
    85  	}
    86  
    87  	stmt := sqlstmt.AcquireStmt(db.dialect)
    88  	defer sqlstmt.ReleaseStmt(stmt)
    89  	if err := db.dialect.SelectStmt(stmt, query); err != nil {
    90  		return nil, err
    91  	}
    92  
    93  	rows, err := driver.Query(
    94  		ctx,
    95  		db.driver,
    96  		stmt,
    97  		getLogger(db.logger, true),
    98  	)
    99  	if err != nil {
   100  		return nil, err
   101  	}
   102  
   103  	rslt := new(Result)
   104  	rslt.cache = db.client.cache
   105  	rslt.codec = db.codec
   106  	rslt.rows = rows
   107  	rslt.columnTypes, rslt.err = rows.ColumnTypes()
   108  	if rslt.err != nil {
   109  		defer rslt.rows.Close()
   110  	}
   111  	for _, col := range rslt.columnTypes {
   112  		rslt.columns = append(rslt.columns, col.Name())
   113  	}
   114  	return rslt, rslt.err
   115  }
   116  
   117  // BeginTransaction :
   118  func (db *Database) BeginTransaction(ctx context.Context, opts ...*sql.TxOptions) (*Transaction, error) {
   119  	opt := &sql.TxOptions{}
   120  	if len(opts) > 0 {
   121  		opt = opts[0]
   122  	}
   123  	return db.beginTrans(ctx, opt)
   124  }
   125  
   126  func (db *Database) beginTrans(ctx context.Context, opt *sql.TxOptions) (*Transaction, error) {
   127  	tx, err := db.client.BeginTx(ctx, opt)
   128  	if err != nil {
   129  		return nil, err
   130  	}
   131  	return &Transaction{
   132  		Context: ctx,
   133  		dbName:  db.name,
   134  		pk:      db.pk,
   135  		client:  db.client,
   136  		driver:  tx,
   137  		dialect: db.dialect,
   138  		logger:  db.logger,
   139  		codec:   db.codec,
   140  	}, nil
   141  }
   142  
   143  // RunInTransaction :
   144  func (db *Database) RunInTransaction(ctx context.Context, cb txCallback, opts ...*options.TransactionOptions) error {
   145  	opt := new(options.TransactionOptions)
   146  	if len(opts) > 0 && opts[0] != nil {
   147  		opt = opts[0]
   148  	}
   149  	duration := 60 * time.Second
   150  	if opt.Duration.Seconds() > 0 {
   151  		duration = opt.Duration
   152  	}
   153  	c, cancel := context.WithTimeout(ctx, duration)
   154  	defer cancel()
   155  	tx, err := db.beginTrans(c, &sql.TxOptions{
   156  		Isolation: opt.IsolationLevel,
   157  		ReadOnly:  opt.ReadOnly,
   158  	})
   159  	if err != nil {
   160  		return err
   161  	}
   162  	defer tx.RollbackTransaction()
   163  	if err := cb(tx); err != nil {
   164  		return err
   165  	}
   166  	return tx.CommitTransaction()
   167  }
   168  
   169  type indexDefinition struct {
   170  	Indexes []struct {
   171  		Table   string `yaml:"table"`
   172  		Name    string `yaml:"name"`
   173  		Type    string `yaml:"type"`
   174  		Cast    string `yaml:"cast"`
   175  		As      string `yaml:"as"`
   176  		Comment string `yaml:"comment"`
   177  		Columns []struct {
   178  			Name      string `yaml:"name"`
   179  			Direction string `yaml:"direction"`
   180  		} `yaml:"columns"`
   181  	} `yaml:"indexes"`
   182  }
   183  
   184  // BuildIndexes :
   185  func (db *Database) BuildIndexes(ctx context.Context, paths ...string) error {
   186  	var (
   187  		path string
   188  		err  error
   189  		fi   os.FileInfo
   190  	)
   191  	if len(paths) > 0 {
   192  		path = paths[0]
   193  		fi, err = os.Stat(path)
   194  		if err != nil {
   195  			return err
   196  		}
   197  	} else {
   198  		pwd, _ := os.Getwd()
   199  		files := []string{pwd + "/index.yml", pwd + "/index.yaml"}
   200  		for _, f := range files {
   201  			fi, err = os.Stat(f)
   202  			if !os.IsNotExist(err) {
   203  				path = f
   204  				break
   205  			}
   206  		}
   207  		if err != nil {
   208  			return err
   209  		}
   210  	}
   211  
   212  	switch v := fi.Mode(); {
   213  	case v.IsDir():
   214  		if err := filepath.Walk(path, func(fp string, info os.FileInfo, err error) error {
   215  			if info.IsDir() {
   216  				return nil
   217  			}
   218  
   219  			ext := filepath.Ext(info.Name())
   220  			// only interested on yaml and yml files
   221  			if ext != ".yaml" && ext != ".yml" {
   222  				return nil
   223  			}
   224  
   225  			return db.readAndBuildIndexes(ctx, fp)
   226  		}); err != nil {
   227  			return err
   228  		}
   229  
   230  	case v.IsRegular():
   231  		if err := db.readAndBuildIndexes(ctx, path); err != nil {
   232  			return err
   233  		}
   234  	}
   235  
   236  	return nil
   237  }
   238  
   239  func (db *Database) readAndBuildIndexes(ctx context.Context, path string) error {
   240  	var id indexDefinition
   241  	b, err := ioutil.ReadFile(path)
   242  	if err != nil {
   243  		return err
   244  	}
   245  	if err := yaml.Unmarshal(b, &id); err != nil {
   246  		return err
   247  	}
   248  
   249  	for _, idx := range id.Indexes {
   250  		length := len(idx.Columns)
   251  		columns := make([]indexes.Col, length)
   252  		for i, col := range idx.Columns {
   253  			dir := indexes.Ascending
   254  			col.Direction = strings.TrimSpace(strings.ToLower(col.Direction))
   255  			if col.Direction == "desc" || col.Direction == "descending" {
   256  				dir = indexes.Descending
   257  			}
   258  			columns[i] = indexes.Col{
   259  				Name:      col.Name,
   260  				Direction: dir,
   261  			}
   262  		}
   263  
   264  		index := indexes.Index{
   265  			Name:    strings.TrimSpace(idx.Name),
   266  			Type:    parseIndexType(idx.Type),
   267  			Cast:    strings.TrimSpace(idx.Cast),
   268  			As:      strings.TrimSpace(idx.As),
   269  			Columns: columns,
   270  			Comment: strings.TrimSpace(idx.Comment),
   271  		}
   272  
   273  		if exists, err := isIndexExists(
   274  			ctx,
   275  			db.name,
   276  			idx.Table,
   277  			index.GetName(),
   278  			db.driver,
   279  			db.dialect,
   280  			db.logger,
   281  		); err != nil {
   282  			return err
   283  		} else if exists {
   284  			continue
   285  		}
   286  
   287  		iv := db.Table(idx.Table).Indexes()
   288  		if err := iv.CreateOne(ctx, index); err != nil {
   289  			return err
   290  		}
   291  	}
   292  	return nil
   293  }
   294  
   295  func parseIndexType(name string) (idxType indexes.Type) {
   296  	name = strings.TrimSpace(strings.ToLower(name))
   297  	if name == "" {
   298  		idxType = indexes.BTree
   299  		return
   300  	}
   301  
   302  	switch name {
   303  	case "spatial":
   304  		idxType = indexes.Spatial
   305  	case "unique":
   306  		idxType = indexes.Unique
   307  	case "btree":
   308  		idxType = indexes.BTree
   309  	case "fulltext":
   310  		idxType = indexes.FullText
   311  	case "primary":
   312  		idxType = indexes.Primary
   313  	case "multi-valued":
   314  		idxType = indexes.MultiValued
   315  	default:
   316  		panic(fmt.Errorf("invalid index type %q", name))
   317  	}
   318  	return
   319  }