github.com/vescale/zgraph@v0.0.0-20230410094002-959c02d50f95/internal/structure/hash.go (about)

     1  // Copyright 2022 zGraph Authors. All rights reserved.
     2  //
     3  // Copyright 2015 PingCAP, Inc.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //     http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  
    17  package structure
    18  
    19  import (
    20  	"bytes"
    21  	"context"
    22  	"strconv"
    23  
    24  	"github.com/pingcap/errors"
    25  	"github.com/vescale/zgraph/storage/kv"
    26  )
    27  
    28  // HashPair is the pair for (field, value) in a hash.
    29  type HashPair struct {
    30  	Field []byte
    31  	Value []byte
    32  }
    33  
    34  // HSet sets the string value of a hash field.
    35  func (t *TxStructure) HSet(key []byte, field []byte, value []byte) error {
    36  	if t.readWriter == nil {
    37  		return ErrWriteOnSnapshot
    38  	}
    39  	return t.updateHash(key, field, func([]byte) ([]byte, error) {
    40  		return value, nil
    41  	})
    42  }
    43  
    44  // HGet gets the value of a hash field.
    45  func (t *TxStructure) HGet(key []byte, field []byte) ([]byte, error) {
    46  	dataKey := t.encodeHashDataKey(key, field)
    47  	value, err := t.reader.Get(context.TODO(), dataKey)
    48  	if errors.Cause(err) == kv.ErrNotExist {
    49  		err = nil
    50  	}
    51  	return value, errors.Trace(err)
    52  }
    53  
    54  func (*TxStructure) hashFieldIntegerVal(val int64) []byte {
    55  	return []byte(strconv.FormatInt(val, 10))
    56  }
    57  
    58  // EncodeHashAutoIDKeyValue returns the hash key-value generated by the key and the field
    59  func (t *TxStructure) EncodeHashAutoIDKeyValue(key []byte, field []byte, val int64) (k, v []byte) {
    60  	return t.encodeHashDataKey(key, field), t.hashFieldIntegerVal(val)
    61  }
    62  
    63  // HInc increments the integer value of a hash field, by step, returns
    64  // the value after the increment.
    65  func (t *TxStructure) HInc(key []byte, field []byte, step int64) (int64, error) {
    66  	if t.readWriter == nil {
    67  		return 0, ErrWriteOnSnapshot
    68  	}
    69  	base := int64(0)
    70  	err := t.updateHash(key, field, func(oldValue []byte) ([]byte, error) {
    71  		if oldValue != nil {
    72  			var err error
    73  			base, err = strconv.ParseInt(string(oldValue), 10, 64)
    74  			if err != nil {
    75  				return nil, errors.Trace(err)
    76  			}
    77  		}
    78  		base += step
    79  		return t.hashFieldIntegerVal(base), nil
    80  	})
    81  
    82  	return base, errors.Trace(err)
    83  }
    84  
    85  // HGetInt64 gets int64 value of a hash field.
    86  func (t *TxStructure) HGetInt64(key []byte, field []byte) (int64, error) {
    87  	value, err := t.HGet(key, field)
    88  	if err != nil || value == nil {
    89  		return 0, errors.Trace(err)
    90  	}
    91  
    92  	var n int64
    93  	n, err = strconv.ParseInt(string(value), 10, 64)
    94  	return n, errors.Trace(err)
    95  }
    96  
    97  func (t *TxStructure) updateHash(key []byte, field []byte, fn func(oldValue []byte) ([]byte, error)) error {
    98  	dataKey := t.encodeHashDataKey(key, field)
    99  	oldValue, err := t.loadHashValue(dataKey)
   100  	if err != nil {
   101  		return errors.Trace(err)
   102  	}
   103  
   104  	newValue, err := fn(oldValue)
   105  	if err != nil {
   106  		return errors.Trace(err)
   107  	}
   108  
   109  	// Check if new value is equal to old value.
   110  	if bytes.Equal(oldValue, newValue) {
   111  		return nil
   112  	}
   113  
   114  	if err = t.readWriter.Set(dataKey, newValue); err != nil {
   115  		return errors.Trace(err)
   116  	}
   117  
   118  	return nil
   119  }
   120  
   121  // HDel deletes one or more hash fields.
   122  func (t *TxStructure) HDel(key []byte, fields ...[]byte) error {
   123  	if t.readWriter == nil {
   124  		return ErrWriteOnSnapshot
   125  	}
   126  
   127  	for _, field := range fields {
   128  		dataKey := t.encodeHashDataKey(key, field)
   129  
   130  		value, err := t.loadHashValue(dataKey)
   131  		if err != nil {
   132  			return errors.Trace(err)
   133  		}
   134  
   135  		if value != nil {
   136  			if err = t.readWriter.Delete(dataKey); err != nil {
   137  				return errors.Trace(err)
   138  			}
   139  		}
   140  	}
   141  
   142  	return nil
   143  }
   144  
   145  // HKeys gets all the fields in a hash.
   146  func (t *TxStructure) HKeys(key []byte) ([][]byte, error) {
   147  	var keys [][]byte
   148  	err := t.iterateHash(key, func(field []byte, value []byte) error {
   149  		keys = append(keys, append([]byte{}, field...))
   150  		return nil
   151  	})
   152  
   153  	return keys, errors.Trace(err)
   154  }
   155  
   156  // HGetAll gets all the fields and values in a hash.
   157  func (t *TxStructure) HGetAll(key []byte) ([]HashPair, error) {
   158  	var res []HashPair
   159  	err := t.iterateHash(key, func(field []byte, value []byte) error {
   160  		pair := HashPair{
   161  			Field: append([]byte{}, field...),
   162  			Value: append([]byte{}, value...),
   163  		}
   164  		res = append(res, pair)
   165  		return nil
   166  	})
   167  
   168  	return res, errors.Trace(err)
   169  }
   170  
   171  // HGetLen gets the length of hash.
   172  func (t *TxStructure) HGetLen(key []byte) (uint64, error) {
   173  	hashLen := 0
   174  	err := t.iterateHash(key, func(field []byte, value []byte) error {
   175  		hashLen++
   176  		return nil
   177  	})
   178  
   179  	return uint64(hashLen), errors.Trace(err)
   180  }
   181  
   182  // HGetLastN gets latest N fields and values in hash.
   183  func (t *TxStructure) HGetLastN(key []byte, num int) ([]HashPair, error) {
   184  	res := make([]HashPair, 0, num)
   185  	err := t.iterReverseHash(key, func(field []byte, value []byte) (bool, error) {
   186  		pair := HashPair{
   187  			Field: append([]byte{}, field...),
   188  			Value: append([]byte{}, value...),
   189  		}
   190  		res = append(res, pair)
   191  		if len(res) >= num {
   192  			return false, nil
   193  		}
   194  		return true, nil
   195  	})
   196  	return res, errors.Trace(err)
   197  }
   198  
   199  // HClear removes the hash value of the key.
   200  func (t *TxStructure) HClear(key []byte) error {
   201  	err := t.iterateHash(key, func(field []byte, value []byte) error {
   202  		k := t.encodeHashDataKey(key, field)
   203  		return errors.Trace(t.readWriter.Delete(k))
   204  	})
   205  
   206  	if err != nil {
   207  		return errors.Trace(err)
   208  	}
   209  
   210  	return nil
   211  }
   212  
   213  func (t *TxStructure) iterateHash(key []byte, fn func(k []byte, v []byte) error) error {
   214  	dataPrefix := t.hashDataKeyPrefix(key)
   215  	it, err := t.reader.Iter(dataPrefix, dataPrefix.PrefixNext())
   216  	if err != nil {
   217  		return errors.Trace(err)
   218  	}
   219  
   220  	var field []byte
   221  
   222  	for it.Valid() {
   223  		if !it.Key().HasPrefix(dataPrefix) {
   224  			break
   225  		}
   226  
   227  		_, field, err = t.decodeHashDataKey(it.Key())
   228  		if err != nil {
   229  			return errors.Trace(err)
   230  		}
   231  
   232  		if err = fn(field, it.Value()); err != nil {
   233  			return errors.Trace(err)
   234  		}
   235  
   236  		err = it.Next()
   237  		if err != nil {
   238  			return errors.Trace(err)
   239  		}
   240  	}
   241  
   242  	return nil
   243  }
   244  
   245  // ReverseHashIterator is the reverse hash iterator.
   246  type ReverseHashIterator struct {
   247  	t      *TxStructure
   248  	iter   kv.Iterator
   249  	prefix []byte
   250  	done   bool
   251  	field  []byte
   252  }
   253  
   254  // Next implements the Iterator Next.
   255  func (i *ReverseHashIterator) Next() error {
   256  	err := i.iter.Next()
   257  	if err != nil {
   258  		return errors.Trace(err)
   259  	}
   260  	if !i.iter.Key().HasPrefix(i.prefix) {
   261  		i.done = true
   262  		return nil
   263  	}
   264  
   265  	_, field, err := i.t.decodeHashDataKey(i.iter.Key())
   266  	if err != nil {
   267  		return errors.Trace(err)
   268  	}
   269  	i.field = field
   270  	return nil
   271  }
   272  
   273  // Valid implements the Iterator Valid.
   274  func (i *ReverseHashIterator) Valid() bool {
   275  	return i.iter.Valid() && !i.done
   276  }
   277  
   278  // Key implements the Iterator Key.
   279  func (i *ReverseHashIterator) Key() []byte {
   280  	return i.field
   281  }
   282  
   283  // Value implements the Iterator Value.
   284  func (i *ReverseHashIterator) Value() []byte {
   285  	return i.iter.Value()
   286  }
   287  
   288  // Close Implements the Iterator Close.
   289  func (*ReverseHashIterator) Close() {}
   290  
   291  // NewHashReverseIter creates a reverse hash iterator.
   292  func NewHashReverseIter(t *TxStructure, key []byte) (*ReverseHashIterator, error) {
   293  	return newHashReverseIter(t, key, nil)
   294  }
   295  
   296  // NewHashReverseIterBeginWithField creates a reverse hash iterator, begin with field.
   297  func NewHashReverseIterBeginWithField(t *TxStructure, key []byte, field []byte) (*ReverseHashIterator, error) {
   298  	return newHashReverseIter(t, key, field)
   299  }
   300  
   301  func newHashReverseIter(t *TxStructure, key []byte, field []byte) (*ReverseHashIterator, error) {
   302  	var iterStart kv.Key
   303  	dataPrefix := t.hashDataKeyPrefix(key)
   304  	if len(field) == 0 {
   305  		iterStart = dataPrefix.PrefixNext()
   306  	} else {
   307  		iterStart = t.encodeHashDataKey(key, field).PrefixNext()
   308  	}
   309  
   310  	it, err := t.reader.IterReverse(nil, iterStart)
   311  	if err != nil {
   312  		return nil, errors.Trace(err)
   313  	}
   314  	return &ReverseHashIterator{
   315  		t:      t,
   316  		iter:   it,
   317  		prefix: dataPrefix,
   318  	}, nil
   319  }
   320  
   321  func (t *TxStructure) iterReverseHash(key []byte, fn func(k []byte, v []byte) (bool, error)) error {
   322  	dataPrefix := t.hashDataKeyPrefix(key)
   323  	it, err := t.reader.IterReverse(nil, dataPrefix.PrefixNext())
   324  	if err != nil {
   325  		return errors.Trace(err)
   326  	}
   327  
   328  	var field []byte
   329  	for it.Valid() {
   330  		if !it.Key().HasPrefix(dataPrefix) {
   331  			break
   332  		}
   333  
   334  		_, field, err = t.decodeHashDataKey(it.Key())
   335  		if err != nil {
   336  			return errors.Trace(err)
   337  		}
   338  
   339  		more, err := fn(field, it.Value())
   340  		if !more || err != nil {
   341  			return errors.Trace(err)
   342  		}
   343  
   344  		err = it.Next()
   345  		if err != nil {
   346  			return errors.Trace(err)
   347  		}
   348  	}
   349  	return nil
   350  }
   351  
   352  func (t *TxStructure) loadHashValue(dataKey []byte) ([]byte, error) {
   353  	v, err := t.reader.Get(context.TODO(), dataKey)
   354  	if errors.Cause(err) == kv.ErrNotExist {
   355  		err = nil
   356  		v = nil
   357  	}
   358  	if err != nil {
   359  		return nil, errors.Trace(err)
   360  	}
   361  
   362  	return v, nil
   363  }