github.com/ecodeclub/eorm@v0.0.2-0.20231001112437-dae71da914d0/transaction.go (about) 1 // Copyright 2021 ecodeclub 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package eorm 16 17 import ( 18 "context" 19 "database/sql" 20 21 "github.com/ecodeclub/ekit/list" 22 "github.com/ecodeclub/ekit/mapx" 23 "github.com/ecodeclub/ekit/sqlx" 24 "github.com/ecodeclub/eorm/internal/rows" 25 "github.com/valyala/bytebufferpool" 26 "golang.org/x/sync/errgroup" 27 28 "github.com/ecodeclub/eorm/internal/datasource" 29 ) 30 31 type Tx struct { 32 baseSession 33 tx datasource.Tx 34 } 35 36 func (t *Tx) queryMulti(ctx context.Context, qs []Query) (list.List[rows.Rows], error) { 37 // 事务在查询的时候,需要将同一个 DB 上的语句合并在一起 38 // 参考 https://github.com/ecodeclub/eorm/discussions/213 39 mp := mapx.NewMultiBuiltinMap[string, Query](len(qs)) 40 for _, q := range qs { 41 if err := mp.Put(q.DB+"_"+q.Datasource, q); err != nil { 42 return nil, err 43 } 44 } 45 keys := mp.Keys() 46 rowsList := &list.ConcurrentList[rows.Rows]{ 47 List: list.NewArrayList[rows.Rows](len(keys)), 48 } 49 var eg errgroup.Group 50 for _, key := range keys { 51 dbQs, _ := mp.Get(key) 52 eg.Go(func() error { 53 return t.execDBQueries(ctx, dbQs, rowsList) 54 }) 55 } 56 return rowsList, eg.Wait() 57 } 58 59 // execDBQueries 执行某个 DB 上的全部查询。 60 // 执行结果会被加入进去 rowsList 里面。虽然这种修改传入参数的做法不是很好,但是作为一个内部方法还是可以接受的。 61 func (t *Tx) execDBQueries(ctx context.Context, dbQs []Query, rowsList *list.ConcurrentList[rows.Rows]) error { 62 qsCnt := len(dbQs) 63 // 考虑到大部分都只有一个查询,我们做一个快路径的优化。 64 if qsCnt == 1 { 65 rs, err := t.tx.Query(ctx, dbQs[0]) 66 if err != nil { 67 return err 68 } 69 return rowsList.Append(rs) 70 } 71 // 慢路径,也就是必须要把同一个库的查询合并在一起 72 q := t.mergeDBQueries(dbQs) 73 rs, err := t.tx.Query(ctx, q) 74 if err != nil { 75 return err 76 } 77 // 查询之后,事务必须再次按照结果集分割开。 78 // 这样是为了让结果集的数量和查询数量保持一致。 79 return t.splitTxResultSet(rowsList, rs) 80 } 81 82 func (t *Tx) splitTxResultSet(list list.List[rows.Rows], rs *sql.Rows) error { 83 cs, err := rs.Columns() 84 if err != nil { 85 return err 86 } 87 ct, err := rs.ColumnTypes() 88 if err != nil { 89 return err 90 } 91 scanner, err := sqlx.NewSQLRowsScanner(rs) 92 if err != nil { 93 return err 94 } 95 // 虽然这里我们可以尝试不读取最后一个 ResultSet 96 // 但是这个优化目前来说不准备做, 97 // 防止用户出现因为类型转换遇到一些潜在的问题 98 // 数据库类型到 GO 类型再到用户希望的类型,是一个漫长的过程。 99 hasNext := true 100 for hasNext { 101 var data [][]any 102 data, err = scanner.ScanAll() 103 if err != nil { 104 return err 105 } 106 err = list.Append(rows.NewDataRows(data, cs, ct)) 107 if err != nil { 108 return err 109 } 110 hasNext = scanner.NextResultSet() 111 } 112 return nil 113 } 114 115 func (t *Tx) mergeDBQueries(dbQs []Query) Query { 116 buffer := bytebufferpool.Get() 117 defer bytebufferpool.Put(buffer) 118 first := dbQs[0] 119 // 预估有多少查询参数,一个查询的参数个数 * 查询个数 120 args := make([]any, 0, len(first.Args)*len(dbQs)) 121 for _, dbQ := range dbQs { 122 _, _ = buffer.WriteString(dbQ.SQL) 123 args = append(args, dbQ.Args...) 124 } 125 return Query{ 126 SQL: buffer.String(), 127 Args: args, 128 DB: first.DB, 129 Datasource: first.Datasource, 130 } 131 } 132 133 func (t *Tx) Commit() error { 134 return t.tx.Commit() 135 } 136 137 func (t *Tx) Rollback() error { 138 return t.tx.Rollback() 139 }