github.com/ecodeclub/eorm@v0.0.2-0.20231001112437-dae71da914d0/db.go (about)

     1  // Copyright 2021 ecodeclub
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  // http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package eorm
    16  
    17  import (
    18  	"context"
    19  	"database/sql"
    20  
    21  	"github.com/ecodeclub/eorm/internal/datasource"
    22  	"github.com/ecodeclub/eorm/internal/datasource/single"
    23  	"github.com/ecodeclub/eorm/internal/dialect"
    24  	"github.com/ecodeclub/eorm/internal/errs"
    25  	"github.com/ecodeclub/eorm/internal/model"
    26  	"github.com/ecodeclub/eorm/internal/valuer"
    27  )
    28  
    29  const (
    30  	SELECT = "SELECT"
    31  	DELETE = "DELETE"
    32  	UPDATE = "UPDATE"
    33  	INSERT = "INSERT"
    34  	RAW    = "RAW"
    35  )
    36  
    37  // DBOption configure DB
    38  type DBOption func(db *DB)
    39  
    40  // DB represents a database
    41  type DB struct {
    42  	baseSession
    43  	ds datasource.DataSource
    44  }
    45  
    46  // DBWithMiddlewares 为 db 配置 Middleware
    47  func DBWithMiddlewares(ms ...Middleware) DBOption {
    48  	return func(db *DB) {
    49  		db.ms = ms
    50  	}
    51  }
    52  
    53  func DBWithMetaRegistry(r model.MetaRegistry) DBOption {
    54  	return func(db *DB) {
    55  		db.metaRegistry = r
    56  	}
    57  }
    58  
    59  func UseReflection() DBOption {
    60  	return func(db *DB) {
    61  		db.valCreator = valuer.PrimitiveCreator{Creator: valuer.NewUnsafeValue}
    62  	}
    63  }
    64  
    65  // Open 创建一个 ORM 实例
    66  // 注意该实例是一个无状态的对象,你应该尽可能复用它
    67  func Open(driver string, dsn string, opts ...DBOption) (*DB, error) {
    68  	db, err := single.OpenDB(driver, dsn)
    69  	if err != nil {
    70  		return nil, err
    71  	}
    72  	return OpenDS(driver, db, opts...)
    73  }
    74  
    75  func OpenDS(driver string, ds datasource.DataSource, opts ...DBOption) (*DB, error) {
    76  	dl, err := dialect.Of(driver)
    77  	if err != nil {
    78  		return nil, err
    79  	}
    80  	orm := &DB{
    81  		baseSession: baseSession{
    82  			executor: ds,
    83  			core: core{
    84  				metaRegistry: model.NewMetaRegistry(),
    85  				dialect:      dl,
    86  				// 可以设为默认,因为原本这里也有默认
    87  				valCreator: valuer.PrimitiveCreator{
    88  					Creator: valuer.NewUnsafeValue,
    89  				},
    90  			},
    91  		},
    92  		ds: ds,
    93  	}
    94  	for _, o := range opts {
    95  		o(orm)
    96  	}
    97  	return orm, nil
    98  }
    99  
   100  func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) {
   101  	inst, ok := db.ds.(datasource.TxBeginner)
   102  	if !ok {
   103  		return nil, errs.ErrNotCompleteTxBeginner
   104  	}
   105  	tx, err := inst.BeginTx(ctx, opts)
   106  	if err != nil {
   107  		return nil, err
   108  	}
   109  	return &Tx{tx: tx, baseSession: baseSession{
   110  		executor: tx,
   111  		core:     db.core,
   112  	}}, nil
   113  }
   114  
   115  func (db *DB) Close() error {
   116  	return db.ds.Close()
   117  }