github.com/ecodeclub/eorm@v0.0.2-0.20231001112437-dae71da914d0/sharding_update.go (about) 1 // Copyright 2021 ecodeclub 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package eorm 16 17 import ( 18 "context" 19 "database/sql" 20 "sync" 21 22 "go.uber.org/multierr" 23 24 "github.com/ecodeclub/eorm/internal/errs" 25 "github.com/ecodeclub/eorm/internal/sharding" 26 "github.com/valyala/bytebufferpool" 27 ) 28 29 var _ sharding.Executor = &ShardingUpdater[any]{} 30 31 type ShardingUpdater[T any] struct { 32 table *T 33 lock sync.Mutex 34 db Session 35 shardingUpdaterBuilder 36 } 37 38 // NewShardingUpdater 开始构建一个 Sharding UPDATE 查询 39 func NewShardingUpdater[T any](sess Session) *ShardingUpdater[T] { 40 b := shardingUpdaterBuilder{} 41 b.core = sess.getCore() 42 b.buffer = bytebufferpool.Get() 43 return &ShardingUpdater[T]{ 44 shardingUpdaterBuilder: b, 45 db: sess, 46 } 47 } 48 49 func (s *ShardingUpdater[T]) Update(val *T) *ShardingUpdater[T] { 50 s.table = val 51 return s 52 } 53 54 func (s *ShardingUpdater[T]) Set(assigns ...Assignable) *ShardingUpdater[T] { 55 s.assigns = assigns 56 return s 57 } 58 59 func (s *ShardingUpdater[T]) Where(predicates ...Predicate) *ShardingUpdater[T] { 60 s.where = predicates 61 return s 62 } 63 64 // Build returns UPDATE []sharding.Query 65 func (s *ShardingUpdater[T]) Build(ctx context.Context) ([]sharding.Query, error) { 66 if s.table == nil { 67 s.table = new(T) 68 } 69 var err error 70 if s.meta == nil { 71 s.meta, err = s.metaRegistry.Get(s.table) 72 if err != nil { 73 return nil, err 74 } 75 } 76 shardingRes, err := s.findDst(ctx, s.where...) 77 if err != nil { 78 return nil, err 79 } 80 81 res := make([]sharding.Query, 0, len(shardingRes.Dsts)) 82 defer bytebufferpool.Put(s.buffer) 83 for _, dst := range shardingRes.Dsts { 84 q, err := s.buildQuery(dst.DB, dst.Table, dst.Name) 85 if err != nil { 86 return nil, err 87 } 88 res = append(res, q) 89 s.args = nil 90 s.buffer.Reset() 91 } 92 return res, nil 93 } 94 95 func (s *ShardingUpdater[T]) buildQuery(db, tbl, ds string) (sharding.Query, error) { 96 var err error 97 98 s.val = s.valCreator.NewPrimitiveValue(s.table, s.meta) 99 100 s.writeString("UPDATE ") 101 s.quote(db) 102 s.writeByte('.') 103 s.quote(tbl) 104 s.writeString(" SET ") 105 if len(s.assigns) == 0 { 106 err = s.buildDefaultColumns() 107 } else { 108 err = s.buildAssigns() 109 } 110 if err != nil { 111 return sharding.EmptyQuery, err 112 } 113 114 if len(s.where) > 0 { 115 s.writeString(" WHERE ") 116 err = s.buildPredicates(s.where) 117 if err != nil { 118 return sharding.EmptyQuery, err 119 } 120 } 121 s.end() 122 123 return sharding.Query{SQL: s.buffer.String(), Args: s.args, Datasource: ds, DB: db}, nil 124 } 125 126 func (s *ShardingUpdater[T]) buildAssigns() error { 127 has := false 128 shardingKey := s.meta.ShardingAlgorithm.ShardingKeys()[0] 129 for _, assign := range s.assigns { 130 if has { 131 s.comma() 132 } 133 switch a := assign.(type) { 134 case Column: 135 if a.name == shardingKey { 136 return errs.NewErrUpdateShardingKeyUnsupported(a.name) 137 } 138 c, ok := s.meta.FieldMap[a.name] 139 if !ok { 140 return errs.NewInvalidFieldError(a.name) 141 } 142 refVal, err := s.val.Field(a.name) 143 if err != nil { 144 return err 145 } 146 s.quote(c.ColumnName) 147 _ = s.buffer.WriteByte('=') 148 s.parameter(refVal.Interface()) 149 has = true 150 case columns: 151 for _, name := range a.cs { 152 if name == shardingKey { 153 return errs.NewErrUpdateShardingKeyUnsupported(name) 154 } 155 c, ok := s.meta.FieldMap[name] 156 if !ok { 157 return errs.NewInvalidFieldError(name) 158 } 159 refVal, err := s.val.Field(name) 160 if err != nil { 161 return err 162 } 163 if has { 164 s.comma() 165 } 166 s.quote(c.ColumnName) 167 _ = s.buffer.WriteByte('=') 168 s.parameter(refVal.Interface()) 169 has = true 170 } 171 case Assignment: 172 if err := s.buildExpr(binaryExpr(a)); err != nil { 173 return err 174 } 175 has = true 176 default: 177 return errs.ErrUnsupportedAssignment 178 } 179 } 180 if !has { 181 return errs.NewValueNotSetError() 182 } 183 return nil 184 } 185 186 func (s *ShardingUpdater[T]) buildDefaultColumns() error { 187 has := false 188 shardingKey := s.meta.ShardingAlgorithm.ShardingKeys()[0] 189 for _, c := range s.meta.Columns { 190 fieldName := c.FieldName 191 if fieldName == shardingKey { 192 continue 193 } 194 refVal, _ := s.val.Field(fieldName) 195 if s.ignoreZeroVal && isZeroValue(refVal) { 196 continue 197 } 198 if s.ignoreNilVal && isNilValue(refVal) { 199 continue 200 } 201 if has { 202 _ = s.buffer.WriteByte(',') 203 } 204 s.quote(c.ColumnName) 205 _ = s.buffer.WriteByte('=') 206 s.parameter(refVal.Interface()) 207 has = true 208 } 209 if !has { 210 return errs.NewValueNotSetError() 211 } 212 return nil 213 } 214 215 func (s *ShardingUpdater[T]) SkipNilValue() *ShardingUpdater[T] { 216 s.ignoreNilVal = true 217 return s 218 } 219 220 func (s *ShardingUpdater[T]) SkipZeroValue() *ShardingUpdater[T] { 221 s.ignoreZeroVal = true 222 return s 223 } 224 225 func (s *ShardingUpdater[T]) Exec(ctx context.Context) sharding.Result { 226 qs, err := s.Build(ctx) 227 if err != nil { 228 return sharding.NewResult(nil, err) 229 } 230 errList := make([]error, len(qs)) 231 resList := make([]sql.Result, len(qs)) 232 var wg sync.WaitGroup 233 wg.Add(len(qs)) 234 for idx, q := range qs { 235 go func(idx int, q Query) { 236 defer wg.Done() 237 res, err := s.db.execContext(ctx, q) 238 s.lock.Lock() 239 errList[idx] = err 240 resList[idx] = res 241 s.lock.Unlock() 242 }(idx, q) 243 } 244 wg.Wait() 245 shardingRes := sharding.NewResult(resList, multierr.Combine(errList...)) 246 return shardingRes 247 }