github.com/ecodeclub/eorm@v0.0.2-0.20231001112437-dae71da914d0/internal/datasource/shardingsource/sharding_datasource.go (about) 1 // Copyright 2021 ecodeclub 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package shardingsource 16 17 import ( 18 "context" 19 "database/sql" 20 "fmt" 21 22 "github.com/ecodeclub/eorm/internal/datasource/transaction" 23 24 "github.com/ecodeclub/eorm/internal/datasource" 25 "go.uber.org/multierr" 26 27 "github.com/ecodeclub/eorm/internal/errs" 28 ) 29 30 var _ datasource.TxBeginner = &ShardingDataSource{} 31 var _ datasource.DataSource = &ShardingDataSource{} 32 var _ datasource.Finder = &ShardingDataSource{} 33 34 type ShardingDataSource struct { 35 sources map[string]datasource.DataSource 36 } 37 38 func (s *ShardingDataSource) Query(ctx context.Context, query datasource.Query) (*sql.Rows, error) { 39 ds, err := s.getTgt(query) 40 if err != nil { 41 return nil, err 42 } 43 return ds.Query(ctx, query) 44 } 45 46 func (s *ShardingDataSource) Exec(ctx context.Context, query datasource.Query) (sql.Result, error) { 47 ds, err := s.getTgt(query) 48 if err != nil { 49 return nil, err 50 } 51 return ds.Exec(ctx, query) 52 } 53 54 func (s *ShardingDataSource) FindTgt(ctx context.Context, query datasource.Query) (datasource.TxBeginner, error) { 55 ds, err := s.getTgt(query) 56 if err != nil { 57 return nil, err 58 } 59 f, ok := ds.(datasource.Finder) 60 if !ok { 61 return nil, errs.NewErrNotCompleteFinder(query.Datasource) 62 } 63 return f.FindTgt(ctx, query) 64 } 65 66 func (s *ShardingDataSource) getTgt(query datasource.Query) (datasource.DataSource, error) { 67 ds, ok := s.sources[query.Datasource] 68 if !ok { 69 return nil, errs.NewErrNotFoundTargetDataSource(query.Datasource) 70 } 71 return ds, nil 72 } 73 74 func (s *ShardingDataSource) BeginTx(ctx context.Context, opts *sql.TxOptions) (datasource.Tx, error) { 75 facade, err := transaction.NewTxFacade(ctx, s) 76 if err != nil { 77 return nil, err 78 } 79 return facade.BeginTx(ctx, opts) 80 } 81 82 func NewShardingDataSource(m map[string]datasource.DataSource) datasource.DataSource { 83 return &ShardingDataSource{ 84 sources: m, 85 } 86 } 87 88 func (s *ShardingDataSource) Close() error { 89 var err error 90 for name, inst := range s.sources { 91 if er := inst.Close(); er != nil { 92 err = multierr.Combine( 93 err, fmt.Errorf("source name [%s] error: %w", name, er)) 94 } 95 } 96 return err 97 }