github.com/ecodeclub/eorm@v0.0.2-0.20231001112437-dae71da914d0/db.go (about) 1 // Copyright 2021 ecodeclub 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package eorm 16 17 import ( 18 "context" 19 "database/sql" 20 21 "github.com/ecodeclub/eorm/internal/datasource" 22 "github.com/ecodeclub/eorm/internal/datasource/single" 23 "github.com/ecodeclub/eorm/internal/dialect" 24 "github.com/ecodeclub/eorm/internal/errs" 25 "github.com/ecodeclub/eorm/internal/model" 26 "github.com/ecodeclub/eorm/internal/valuer" 27 ) 28 29 const ( 30 SELECT = "SELECT" 31 DELETE = "DELETE" 32 UPDATE = "UPDATE" 33 INSERT = "INSERT" 34 RAW = "RAW" 35 ) 36 37 // DBOption configure DB 38 type DBOption func(db *DB) 39 40 // DB represents a database 41 type DB struct { 42 baseSession 43 ds datasource.DataSource 44 } 45 46 // DBWithMiddlewares 为 db 配置 Middleware 47 func DBWithMiddlewares(ms ...Middleware) DBOption { 48 return func(db *DB) { 49 db.ms = ms 50 } 51 } 52 53 func DBWithMetaRegistry(r model.MetaRegistry) DBOption { 54 return func(db *DB) { 55 db.metaRegistry = r 56 } 57 } 58 59 func UseReflection() DBOption { 60 return func(db *DB) { 61 db.valCreator = valuer.PrimitiveCreator{Creator: valuer.NewUnsafeValue} 62 } 63 } 64 65 // Open 创建一个 ORM 实例 66 // 注意该实例是一个无状态的对象,你应该尽可能复用它 67 func Open(driver string, dsn string, opts ...DBOption) (*DB, error) { 68 db, err := single.OpenDB(driver, dsn) 69 if err != nil { 70 return nil, err 71 } 72 return OpenDS(driver, db, opts...) 73 } 74 75 func OpenDS(driver string, ds datasource.DataSource, opts ...DBOption) (*DB, error) { 76 dl, err := dialect.Of(driver) 77 if err != nil { 78 return nil, err 79 } 80 orm := &DB{ 81 baseSession: baseSession{ 82 executor: ds, 83 core: core{ 84 metaRegistry: model.NewMetaRegistry(), 85 dialect: dl, 86 // 可以设为默认,因为原本这里也有默认 87 valCreator: valuer.PrimitiveCreator{ 88 Creator: valuer.NewUnsafeValue, 89 }, 90 }, 91 }, 92 ds: ds, 93 } 94 for _, o := range opts { 95 o(orm) 96 } 97 return orm, nil 98 } 99 100 func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) { 101 inst, ok := db.ds.(datasource.TxBeginner) 102 if !ok { 103 return nil, errs.ErrNotCompleteTxBeginner 104 } 105 tx, err := inst.BeginTx(ctx, opts) 106 if err != nil { 107 return nil, err 108 } 109 return &Tx{tx: tx, baseSession: baseSession{ 110 executor: tx, 111 core: db.core, 112 }}, nil 113 } 114 115 func (db *DB) Close() error { 116 return db.ds.Close() 117 }