github.com/ecodeclub/eorm@v0.0.2-0.20231001112437-dae71da914d0/sharding_insert.go (about) 1 // Copyright 2021 ecodeclub 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package eorm 16 17 import ( 18 "context" 19 "database/sql" 20 "errors" 21 "reflect" 22 "sync" 23 24 "github.com/ecodeclub/ekit/mapx" 25 26 "github.com/ecodeclub/eorm/internal/errs" 27 "github.com/ecodeclub/eorm/internal/model" 28 "github.com/ecodeclub/eorm/internal/sharding" 29 "github.com/valyala/bytebufferpool" 30 "go.uber.org/multierr" 31 ) 32 33 var _ sharding.Executor = &ShardingInserter[any]{} 34 35 type ShardingInserter[T any] struct { 36 shardingInserterBuilder 37 values []*T 38 db Session 39 lock sync.RWMutex 40 } 41 42 func (si *ShardingInserter[T]) Build(ctx context.Context) ([]sharding.Query, error) { 43 defer bytebufferpool.Put(si.buffer) 44 var err error 45 if len(si.values) == 0 { 46 return nil, errors.New("插入0行") 47 } 48 si.meta, err = si.metaRegistry.Get(si.values[0]) 49 if err != nil { 50 return nil, err 51 } 52 colMetaData, err := si.getColumns() 53 if err != nil { 54 return nil, err 55 } 56 skNames := si.meta.ShardingAlgorithm.ShardingKeys() 57 if err := si.checkColumns(colMetaData, skNames); err != nil { 58 return nil, err 59 } 60 61 // ds-db => 目标表 62 //dsDBMap, err := mapx.NewTreeMap[key, *mapx.TreeMap[key, []*T]](compareDSDB) 63 dsDBTabMap, err := mapx.NewMultiTreeMap[sharding.Dst, *T](sharding.CompareDSDBTab) 64 if err != nil { 65 return nil, err 66 } 67 for _, value := range si.values { 68 dst, err := si.findDst(ctx, value) 69 if err != nil { 70 return nil, err 71 } 72 // 一个value只能命中一个库表如果不满足就报错 73 if len(dst.Dsts) != 1 { 74 return nil, errs.ErrInsertFindingDst 75 } 76 err = dsDBTabMap.Put(dst.Dsts[0], value) 77 if err != nil { 78 return nil, err 79 } 80 } 81 82 // 针对每一个目标库,生成一个 insert 语句 83 //dsDBKeys := dsDBMap.Keys() 84 dsts := dsDBTabMap.Keys() 85 ansQuery := make([]sharding.Query, 0, len(dsts)) 86 for _, dst := range dsts { 87 vals, _ := dsDBTabMap.Get(dst) 88 err = si.buildQuery(dst.DB, dst.Table, colMetaData, vals) 89 if err != nil { 90 return nil, err 91 } 92 ansQuery = append(ansQuery, sharding.Query{ 93 SQL: si.buffer.String(), 94 Args: si.args, 95 DB: dst.DB, 96 Datasource: dst.Name, 97 }) 98 si.buffer.Reset() 99 si.args = []any{} 100 } 101 return ansQuery, nil 102 } 103 104 func (si *ShardingInserter[T]) buildQuery(db, table string, colMetas []*model.ColumnMeta, values []*T) error { 105 var err error 106 si.writeString("INSERT INTO ") 107 si.quote(db) 108 si.writeByte('.') 109 si.quote(table) 110 si.writeString("(") 111 err = si.buildColumns(colMetas) 112 if err != nil { 113 return err 114 } 115 si.writeString(")") 116 si.writeString(" VALUES") 117 for index, val := range values { 118 if index > 0 { 119 si.comma() 120 } 121 si.writeString("(") 122 refVal := si.valCreator.NewPrimitiveValue(val, si.meta) 123 for j, v := range colMetas { 124 fdVal, err := refVal.Field(v.FieldName) 125 if err != nil { 126 return err 127 } 128 si.parameter(fdVal.Interface()) 129 if j != len(colMetas)-1 { 130 si.comma() 131 } 132 } 133 si.writeString(")") 134 } 135 si.end() 136 return nil 137 } 138 139 // checkColumns 判断sk是否存在于meta中,如果不存在会返回报错 140 func (*ShardingInserter[T]) checkColumns(colMetas []*model.ColumnMeta, sks []string) error { 141 colMetasMap := make(map[string]struct{}, len(colMetas)) 142 for _, colMeta := range colMetas { 143 colMetasMap[colMeta.FieldName] = struct{}{} 144 } 145 for _, sk := range sks { 146 if _, ok := colMetasMap[sk]; !ok { 147 return errs.ErrInsertShardingKeyNotFound 148 } 149 } 150 return nil 151 } 152 153 func (si *ShardingInserter[T]) findDst(ctx context.Context, val *T) (sharding.Response, error) { 154 sks := si.meta.ShardingAlgorithm.ShardingKeys() 155 skValues := make(map[string]any) 156 for _, sk := range sks { 157 refVal := reflect.ValueOf(val).Elem().FieldByName(sk).Interface() 158 skValues[sk] = refVal 159 } 160 return si.meta.ShardingAlgorithm.Sharding(ctx, sharding.Request{ 161 Op: opEQ, 162 SkValues: skValues, 163 }) 164 } 165 166 func (si *ShardingInserter[T]) getColumns() ([]*model.ColumnMeta, error) { 167 cs := make([]*model.ColumnMeta, 0, len(si.columns)) 168 if len(si.columns) != 0 { 169 for _, c := range si.columns { 170 v, isOk := si.meta.FieldMap[c] 171 if !isOk { 172 return cs, errs.NewInvalidFieldError(c) 173 } 174 cs = append(cs, v) 175 } 176 } else { 177 for _, val := range si.meta.Columns { 178 if si.ignorePK && val.IsPrimaryKey { 179 continue 180 } 181 cs = append(cs, val) 182 } 183 } 184 return cs, nil 185 } 186 187 func (si *ShardingInserter[T]) buildColumns(colMetas []*model.ColumnMeta) error { 188 for idx, colMeta := range colMetas { 189 si.quote(colMeta.ColumnName) 190 if idx != len(colMetas)-1 { 191 si.comma() 192 } 193 } 194 return nil 195 } 196 197 func (si *ShardingInserter[T]) Values(values []*T) *ShardingInserter[T] { 198 si.values = values 199 return si 200 } 201 202 func (si *ShardingInserter[T]) Columns(cols []string) *ShardingInserter[T] { 203 si.columns = cols 204 return si 205 } 206 207 func (si *ShardingInserter[T]) IgnorePK() *ShardingInserter[T] { 208 si.ignorePK = true 209 return si 210 } 211 212 func NewShardingInsert[T any](db Session) *ShardingInserter[T] { 213 b := shardingInserterBuilder{} 214 b.core = db.getCore() 215 b.buffer = bytebufferpool.Get() 216 b.columns = []string{} 217 return &ShardingInserter[T]{ 218 db: db, 219 shardingInserterBuilder: b, 220 } 221 } 222 223 func (si *ShardingInserter[T]) Exec(ctx context.Context) sharding.Result { 224 qs, err := si.Build(ctx) 225 if err != nil { 226 return sharding.NewResult(nil, err) 227 } 228 errList := make([]error, len(qs)) 229 resList := make([]sql.Result, len(qs)) 230 var wg sync.WaitGroup 231 wg.Add(len(qs)) 232 for idx, q := range qs { 233 go func(idx int, q Query) { 234 defer wg.Done() 235 res, er := si.db.execContext(ctx, q) 236 si.lock.Lock() 237 errList[idx] = er 238 resList[idx] = res 239 si.lock.Unlock() 240 }(idx, q) 241 } 242 wg.Wait() 243 shardingRes := sharding.NewResult(resList, multierr.Combine(errList...)) 244 return shardingRes 245 }