github.com/matrixorigin/matrixone@v1.2.0/pkg/vm/engine/memoryengine/shard_hash.go (about)

     1  // Copyright 2022 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package memoryengine
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  	"hash/fnv"
    21  	"sort"
    22  	"unsafe"
    23  
    24  	"github.com/matrixorigin/matrixone/pkg/common/moerr"
    25  	"github.com/matrixorigin/matrixone/pkg/common/mpool"
    26  	"github.com/matrixorigin/matrixone/pkg/container/batch"
    27  	"github.com/matrixorigin/matrixone/pkg/container/types"
    28  	"github.com/matrixorigin/matrixone/pkg/container/vector"
    29  	"github.com/matrixorigin/matrixone/pkg/pb/metadata"
    30  	"github.com/matrixorigin/matrixone/pkg/vm/engine"
    31  )
    32  
    33  type HashShard struct {
    34  	mp *mpool.MPool
    35  }
    36  
    37  func NewHashShard(mp *mpool.MPool) *HashShard {
    38  	return &HashShard{
    39  		mp: mp,
    40  	}
    41  }
    42  
    43  func (*HashShard) Batch(
    44  	ctx context.Context,
    45  	tableID ID,
    46  	getDefs getDefsFunc,
    47  	bat *batch.Batch,
    48  	nodes []metadata.TNService,
    49  ) (
    50  	sharded []*ShardedBatch,
    51  	err error,
    52  ) {
    53  
    54  	// get defs
    55  	defs, err := getDefs(ctx)
    56  	if err != nil {
    57  		return nil, err
    58  	}
    59  
    60  	// get shard key
    61  	var primaryAttrs []engine.Attribute
    62  	for _, def := range defs {
    63  		attr, ok := def.(*engine.AttributeDef)
    64  		if !ok {
    65  			continue
    66  		}
    67  		if attr.Attr.Primary {
    68  			primaryAttrs = append(primaryAttrs, attr.Attr)
    69  		}
    70  	}
    71  	sort.Slice(primaryAttrs, func(i, j int) bool {
    72  		return primaryAttrs[i].Name < primaryAttrs[j].Name
    73  	})
    74  	if len(primaryAttrs) == 0 {
    75  		// no shard key
    76  		return nil, nil
    77  	}
    78  	type keyInfo struct {
    79  		Attr  engine.Attribute
    80  		Index int
    81  	}
    82  	var infos []keyInfo
    83  	for _, attr := range primaryAttrs {
    84  		for i, name := range bat.Attrs {
    85  			if name == attr.Name {
    86  				infos = append(infos, keyInfo{
    87  					Attr:  attr,
    88  					Index: i,
    89  				})
    90  			}
    91  		}
    92  	}
    93  
    94  	// shards
    95  	var shards []*Shard
    96  	for _, store := range nodes {
    97  		for _, info := range store.Shards {
    98  			shards = append(shards, &Shard{
    99  				TNShardRecord: metadata.TNShardRecord{
   100  					ShardID: info.ShardID,
   101  				},
   102  				ReplicaID: info.ReplicaID,
   103  				Address:   store.TxnServiceAddress,
   104  			})
   105  		}
   106  	}
   107  	sort.Slice(shards, func(i, j int) bool {
   108  		return shards[i].ShardID < shards[j].ShardID
   109  	})
   110  
   111  	type batValue struct {
   112  		bat   *batch.Batch
   113  		empty bool
   114  	}
   115  	m := make(map[*Shard]batValue)
   116  
   117  	for _, shard := range shards {
   118  		batchCopy := *bat
   119  		m[shard] = batValue{&batchCopy, true}
   120  	}
   121  
   122  	// shard batch
   123  	for i := 0; i < bat.RowCount(); i++ {
   124  		hasher := fnv.New32()
   125  		for _, info := range infos {
   126  			vec := bat.Vecs[info.Index]
   127  			bs, err := getBytesFromPrimaryVectorForHash(ctx, vec, i, info.Attr.Type)
   128  			if err != nil {
   129  				return nil, err
   130  			}
   131  			_, err = hasher.Write(bs)
   132  			if err != nil {
   133  				panic(err)
   134  			}
   135  		}
   136  		n := int(hasher.Sum32())
   137  		shard := shards[n%len(shards)]
   138  		m[shard] = batValue{m[shard].bat, false}
   139  	}
   140  
   141  	for shard, value := range m {
   142  		if value.empty {
   143  			continue
   144  		}
   145  		sharded = append(sharded, &ShardedBatch{
   146  			Shard: *shard,
   147  			Batch: value.bat,
   148  		})
   149  	}
   150  
   151  	return
   152  }
   153  
   154  func (h *HashShard) Vector(
   155  	ctx context.Context,
   156  	tableID ID,
   157  	getDefs getDefsFunc,
   158  	colName string,
   159  	vec *vector.Vector,
   160  	nodes []metadata.TNService,
   161  ) (
   162  	sharded []*ShardedVector,
   163  	err error,
   164  ) {
   165  
   166  	//TODO use vector nulls mask
   167  
   168  	// get defs
   169  	defs, err := getDefs(ctx)
   170  	if err != nil {
   171  		return nil, err
   172  	}
   173  
   174  	// get shard key
   175  	var shardAttr *engine.Attribute
   176  	for _, def := range defs {
   177  		attr, ok := def.(*engine.AttributeDef)
   178  		if !ok {
   179  			continue
   180  		}
   181  		if attr.Attr.Primary {
   182  			if attr.Attr.Name == colName {
   183  				shardAttr = &attr.Attr
   184  				break
   185  			}
   186  		}
   187  	}
   188  	if shardAttr == nil {
   189  		// no shard key
   190  		return nil, nil
   191  	}
   192  
   193  	// shards
   194  	var shards []*Shard
   195  	for _, store := range nodes {
   196  		for _, info := range store.Shards {
   197  			shards = append(shards, &Shard{
   198  				TNShardRecord: metadata.TNShardRecord{
   199  					ShardID: info.ShardID,
   200  				},
   201  				ReplicaID: info.ReplicaID,
   202  				Address:   store.TxnServiceAddress,
   203  			})
   204  		}
   205  	}
   206  	sort.Slice(shards, func(i, j int) bool {
   207  		return shards[i].ShardID < shards[j].ShardID
   208  	})
   209  	m := make(map[*Shard]*vector.Vector)
   210  
   211  	// shard vector
   212  	for i := 0; i < vec.Length(); i++ {
   213  		hasher := fnv.New32()
   214  		bs, err := getBytesFromPrimaryVectorForHash(ctx, vec, i, shardAttr.Type)
   215  		if err != nil {
   216  			return nil, err
   217  		}
   218  		_, err = hasher.Write(bs)
   219  		if err != nil {
   220  			panic(err)
   221  		}
   222  		n := int(hasher.Sum32())
   223  		shard := shards[n%len(shards)]
   224  		shardVec, ok := m[shard]
   225  		if !ok {
   226  			shardVec = vector.NewVec(shardAttr.Type)
   227  			m[shard] = shardVec
   228  		}
   229  		v := getNullableValueFromVector(vec, i)
   230  		appendNullableValueToVector(shardVec, v, h.mp)
   231  	}
   232  
   233  	for shard, vec := range m {
   234  		if vec.Length() == 0 {
   235  			continue
   236  		}
   237  		sharded = append(sharded, &ShardedVector{
   238  			Shard:  *shard,
   239  			Vector: vec,
   240  		})
   241  	}
   242  
   243  	return
   244  }
   245  
   246  var _ ShardPolicy = new(HashShard)
   247  
   248  func getBytesFromPrimaryVectorForHash(
   249  	ctx context.Context,
   250  	vec *vector.Vector,
   251  	i int,
   252  	typ types.Type) ([]byte, error) {
   253  	if vec.IsConst() {
   254  		panic("primary value vector should not be const")
   255  	}
   256  	if vec.GetNulls().Any() {
   257  		//TODO mimic to pass BVT
   258  		return nil, moerr.NewDuplicate(ctx)
   259  		//panic("primary value vector should not contain nulls")
   260  	}
   261  	if vec.GetType().IsFixedLen() {
   262  		// WTF is this?   Fix later when vector is refactored.
   263  		// is slice
   264  		size := vec.GetType().TypeSize()
   265  		l := vec.Length() * size
   266  		data := unsafe.Slice(vector.GetPtrAt[byte](vec, 0), l)
   267  		end := (i + 1) * size
   268  		if end > len(data) {
   269  			//TODO mimic to pass BVT
   270  			return nil, moerr.NewDuplicate(ctx)
   271  			//return nil, moerr.NewInvalidInput("vector size not match")
   272  		}
   273  		return data[i*size : (i+1)*size], nil
   274  	} else if vec.GetType().IsVarlen() {
   275  		slice := vector.MustBytesCol(vec)
   276  		if i >= len(slice) {
   277  			return []byte{}, nil
   278  		}
   279  		return slice[i], nil
   280  	}
   281  	panic(fmt.Sprintf("unknown type: %v", typ))
   282  }
   283  
   284  type Nullable struct {
   285  	IsNull bool
   286  	Value  any
   287  }
   288  
   289  func getNullableValueFromVector(vec *vector.Vector, i int) (value Nullable) {
   290  	if vec.IsConst() {
   291  		i = 0
   292  	}
   293  	switch vec.GetType().Oid {
   294  
   295  	case types.T_bool:
   296  		if vec.IsConstNull() {
   297  			value = Nullable{
   298  				IsNull: true,
   299  				Value:  false,
   300  			}
   301  			return
   302  		}
   303  		value = Nullable{
   304  			IsNull: vec.GetNulls().Contains(uint64(i)),
   305  			Value:  vector.MustFixedCol[bool](vec)[i],
   306  		}
   307  		return
   308  
   309  	case types.T_bit:
   310  		if vec.IsConstNull() {
   311  			value = Nullable{
   312  				IsNull: true,
   313  				Value:  uint64(0),
   314  			}
   315  			return
   316  		}
   317  		value = Nullable{
   318  			IsNull: vec.GetNulls().Contains(uint64(i)),
   319  			Value:  vector.MustFixedCol[uint64](vec)[i],
   320  		}
   321  		return
   322  
   323  	case types.T_int8:
   324  		if vec.IsConstNull() {
   325  			value = Nullable{
   326  				IsNull: true,
   327  				Value:  int8(0),
   328  			}
   329  			return
   330  		}
   331  		value = Nullable{
   332  			IsNull: vec.GetNulls().Contains(uint64(i)),
   333  			Value:  vector.MustFixedCol[int8](vec)[i],
   334  		}
   335  		return
   336  
   337  	case types.T_int16:
   338  		if vec.IsConstNull() {
   339  			value = Nullable{
   340  				IsNull: true,
   341  				Value:  int16(0),
   342  			}
   343  			return
   344  		}
   345  		value = Nullable{
   346  			IsNull: vec.GetNulls().Contains(uint64(i)),
   347  			Value:  vector.MustFixedCol[int16](vec)[i],
   348  		}
   349  		return
   350  
   351  	case types.T_int32:
   352  		if vec.IsConstNull() {
   353  			value = Nullable{
   354  				IsNull: true,
   355  				Value:  int32(0),
   356  			}
   357  			return
   358  		}
   359  		value = Nullable{
   360  			IsNull: vec.GetNulls().Contains(uint64(i)),
   361  			Value:  vector.MustFixedCol[int32](vec)[i],
   362  		}
   363  		return
   364  
   365  	case types.T_int64:
   366  		if vec.IsConstNull() {
   367  			value = Nullable{
   368  				IsNull: true,
   369  				Value:  int64(0),
   370  			}
   371  			return
   372  		}
   373  		value = Nullable{
   374  			IsNull: vec.GetNulls().Contains(uint64(i)),
   375  			Value:  vector.MustFixedCol[int64](vec)[i],
   376  		}
   377  		return
   378  
   379  	case types.T_uint8:
   380  		if vec.IsConstNull() {
   381  			value = Nullable{
   382  				IsNull: true,
   383  				Value:  uint8(0),
   384  			}
   385  			return
   386  		}
   387  		value = Nullable{
   388  			IsNull: vec.GetNulls().Contains(uint64(i)),
   389  			Value:  vector.MustFixedCol[uint8](vec)[i],
   390  		}
   391  		return
   392  
   393  	case types.T_uint16:
   394  		if vec.IsConstNull() {
   395  			value = Nullable{
   396  				IsNull: true,
   397  				Value:  uint16(0),
   398  			}
   399  			return
   400  		}
   401  		value = Nullable{
   402  			IsNull: vec.GetNulls().Contains(uint64(i)),
   403  			Value:  vector.MustFixedCol[uint16](vec)[i],
   404  		}
   405  		return
   406  
   407  	case types.T_uint32:
   408  		if vec.IsConstNull() {
   409  			value = Nullable{
   410  				IsNull: true,
   411  				Value:  uint32(0),
   412  			}
   413  			return
   414  		}
   415  		value = Nullable{
   416  			IsNull: vec.GetNulls().Contains(uint64(i)),
   417  			Value:  vector.MustFixedCol[uint32](vec)[i],
   418  		}
   419  		return
   420  
   421  	case types.T_uint64:
   422  		if vec.IsConstNull() {
   423  			value = Nullable{
   424  				IsNull: true,
   425  				Value:  uint64(0),
   426  			}
   427  			return
   428  		}
   429  		value = Nullable{
   430  			IsNull: vec.GetNulls().Contains(uint64(i)),
   431  			Value:  vector.MustFixedCol[uint64](vec)[i],
   432  		}
   433  		return
   434  
   435  	case types.T_float32:
   436  		if vec.IsConstNull() {
   437  			value = Nullable{
   438  				IsNull: true,
   439  				Value:  float32(0),
   440  			}
   441  			return
   442  		}
   443  		value = Nullable{
   444  			IsNull: vec.GetNulls().Contains(uint64(i)),
   445  			Value:  vector.MustFixedCol[float32](vec)[i],
   446  		}
   447  		return
   448  
   449  	case types.T_float64:
   450  		if vec.IsConstNull() {
   451  			value = Nullable{
   452  				IsNull: true,
   453  				Value:  float64(0),
   454  			}
   455  			return
   456  		}
   457  		value = Nullable{
   458  			IsNull: vec.GetNulls().Contains(uint64(i)),
   459  			Value:  vector.MustFixedCol[float64](vec)[i],
   460  		}
   461  		return
   462  
   463  	case types.T_tuple:
   464  		if vec.IsConstNull() {
   465  			value = Nullable{
   466  				IsNull: true,
   467  				Value:  []any{},
   468  			}
   469  			return
   470  		}
   471  		value = Nullable{
   472  			IsNull: vec.GetNulls().Contains(uint64(i)),
   473  			Value:  vector.MustFixedCol[[]any](vec)[i],
   474  		}
   475  		return
   476  
   477  	case types.T_char, types.T_varchar, types.T_binary, types.T_varbinary, types.T_json, types.T_blob, types.T_text,
   478  		types.T_array_float32, types.T_array_float64:
   479  		if vec.IsConstNull() {
   480  			value = Nullable{
   481  				IsNull: true,
   482  				Value:  []byte{},
   483  			}
   484  			return
   485  		}
   486  		value = Nullable{
   487  			IsNull: vec.GetNulls().Contains(uint64(i)),
   488  			Value:  vec.GetBytesAt(i),
   489  		}
   490  		return
   491  
   492  	case types.T_date:
   493  		if vec.IsConstNull() {
   494  			var zero types.Date
   495  			value = Nullable{
   496  				IsNull: true,
   497  				Value:  zero,
   498  			}
   499  			return
   500  		}
   501  		value = Nullable{
   502  			IsNull: vec.GetNulls().Contains(uint64(i)),
   503  			Value:  vector.MustFixedCol[types.Date](vec)[i],
   504  		}
   505  		return
   506  
   507  	case types.T_time:
   508  		if vec.IsConstNull() {
   509  			var zero types.Time
   510  			value = Nullable{
   511  				IsNull: true,
   512  				Value:  zero,
   513  			}
   514  			return
   515  		}
   516  		value = Nullable{
   517  			IsNull: vec.GetNulls().Contains(uint64(i)),
   518  			Value:  vector.MustFixedCol[types.Time](vec)[i],
   519  		}
   520  		return
   521  
   522  	case types.T_datetime:
   523  		if vec.IsConstNull() {
   524  			var zero types.Datetime
   525  			value = Nullable{
   526  				IsNull: true,
   527  				Value:  zero,
   528  			}
   529  			return
   530  		}
   531  		value = Nullable{
   532  			IsNull: vec.GetNulls().Contains(uint64(i)),
   533  			Value:  vector.MustFixedCol[types.Datetime](vec)[i],
   534  		}
   535  		return
   536  
   537  	case types.T_timestamp:
   538  		if vec.IsConstNull() {
   539  			var zero types.Timestamp
   540  			value = Nullable{
   541  				IsNull: true,
   542  				Value:  zero,
   543  			}
   544  			return
   545  		}
   546  		value = Nullable{
   547  			IsNull: vec.GetNulls().Contains(uint64(i)),
   548  			Value:  vector.MustFixedCol[types.Timestamp](vec)[i],
   549  		}
   550  		return
   551  
   552  	case types.T_enum:
   553  		if vec.IsConstNull() {
   554  			var zero types.Enum
   555  			value = Nullable{
   556  				IsNull: true,
   557  				Value:  zero,
   558  			}
   559  			return
   560  		}
   561  		value = Nullable{
   562  			IsNull: vec.GetNulls().Contains(uint64(i)),
   563  			Value:  vector.MustFixedCol[types.Enum](vec)[i],
   564  		}
   565  		return
   566  
   567  	case types.T_decimal64:
   568  		if vec.IsConstNull() {
   569  			var zero types.Decimal64
   570  			value = Nullable{
   571  				IsNull: true,
   572  				Value:  zero,
   573  			}
   574  			return
   575  		}
   576  		value = Nullable{
   577  			IsNull: vec.GetNulls().Contains(uint64(i)),
   578  			Value:  vector.MustFixedCol[types.Decimal64](vec)[i],
   579  		}
   580  		return
   581  
   582  	case types.T_decimal128:
   583  		if vec.IsConstNull() {
   584  			var zero types.Decimal128
   585  			value = Nullable{
   586  				IsNull: true,
   587  				Value:  zero,
   588  			}
   589  			return
   590  		}
   591  		value = Nullable{
   592  			IsNull: vec.GetNulls().Contains(uint64(i)),
   593  			Value:  vector.MustFixedCol[types.Decimal128](vec)[i],
   594  		}
   595  		return
   596  
   597  	case types.T_Rowid:
   598  		if vec.IsConstNull() {
   599  			var zero types.Rowid
   600  			value = Nullable{
   601  				IsNull: true,
   602  				Value:  zero,
   603  			}
   604  			return
   605  		}
   606  		value = Nullable{
   607  			IsNull: vec.GetNulls().Contains(uint64(i)),
   608  			Value:  vector.MustFixedCol[types.Rowid](vec)[i],
   609  		}
   610  		return
   611  	case types.T_Blockid:
   612  		if vec.IsConstNull() {
   613  			var zero types.Blockid
   614  			value = Nullable{
   615  				IsNull: true,
   616  				Value:  zero,
   617  			}
   618  			return
   619  		}
   620  		value = Nullable{
   621  			IsNull: vec.GetNulls().Contains(uint64(i)),
   622  			Value:  vector.MustFixedCol[types.Blockid](vec)[i],
   623  		}
   624  		return
   625  	case types.T_uuid:
   626  		if vec.IsConstNull() {
   627  			var zero types.Uuid
   628  			value = Nullable{
   629  				IsNull: true,
   630  				Value:  zero,
   631  			}
   632  			return
   633  		}
   634  		value = Nullable{
   635  			IsNull: vec.GetNulls().Contains(uint64(i)),
   636  			Value:  vector.MustFixedCol[types.Uuid](vec)[i],
   637  		}
   638  		return
   639  
   640  	}
   641  
   642  	panic(fmt.Sprintf("unknown column type: %v", *vec.GetType()))
   643  }
   644  
   645  func appendNullableValueToVector(vec *vector.Vector, value Nullable, mp *mpool.MPool) {
   646  	str, ok := value.Value.(string)
   647  	if ok {
   648  		value.Value = []byte(str)
   649  	}
   650  	vector.AppendAny(vec, value.Value, value.IsNull, mp)
   651  }