github.com/ecodeclub/eorm@v0.0.2-0.20231001112437-dae71da914d0/sharding_insert.go (about)

     1  // Copyright 2021 ecodeclub
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  // http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package eorm
    16  
    17  import (
    18  	"context"
    19  	"database/sql"
    20  	"errors"
    21  	"reflect"
    22  	"sync"
    23  
    24  	"github.com/ecodeclub/ekit/mapx"
    25  
    26  	"github.com/ecodeclub/eorm/internal/errs"
    27  	"github.com/ecodeclub/eorm/internal/model"
    28  	"github.com/ecodeclub/eorm/internal/sharding"
    29  	"github.com/valyala/bytebufferpool"
    30  	"go.uber.org/multierr"
    31  )
    32  
    33  var _ sharding.Executor = &ShardingInserter[any]{}
    34  
    35  type ShardingInserter[T any] struct {
    36  	shardingInserterBuilder
    37  	values []*T
    38  	db     Session
    39  	lock   sync.RWMutex
    40  }
    41  
    42  func (si *ShardingInserter[T]) Build(ctx context.Context) ([]sharding.Query, error) {
    43  	defer bytebufferpool.Put(si.buffer)
    44  	var err error
    45  	if len(si.values) == 0 {
    46  		return nil, errors.New("插入0行")
    47  	}
    48  	si.meta, err = si.metaRegistry.Get(si.values[0])
    49  	if err != nil {
    50  		return nil, err
    51  	}
    52  	colMetaData, err := si.getColumns()
    53  	if err != nil {
    54  		return nil, err
    55  	}
    56  	skNames := si.meta.ShardingAlgorithm.ShardingKeys()
    57  	if err := si.checkColumns(colMetaData, skNames); err != nil {
    58  		return nil, err
    59  	}
    60  
    61  	// ds-db => 目标表
    62  	//dsDBMap, err := mapx.NewTreeMap[key, *mapx.TreeMap[key, []*T]](compareDSDB)
    63  	dsDBTabMap, err := mapx.NewMultiTreeMap[sharding.Dst, *T](sharding.CompareDSDBTab)
    64  	if err != nil {
    65  		return nil, err
    66  	}
    67  	for _, value := range si.values {
    68  		dst, err := si.findDst(ctx, value)
    69  		if err != nil {
    70  			return nil, err
    71  		}
    72  		// 一个value只能命中一个库表如果不满足就报错
    73  		if len(dst.Dsts) != 1 {
    74  			return nil, errs.ErrInsertFindingDst
    75  		}
    76  		err = dsDBTabMap.Put(dst.Dsts[0], value)
    77  		if err != nil {
    78  			return nil, err
    79  		}
    80  	}
    81  
    82  	// 针对每一个目标库,生成一个 insert 语句
    83  	//dsDBKeys := dsDBMap.Keys()
    84  	dsts := dsDBTabMap.Keys()
    85  	ansQuery := make([]sharding.Query, 0, len(dsts))
    86  	for _, dst := range dsts {
    87  		vals, _ := dsDBTabMap.Get(dst)
    88  		err = si.buildQuery(dst.DB, dst.Table, colMetaData, vals)
    89  		if err != nil {
    90  			return nil, err
    91  		}
    92  		ansQuery = append(ansQuery, sharding.Query{
    93  			SQL:        si.buffer.String(),
    94  			Args:       si.args,
    95  			DB:         dst.DB,
    96  			Datasource: dst.Name,
    97  		})
    98  		si.buffer.Reset()
    99  		si.args = []any{}
   100  	}
   101  	return ansQuery, nil
   102  }
   103  
   104  func (si *ShardingInserter[T]) buildQuery(db, table string, colMetas []*model.ColumnMeta, values []*T) error {
   105  	var err error
   106  	si.writeString("INSERT INTO ")
   107  	si.quote(db)
   108  	si.writeByte('.')
   109  	si.quote(table)
   110  	si.writeString("(")
   111  	err = si.buildColumns(colMetas)
   112  	if err != nil {
   113  		return err
   114  	}
   115  	si.writeString(")")
   116  	si.writeString(" VALUES")
   117  	for index, val := range values {
   118  		if index > 0 {
   119  			si.comma()
   120  		}
   121  		si.writeString("(")
   122  		refVal := si.valCreator.NewPrimitiveValue(val, si.meta)
   123  		for j, v := range colMetas {
   124  			fdVal, err := refVal.Field(v.FieldName)
   125  			if err != nil {
   126  				return err
   127  			}
   128  			si.parameter(fdVal.Interface())
   129  			if j != len(colMetas)-1 {
   130  				si.comma()
   131  			}
   132  		}
   133  		si.writeString(")")
   134  	}
   135  	si.end()
   136  	return nil
   137  }
   138  
   139  // checkColumns 判断sk是否存在于meta中,如果不存在会返回报错
   140  func (*ShardingInserter[T]) checkColumns(colMetas []*model.ColumnMeta, sks []string) error {
   141  	colMetasMap := make(map[string]struct{}, len(colMetas))
   142  	for _, colMeta := range colMetas {
   143  		colMetasMap[colMeta.FieldName] = struct{}{}
   144  	}
   145  	for _, sk := range sks {
   146  		if _, ok := colMetasMap[sk]; !ok {
   147  			return errs.ErrInsertShardingKeyNotFound
   148  		}
   149  	}
   150  	return nil
   151  }
   152  
   153  func (si *ShardingInserter[T]) findDst(ctx context.Context, val *T) (sharding.Response, error) {
   154  	sks := si.meta.ShardingAlgorithm.ShardingKeys()
   155  	skValues := make(map[string]any)
   156  	for _, sk := range sks {
   157  		refVal := reflect.ValueOf(val).Elem().FieldByName(sk).Interface()
   158  		skValues[sk] = refVal
   159  	}
   160  	return si.meta.ShardingAlgorithm.Sharding(ctx, sharding.Request{
   161  		Op:       opEQ,
   162  		SkValues: skValues,
   163  	})
   164  }
   165  
   166  func (si *ShardingInserter[T]) getColumns() ([]*model.ColumnMeta, error) {
   167  	cs := make([]*model.ColumnMeta, 0, len(si.columns))
   168  	if len(si.columns) != 0 {
   169  		for _, c := range si.columns {
   170  			v, isOk := si.meta.FieldMap[c]
   171  			if !isOk {
   172  				return cs, errs.NewInvalidFieldError(c)
   173  			}
   174  			cs = append(cs, v)
   175  		}
   176  	} else {
   177  		for _, val := range si.meta.Columns {
   178  			if si.ignorePK && val.IsPrimaryKey {
   179  				continue
   180  			}
   181  			cs = append(cs, val)
   182  		}
   183  	}
   184  	return cs, nil
   185  }
   186  
   187  func (si *ShardingInserter[T]) buildColumns(colMetas []*model.ColumnMeta) error {
   188  	for idx, colMeta := range colMetas {
   189  		si.quote(colMeta.ColumnName)
   190  		if idx != len(colMetas)-1 {
   191  			si.comma()
   192  		}
   193  	}
   194  	return nil
   195  }
   196  
   197  func (si *ShardingInserter[T]) Values(values []*T) *ShardingInserter[T] {
   198  	si.values = values
   199  	return si
   200  }
   201  
   202  func (si *ShardingInserter[T]) Columns(cols []string) *ShardingInserter[T] {
   203  	si.columns = cols
   204  	return si
   205  }
   206  
   207  func (si *ShardingInserter[T]) IgnorePK() *ShardingInserter[T] {
   208  	si.ignorePK = true
   209  	return si
   210  }
   211  
   212  func NewShardingInsert[T any](db Session) *ShardingInserter[T] {
   213  	b := shardingInserterBuilder{}
   214  	b.core = db.getCore()
   215  	b.buffer = bytebufferpool.Get()
   216  	b.columns = []string{}
   217  	return &ShardingInserter[T]{
   218  		db:                      db,
   219  		shardingInserterBuilder: b,
   220  	}
   221  }
   222  
   223  func (si *ShardingInserter[T]) Exec(ctx context.Context) sharding.Result {
   224  	qs, err := si.Build(ctx)
   225  	if err != nil {
   226  		return sharding.NewResult(nil, err)
   227  	}
   228  	errList := make([]error, len(qs))
   229  	resList := make([]sql.Result, len(qs))
   230  	var wg sync.WaitGroup
   231  	wg.Add(len(qs))
   232  	for idx, q := range qs {
   233  		go func(idx int, q Query) {
   234  			defer wg.Done()
   235  			res, er := si.db.execContext(ctx, q)
   236  			si.lock.Lock()
   237  			errList[idx] = er
   238  			resList[idx] = res
   239  			si.lock.Unlock()
   240  		}(idx, q)
   241  	}
   242  	wg.Wait()
   243  	shardingRes := sharding.NewResult(resList, multierr.Combine(errList...))
   244  	return shardingRes
   245  }