github.com/ecodeclub/eorm@v0.0.2-0.20231001112437-dae71da914d0/transaction.go (about)

     1  // Copyright 2021 ecodeclub
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  // http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package eorm
    16  
    17  import (
    18  	"context"
    19  	"database/sql"
    20  
    21  	"github.com/ecodeclub/ekit/list"
    22  	"github.com/ecodeclub/ekit/mapx"
    23  	"github.com/ecodeclub/ekit/sqlx"
    24  	"github.com/ecodeclub/eorm/internal/rows"
    25  	"github.com/valyala/bytebufferpool"
    26  	"golang.org/x/sync/errgroup"
    27  
    28  	"github.com/ecodeclub/eorm/internal/datasource"
    29  )
    30  
    31  type Tx struct {
    32  	baseSession
    33  	tx datasource.Tx
    34  }
    35  
    36  func (t *Tx) queryMulti(ctx context.Context, qs []Query) (list.List[rows.Rows], error) {
    37  	// 事务在查询的时候,需要将同一个 DB 上的语句合并在一起
    38  	// 参考 https://github.com/ecodeclub/eorm/discussions/213
    39  	mp := mapx.NewMultiBuiltinMap[string, Query](len(qs))
    40  	for _, q := range qs {
    41  		if err := mp.Put(q.DB+"_"+q.Datasource, q); err != nil {
    42  			return nil, err
    43  		}
    44  	}
    45  	keys := mp.Keys()
    46  	rowsList := &list.ConcurrentList[rows.Rows]{
    47  		List: list.NewArrayList[rows.Rows](len(keys)),
    48  	}
    49  	var eg errgroup.Group
    50  	for _, key := range keys {
    51  		dbQs, _ := mp.Get(key)
    52  		eg.Go(func() error {
    53  			return t.execDBQueries(ctx, dbQs, rowsList)
    54  		})
    55  	}
    56  	return rowsList, eg.Wait()
    57  }
    58  
    59  // execDBQueries 执行某个 DB 上的全部查询。
    60  // 执行结果会被加入进去 rowsList 里面。虽然这种修改传入参数的做法不是很好,但是作为一个内部方法还是可以接受的。
    61  func (t *Tx) execDBQueries(ctx context.Context, dbQs []Query, rowsList *list.ConcurrentList[rows.Rows]) error {
    62  	qsCnt := len(dbQs)
    63  	// 考虑到大部分都只有一个查询,我们做一个快路径的优化。
    64  	if qsCnt == 1 {
    65  		rs, err := t.tx.Query(ctx, dbQs[0])
    66  		if err != nil {
    67  			return err
    68  		}
    69  		return rowsList.Append(rs)
    70  	}
    71  	// 慢路径,也就是必须要把同一个库的查询合并在一起
    72  	q := t.mergeDBQueries(dbQs)
    73  	rs, err := t.tx.Query(ctx, q)
    74  	if err != nil {
    75  		return err
    76  	}
    77  	// 查询之后,事务必须再次按照结果集分割开。
    78  	// 这样是为了让结果集的数量和查询数量保持一致。
    79  	return t.splitTxResultSet(rowsList, rs)
    80  }
    81  
    82  func (t *Tx) splitTxResultSet(list list.List[rows.Rows], rs *sql.Rows) error {
    83  	cs, err := rs.Columns()
    84  	if err != nil {
    85  		return err
    86  	}
    87  	ct, err := rs.ColumnTypes()
    88  	if err != nil {
    89  		return err
    90  	}
    91  	scanner, err := sqlx.NewSQLRowsScanner(rs)
    92  	if err != nil {
    93  		return err
    94  	}
    95  	// 虽然这里我们可以尝试不读取最后一个 ResultSet
    96  	// 但是这个优化目前来说不准备做,
    97  	// 防止用户出现因为类型转换遇到一些潜在的问题
    98  	// 数据库类型到 GO 类型再到用户希望的类型,是一个漫长的过程。
    99  	hasNext := true
   100  	for hasNext {
   101  		var data [][]any
   102  		data, err = scanner.ScanAll()
   103  		if err != nil {
   104  			return err
   105  		}
   106  		err = list.Append(rows.NewDataRows(data, cs, ct))
   107  		if err != nil {
   108  			return err
   109  		}
   110  		hasNext = scanner.NextResultSet()
   111  	}
   112  	return nil
   113  }
   114  
   115  func (t *Tx) mergeDBQueries(dbQs []Query) Query {
   116  	buffer := bytebufferpool.Get()
   117  	defer bytebufferpool.Put(buffer)
   118  	first := dbQs[0]
   119  	// 预估有多少查询参数,一个查询的参数个数 * 查询个数
   120  	args := make([]any, 0, len(first.Args)*len(dbQs))
   121  	for _, dbQ := range dbQs {
   122  		_, _ = buffer.WriteString(dbQ.SQL)
   123  		args = append(args, dbQ.Args...)
   124  	}
   125  	return Query{
   126  		SQL:        buffer.String(),
   127  		Args:       args,
   128  		DB:         first.DB,
   129  		Datasource: first.Datasource,
   130  	}
   131  }
   132  
   133  func (t *Tx) Commit() error {
   134  	return t.tx.Commit()
   135  }
   136  
   137  func (t *Tx) Rollback() error {
   138  	return t.tx.Rollback()
   139  }