github.com/RevenueMonster/sqlike@v1.0.6/sqlike/client.go (about)

     1  package sqlike
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"strings"
     7  
     8  	semver "github.com/Masterminds/semver/v3"
     9  	"github.com/RevenueMonster/sqlike/reflext"
    10  	"github.com/RevenueMonster/sqlike/sql/charset"
    11  	"github.com/RevenueMonster/sqlike/sql/codec"
    12  	"github.com/RevenueMonster/sqlike/sql/dialect"
    13  	"github.com/RevenueMonster/sqlike/sql/driver"
    14  	sqlstmt "github.com/RevenueMonster/sqlike/sql/stmt"
    15  	"github.com/RevenueMonster/sqlike/sqlike/logs"
    16  	"github.com/RevenueMonster/sqlike/sqlike/options"
    17  	"github.com/bxcodec/dbresolver"
    18  )
    19  
    20  // DriverInfo :
    21  type DriverInfo struct {
    22  	driverName string
    23  	version    *semver.Version
    24  	charSet    charset.Code
    25  	collate    string
    26  }
    27  
    28  // DriverName :
    29  func (d *DriverInfo) DriverName() string {
    30  	return d.driverName
    31  }
    32  
    33  // Version :
    34  func (d *DriverInfo) Version() *semver.Version {
    35  	return d.version
    36  }
    37  
    38  // Charset :
    39  func (d *DriverInfo) Charset() charset.Code {
    40  	return d.charSet
    41  }
    42  
    43  // Collate :
    44  func (d *DriverInfo) Collate() string {
    45  	return d.collate
    46  }
    47  
    48  // Client : sqlike client is a client embedded with *sql.DB, so you may use any apis of *sql.DB
    49  type Client struct {
    50  	*DriverInfo
    51  	*sql.DB
    52  	pk      string
    53  	logger  logs.Logger
    54  	cache   reflext.StructMapper
    55  	codec   codec.Codecer
    56  	dialect dialect.Dialect
    57  }
    58  
    59  // newClient : create a new client struct by providing driver, *sql.DB, dialect etc
    60  func newClient(ctx context.Context, driver string, db *sql.DB, dialect dialect.Dialect, code charset.Code, collate string) (*Client, error) {
    61  	driver = strings.TrimSpace(strings.ToLower(driver))
    62  	client := &Client{
    63  		DB:      db,
    64  		dialect: dialect,
    65  	}
    66  	client.pk = "$Key"
    67  	client.DriverInfo = new(DriverInfo)
    68  	client.driverName = driver
    69  	client.charSet = code
    70  	client.collate = collate
    71  	client.cache = reflext.DefaultMapper
    72  	client.codec = codec.DefaultRegistry
    73  	client.version = client.getVersion(ctx)
    74  	return client, nil
    75  }
    76  
    77  // SetLogger : this is to set the logger for debugging, it will panic if the logger input is nil
    78  func (c *Client) SetLogger(logger logs.Logger) *Client {
    79  	if logger == nil {
    80  		panic("logger cannot be nil")
    81  	}
    82  	c.logger = logger
    83  	return c
    84  }
    85  
    86  // SetPrimaryKey : this will set a default primary key for subsequent operation such as Insert, InsertOne, ModifyOne
    87  func (c *Client) SetPrimaryKey(pk string) *Client {
    88  	c.pk = pk
    89  	return c
    90  }
    91  
    92  // SetCodec : Codec is a component which handling the :
    93  // 1. encoding between input data and driver.Valuer
    94  // 2. decoding between output data and sql.Scanner
    95  func (c *Client) SetCodec(cdc codec.Codecer) *Client {
    96  	c.codec = cdc
    97  	return c
    98  }
    99  
   100  // SetStructMapper : StructMapper is a mapper to reflect a struct on runtime and provide struct info
   101  func (c *Client) SetStructMapper(mapper reflext.StructMapper) *Client {
   102  	c.cache = mapper
   103  	return c
   104  }
   105  
   106  // CreateDatabase : create database with name
   107  func (c *Client) CreateDatabase(ctx context.Context, name string) error {
   108  	return c.createDB(ctx, name, true)
   109  }
   110  
   111  // DropDatabase : drop the selected database
   112  func (c *Client) DropDatabase(ctx context.Context, name string) error {
   113  	return c.dropDB(ctx, name, true)
   114  }
   115  
   116  // ListDatabases : list all the database on current connection
   117  func (c *Client) ListDatabases(ctx context.Context) ([]string, error) {
   118  	stmt := sqlstmt.AcquireStmt(c.dialect)
   119  	defer sqlstmt.ReleaseStmt(stmt)
   120  	c.dialect.GetDatabases(stmt)
   121  	rows, err := driver.Query(
   122  		ctx,
   123  		c.DB,
   124  		stmt,
   125  		c.logger,
   126  	)
   127  	if err != nil {
   128  		return nil, err
   129  	}
   130  	defer rows.Close()
   131  	dbs := make([]string, 0)
   132  	for i := 0; rows.Next(); i++ {
   133  		dbs = append(dbs, "")
   134  		if err := rows.Scan(&dbs[i]); err != nil {
   135  			return nil, err
   136  		}
   137  	}
   138  	return dbs, nil
   139  }
   140  
   141  // Database : this api will execute `USE database`, which will point your current connection to selected database
   142  func (c *Client) Database(name string, connections ...*options.ConnectOptions) *Database {
   143  	stmt := sqlstmt.AcquireStmt(c.dialect)
   144  	defer sqlstmt.ReleaseStmt(stmt)
   145  	c.dialect.UseDatabase(stmt, name)
   146  	if _, err := driver.Execute(context.Background(), c.DB, stmt, c.logger); err != nil {
   147  		panic(err)
   148  	}
   149  
   150  	if len(connections) == 0 {
   151  		return &Database{
   152  			driverName: c.driverName,
   153  			name:       name,
   154  			pk:         c.pk,
   155  			client:     c,
   156  			dialect:    c.dialect,
   157  			driver:     c.DB,
   158  			logger:     c.logger,
   159  			codec:      c.codec,
   160  		}
   161  	}
   162  
   163  	replicas := make([]*sql.DB, len(connections)+1)
   164  	replicas[0] = c.DB
   165  
   166  	dialect := dialect.GetDialectByDriver(c.driverName)
   167  	for idx, connection := range connections {
   168  		connStr := dialect.Connect(connection)
   169  		db, err := sql.Open(c.driverName, connStr)
   170  		if err != nil {
   171  			panic(err)
   172  		}
   173  
   174  		driver.Execute(context.Background(), db, stmt, c.logger)
   175  		replicas[idx+1] = db
   176  	}
   177  
   178  	resolver := dbresolver.WrapDBs(replicas...)
   179  	return &Database{
   180  		driverName: c.driverName,
   181  		name:       name,
   182  		pk:         c.pk,
   183  		client:     c,
   184  		dialect:    c.dialect,
   185  		driver:     resolver,
   186  		logger:     c.logger,
   187  		codec:      c.codec,
   188  	}
   189  
   190  }
   191  
   192  // getVersion is a internal function to get sql driver's version
   193  func (c *Client) getVersion(ctx context.Context) (version *semver.Version) {
   194  	var (
   195  		ver string
   196  		err error
   197  	)
   198  	stmt := sqlstmt.AcquireStmt(c.dialect)
   199  	defer sqlstmt.ReleaseStmt(stmt)
   200  	c.dialect.GetVersion(stmt)
   201  	err = driver.QueryRowContext(
   202  		ctx,
   203  		c.DB,
   204  		stmt,
   205  		c.logger,
   206  	).Scan(&ver)
   207  	if err != nil {
   208  		panic(err)
   209  	}
   210  	paths := strings.Split(ver, "-")
   211  	version, err = semver.NewVersion(paths[0])
   212  	if err != nil {
   213  		panic(err)
   214  	}
   215  	return
   216  }
   217  
   218  // createDB is a internal function for create a database
   219  func (c *Client) createDB(ctx context.Context, name string, checkExists bool) error {
   220  	stmt := sqlstmt.AcquireStmt(c.dialect)
   221  	defer sqlstmt.ReleaseStmt(stmt)
   222  	c.dialect.CreateDatabase(stmt, name, checkExists)
   223  	_, err := driver.Execute(
   224  		ctx,
   225  		c.DB,
   226  		stmt,
   227  		c.logger,
   228  	)
   229  	return err
   230  }
   231  
   232  // dropDB is a internal function for drop a database
   233  func (c *Client) dropDB(ctx context.Context, name string, checkExists bool) error {
   234  	stmt := sqlstmt.AcquireStmt(c.dialect)
   235  	defer sqlstmt.ReleaseStmt(stmt)
   236  	c.dialect.DropDatabase(stmt, name, checkExists)
   237  	_, err := driver.Execute(
   238  		ctx,
   239  		c.DB,
   240  		stmt,
   241  		c.logger,
   242  	)
   243  	return err
   244  }