github.com/ecodeclub/eorm@v0.0.2-0.20231001112437-dae71da914d0/internal/datasource/cluster/cluster_db.go (about)

     1  // Copyright 2021 ecodeclub
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  // http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package cluster
    16  
    17  import (
    18  	"context"
    19  	"database/sql"
    20  	"fmt"
    21  
    22  	"github.com/ecodeclub/eorm/internal/datasource/transaction"
    23  
    24  	"github.com/ecodeclub/eorm/internal/datasource"
    25  	"github.com/ecodeclub/eorm/internal/datasource/masterslave"
    26  	"github.com/ecodeclub/eorm/internal/errs"
    27  	"go.uber.org/multierr"
    28  )
    29  
    30  var _ datasource.TxBeginner = &clusterDB{}
    31  var _ datasource.DataSource = &clusterDB{}
    32  var _ datasource.Finder = &clusterDB{}
    33  
    34  // clusterDB 以 DB 名称作为索引目标数据库
    35  type clusterDB struct {
    36  	// DataSource  应实现为 *masterSlavesDB
    37  	masterSlavesDBs map[string]*masterslave.MasterSlavesDB
    38  }
    39  
    40  func (c *clusterDB) Query(ctx context.Context, query datasource.Query) (*sql.Rows, error) {
    41  	ms, err := c.getTgt(query)
    42  	if err != nil {
    43  		return nil, err
    44  	}
    45  	return ms.Query(ctx, query)
    46  }
    47  
    48  func (c *clusterDB) Exec(ctx context.Context, query datasource.Query) (sql.Result, error) {
    49  	ms, ok := c.masterSlavesDBs[query.DB]
    50  	if !ok {
    51  		return nil, errs.NewErrNotFoundTargetDB(query.DB)
    52  	}
    53  	return ms.Exec(ctx, query)
    54  }
    55  
    56  func (c *clusterDB) Close() error {
    57  	var err error
    58  	for name, inst := range c.masterSlavesDBs {
    59  		if er := inst.Close(); er != nil {
    60  			err = multierr.Combine(
    61  				err, fmt.Errorf("masterslave DB name [%s] error: %w", name, er))
    62  		}
    63  	}
    64  	return err
    65  }
    66  
    67  func (c *clusterDB) FindTgt(_ context.Context, query datasource.Query) (datasource.TxBeginner, error) {
    68  	db, err := c.getTgt(query)
    69  	if err != nil {
    70  		return nil, err
    71  	}
    72  	return db, nil
    73  }
    74  
    75  func (c *clusterDB) getTgt(query datasource.Query) (*masterslave.MasterSlavesDB, error) {
    76  	db, ok := c.masterSlavesDBs[query.DB]
    77  	if !ok {
    78  		return nil, errs.NewErrNotFoundTargetDB(query.DB)
    79  	}
    80  	return db, nil
    81  }
    82  
    83  func (c *clusterDB) BeginTx(ctx context.Context, opts *sql.TxOptions) (datasource.Tx, error) {
    84  	facade, err := transaction.NewTxFacade(ctx, c)
    85  	if err != nil {
    86  		return nil, err
    87  	}
    88  
    89  	return facade.BeginTx(ctx, opts)
    90  }
    91  
    92  func NewClusterDB(ms map[string]*masterslave.MasterSlavesDB) datasource.DataSource {
    93  	return &clusterDB{masterSlavesDBs: ms}
    94  }