github.com/ecodeclub/eorm@v0.0.2-0.20231001112437-dae71da914d0/sharding_update.go (about)

     1  // Copyright 2021 ecodeclub
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  // http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package eorm
    16  
    17  import (
    18  	"context"
    19  	"database/sql"
    20  	"sync"
    21  
    22  	"go.uber.org/multierr"
    23  
    24  	"github.com/ecodeclub/eorm/internal/errs"
    25  	"github.com/ecodeclub/eorm/internal/sharding"
    26  	"github.com/valyala/bytebufferpool"
    27  )
    28  
    29  var _ sharding.Executor = &ShardingUpdater[any]{}
    30  
    31  type ShardingUpdater[T any] struct {
    32  	table *T
    33  	lock  sync.Mutex
    34  	db    Session
    35  	shardingUpdaterBuilder
    36  }
    37  
    38  // NewShardingUpdater 开始构建一个 Sharding UPDATE 查询
    39  func NewShardingUpdater[T any](sess Session) *ShardingUpdater[T] {
    40  	b := shardingUpdaterBuilder{}
    41  	b.core = sess.getCore()
    42  	b.buffer = bytebufferpool.Get()
    43  	return &ShardingUpdater[T]{
    44  		shardingUpdaterBuilder: b,
    45  		db:                     sess,
    46  	}
    47  }
    48  
    49  func (s *ShardingUpdater[T]) Update(val *T) *ShardingUpdater[T] {
    50  	s.table = val
    51  	return s
    52  }
    53  
    54  func (s *ShardingUpdater[T]) Set(assigns ...Assignable) *ShardingUpdater[T] {
    55  	s.assigns = assigns
    56  	return s
    57  }
    58  
    59  func (s *ShardingUpdater[T]) Where(predicates ...Predicate) *ShardingUpdater[T] {
    60  	s.where = predicates
    61  	return s
    62  }
    63  
    64  // Build returns UPDATE []sharding.Query
    65  func (s *ShardingUpdater[T]) Build(ctx context.Context) ([]sharding.Query, error) {
    66  	if s.table == nil {
    67  		s.table = new(T)
    68  	}
    69  	var err error
    70  	if s.meta == nil {
    71  		s.meta, err = s.metaRegistry.Get(s.table)
    72  		if err != nil {
    73  			return nil, err
    74  		}
    75  	}
    76  	shardingRes, err := s.findDst(ctx, s.where...)
    77  	if err != nil {
    78  		return nil, err
    79  	}
    80  
    81  	res := make([]sharding.Query, 0, len(shardingRes.Dsts))
    82  	defer bytebufferpool.Put(s.buffer)
    83  	for _, dst := range shardingRes.Dsts {
    84  		q, err := s.buildQuery(dst.DB, dst.Table, dst.Name)
    85  		if err != nil {
    86  			return nil, err
    87  		}
    88  		res = append(res, q)
    89  		s.args = nil
    90  		s.buffer.Reset()
    91  	}
    92  	return res, nil
    93  }
    94  
    95  func (s *ShardingUpdater[T]) buildQuery(db, tbl, ds string) (sharding.Query, error) {
    96  	var err error
    97  
    98  	s.val = s.valCreator.NewPrimitiveValue(s.table, s.meta)
    99  
   100  	s.writeString("UPDATE ")
   101  	s.quote(db)
   102  	s.writeByte('.')
   103  	s.quote(tbl)
   104  	s.writeString(" SET ")
   105  	if len(s.assigns) == 0 {
   106  		err = s.buildDefaultColumns()
   107  	} else {
   108  		err = s.buildAssigns()
   109  	}
   110  	if err != nil {
   111  		return sharding.EmptyQuery, err
   112  	}
   113  
   114  	if len(s.where) > 0 {
   115  		s.writeString(" WHERE ")
   116  		err = s.buildPredicates(s.where)
   117  		if err != nil {
   118  			return sharding.EmptyQuery, err
   119  		}
   120  	}
   121  	s.end()
   122  
   123  	return sharding.Query{SQL: s.buffer.String(), Args: s.args, Datasource: ds, DB: db}, nil
   124  }
   125  
   126  func (s *ShardingUpdater[T]) buildAssigns() error {
   127  	has := false
   128  	shardingKey := s.meta.ShardingAlgorithm.ShardingKeys()[0]
   129  	for _, assign := range s.assigns {
   130  		if has {
   131  			s.comma()
   132  		}
   133  		switch a := assign.(type) {
   134  		case Column:
   135  			if a.name == shardingKey {
   136  				return errs.NewErrUpdateShardingKeyUnsupported(a.name)
   137  			}
   138  			c, ok := s.meta.FieldMap[a.name]
   139  			if !ok {
   140  				return errs.NewInvalidFieldError(a.name)
   141  			}
   142  			refVal, err := s.val.Field(a.name)
   143  			if err != nil {
   144  				return err
   145  			}
   146  			s.quote(c.ColumnName)
   147  			_ = s.buffer.WriteByte('=')
   148  			s.parameter(refVal.Interface())
   149  			has = true
   150  		case columns:
   151  			for _, name := range a.cs {
   152  				if name == shardingKey {
   153  					return errs.NewErrUpdateShardingKeyUnsupported(name)
   154  				}
   155  				c, ok := s.meta.FieldMap[name]
   156  				if !ok {
   157  					return errs.NewInvalidFieldError(name)
   158  				}
   159  				refVal, err := s.val.Field(name)
   160  				if err != nil {
   161  					return err
   162  				}
   163  				if has {
   164  					s.comma()
   165  				}
   166  				s.quote(c.ColumnName)
   167  				_ = s.buffer.WriteByte('=')
   168  				s.parameter(refVal.Interface())
   169  				has = true
   170  			}
   171  		case Assignment:
   172  			if err := s.buildExpr(binaryExpr(a)); err != nil {
   173  				return err
   174  			}
   175  			has = true
   176  		default:
   177  			return errs.ErrUnsupportedAssignment
   178  		}
   179  	}
   180  	if !has {
   181  		return errs.NewValueNotSetError()
   182  	}
   183  	return nil
   184  }
   185  
   186  func (s *ShardingUpdater[T]) buildDefaultColumns() error {
   187  	has := false
   188  	shardingKey := s.meta.ShardingAlgorithm.ShardingKeys()[0]
   189  	for _, c := range s.meta.Columns {
   190  		fieldName := c.FieldName
   191  		if fieldName == shardingKey {
   192  			continue
   193  		}
   194  		refVal, _ := s.val.Field(fieldName)
   195  		if s.ignoreZeroVal && isZeroValue(refVal) {
   196  			continue
   197  		}
   198  		if s.ignoreNilVal && isNilValue(refVal) {
   199  			continue
   200  		}
   201  		if has {
   202  			_ = s.buffer.WriteByte(',')
   203  		}
   204  		s.quote(c.ColumnName)
   205  		_ = s.buffer.WriteByte('=')
   206  		s.parameter(refVal.Interface())
   207  		has = true
   208  	}
   209  	if !has {
   210  		return errs.NewValueNotSetError()
   211  	}
   212  	return nil
   213  }
   214  
   215  func (s *ShardingUpdater[T]) SkipNilValue() *ShardingUpdater[T] {
   216  	s.ignoreNilVal = true
   217  	return s
   218  }
   219  
   220  func (s *ShardingUpdater[T]) SkipZeroValue() *ShardingUpdater[T] {
   221  	s.ignoreZeroVal = true
   222  	return s
   223  }
   224  
   225  func (s *ShardingUpdater[T]) Exec(ctx context.Context) sharding.Result {
   226  	qs, err := s.Build(ctx)
   227  	if err != nil {
   228  		return sharding.NewResult(nil, err)
   229  	}
   230  	errList := make([]error, len(qs))
   231  	resList := make([]sql.Result, len(qs))
   232  	var wg sync.WaitGroup
   233  	wg.Add(len(qs))
   234  	for idx, q := range qs {
   235  		go func(idx int, q Query) {
   236  			defer wg.Done()
   237  			res, err := s.db.execContext(ctx, q)
   238  			s.lock.Lock()
   239  			errList[idx] = err
   240  			resList[idx] = res
   241  			s.lock.Unlock()
   242  		}(idx, q)
   243  	}
   244  	wg.Wait()
   245  	shardingRes := sharding.NewResult(resList, multierr.Combine(errList...))
   246  	return shardingRes
   247  }