code.vegaprotocol.io/vega@v0.79.0/core/staking/accounting_snapshot.go (about)

     1  // Copyright (C) 2023 Gobalsky Labs Limited
     2  //
     3  // This program is free software: you can redistribute it and/or modify
     4  // it under the terms of the GNU Affero General Public License as
     5  // published by the Free Software Foundation, either version 3 of the
     6  // License, or (at your option) any later version.
     7  //
     8  // This program is distributed in the hope that it will be useful,
     9  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    10  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    11  // GNU Affero General Public License for more details.
    12  //
    13  // You should have received a copy of the GNU Affero General Public License
    14  // along with this program.  If not, see <http://www.gnu.org/licenses/>.
    15  
    16  package staking
    17  
    18  import (
    19  	"context"
    20  	"fmt"
    21  
    22  	"code.vegaprotocol.io/vega/core/events"
    23  	"code.vegaprotocol.io/vega/core/types"
    24  	vgcontext "code.vegaprotocol.io/vega/libs/context"
    25  	"code.vegaprotocol.io/vega/libs/proto"
    26  	"code.vegaprotocol.io/vega/logging"
    27  )
    28  
    29  var accountsKey = (&types.PayloadStakingAccounts{}).Key()
    30  
    31  type accountingSnapshotState struct {
    32  	serialised  []byte
    33  	isRestoring bool
    34  }
    35  
    36  func (a *Accounting) serialiseStakingAccounts() ([]byte, error) {
    37  	accounts := make([]*types.StakingAccount, 0, len(a.hashableAccounts))
    38  	a.log.Debug("serialsing staking accounts", logging.Int("n", len(a.hashableAccounts)))
    39  	for _, acc := range a.hashableAccounts {
    40  		// dedup transactions with the same eth hash from different block heights and recalc the balance
    41  		acc.Events = a.dedupHack(acc.Events)
    42  		acc.computeOngoingBalance()
    43  		accounts = append(accounts,
    44  			&types.StakingAccount{
    45  				Party:   acc.Party,
    46  				Balance: acc.Balance,
    47  				Events:  acc.Events,
    48  			})
    49  	}
    50  
    51  	var psts *types.StakeTotalSupply
    52  	if a.pendingStakeTotalSupply != nil {
    53  		psts = a.pendingStakeTotalSupply.sts
    54  	}
    55  
    56  	pl := types.Payload{
    57  		Data: &types.PayloadStakingAccounts{
    58  			PendingStakeTotalSupply: psts,
    59  			StakingAccounts: &types.StakingAccounts{
    60  				Accounts:                accounts,
    61  				StakingAssetTotalSupply: a.stakingAssetTotalSupply.Clone(),
    62  			},
    63  		},
    64  	}
    65  
    66  	return proto.Marshal(pl.IntoProto())
    67  }
    68  
    69  // get the serialised form and hash of the given key.
    70  func (a *Accounting) serialise(k string) ([]byte, error) {
    71  	if k != accountsKey {
    72  		return nil, types.ErrSnapshotKeyDoesNotExist
    73  	}
    74  
    75  	data, err := a.serialiseStakingAccounts()
    76  	if err != nil {
    77  		return nil, err
    78  	}
    79  
    80  	a.accState.serialised = data
    81  	return data, nil
    82  }
    83  
    84  func (a *Accounting) OnStateLoaded(_ context.Context) error {
    85  	a.accState.isRestoring = false
    86  	return nil
    87  }
    88  
    89  func (a *Accounting) OnStateLoadStarts(_ context.Context) error {
    90  	a.accState.isRestoring = true
    91  	return nil
    92  }
    93  
    94  func (a *Accounting) Namespace() types.SnapshotNamespace {
    95  	return types.StakingSnapshot
    96  }
    97  
    98  func (a *Accounting) Keys() []string {
    99  	return []string{accountsKey}
   100  }
   101  
   102  func (a *Accounting) Stopped() bool {
   103  	return false
   104  }
   105  
   106  func (a *Accounting) GetState(k string) ([]byte, []types.StateProvider, error) {
   107  	data, err := a.serialise(k)
   108  	return data, nil, err
   109  }
   110  
   111  func (a *Accounting) LoadState(ctx context.Context, payload *types.Payload) ([]types.StateProvider, error) {
   112  	if a.Namespace() != payload.Data.Namespace() {
   113  		return nil, types.ErrInvalidSnapshotNamespace
   114  	}
   115  
   116  	switch pl := payload.Data.(type) {
   117  	case *types.PayloadStakingAccounts:
   118  
   119  		return nil, a.restoreStakingAccounts(ctx, pl.StakingAccounts, pl.PendingStakeTotalSupply, payload)
   120  	default:
   121  		return nil, types.ErrUnknownSnapshotType
   122  	}
   123  }
   124  
   125  // dedupHack takes care of events with the same ethereum tx hash originating from
   126  // reorg - the result is that duplicates are removed and the branch with the latest block height is kept
   127  // after calling this function, the balance should be recalculated.
   128  func (a *Accounting) dedupHack(evts []*types.StakeLinking) []*types.StakeLinking {
   129  	hashToEvt := map[string]*types.StakeLinking{}
   130  	for _, sl := range evts {
   131  		evt, ok := hashToEvt[sl.TxHash]
   132  		if !ok {
   133  			hashToEvt[sl.TxHash] = sl
   134  		} else {
   135  			if sl.BlockHeight > evt.BlockHeight {
   136  				a.log.Warn("duplicate events with identical transaction hash found", logging.String("tx-hash", sl.TxHash), logging.Uint64("block-height1", sl.BlockHeight), logging.Uint64("block-height2", evt.BlockHeight))
   137  				hashToEvt[sl.TxHash] = sl
   138  			}
   139  		}
   140  	}
   141  	newEvts := make([]*types.StakeLinking, 0, len(hashToEvt))
   142  	for _, sl := range evts {
   143  		evt := hashToEvt[sl.TxHash]
   144  		if evt.BlockHeight == sl.BlockHeight {
   145  			newEvts = append(newEvts, sl)
   146  		}
   147  	}
   148  	return newEvts
   149  }
   150  
   151  func (a *Accounting) restoreStakingAccounts(ctx context.Context, accounts *types.StakingAccounts, pendingSupply *types.StakeTotalSupply, p *types.Payload) error {
   152  	a.hashableAccounts = make([]*Account, 0, len(accounts.Accounts))
   153  	a.log.Debug("restoring staking accounts",
   154  		logging.Int("n", len(accounts.Accounts)),
   155  	)
   156  	evts := []events.Event{}
   157  	pevts := []events.Event{}
   158  	for _, acc := range accounts.Accounts {
   159  		stakingAcc := &Account{
   160  			Party:   acc.Party,
   161  			Balance: acc.Balance,
   162  			Events:  a.dedupHack(acc.Events),
   163  		}
   164  		stakingAcc.computeOngoingBalance()
   165  		a.hashableAccounts = append(a.hashableAccounts, stakingAcc)
   166  		a.accounts[acc.Party] = stakingAcc
   167  		pevts = append(pevts, events.NewPartyEvent(ctx, types.Party{Id: acc.Party}))
   168  		for _, e := range acc.Events {
   169  			evts = append(evts, events.NewStakeLinking(ctx, *e))
   170  		}
   171  	}
   172  
   173  	if pendingSupply != nil {
   174  		expectedSupply := pendingSupply.TotalSupply.Clone()
   175  		a.pendingStakeTotalSupply = &pendingStakeTotalSupply{
   176  			sts:     pendingSupply,
   177  			chainID: a.chainID,
   178  			check: func() error {
   179  				totalSupply, err := a.getStakeAssetTotalSupply(a.stakingBridgeAddresses[0])
   180  				if err != nil {
   181  					return err
   182  				}
   183  
   184  				if totalSupply.NEQ(expectedSupply) {
   185  					return fmt.Errorf(
   186  						"invalid stake asset total supply, expected %s got %s",
   187  						expectedSupply.String(), totalSupply.String(),
   188  					)
   189  				}
   190  
   191  				return nil
   192  			},
   193  		}
   194  		a.witness.RestoreResource(a.pendingStakeTotalSupply, a.onStakeTotalSupplyVerified)
   195  	}
   196  
   197  	if vgcontext.InProgressUpgradeFrom(ctx, "v0.76.8") {
   198  		lastSeen := a.getLastBlockSeen()
   199  		for _, addr := range a.stakingBridgeAddresses {
   200  			a.log.Info("migration code updating multisig last seen",
   201  				logging.String("address", addr.Hex()),
   202  				logging.Uint64("last-seen", lastSeen),
   203  				logging.String("chain-id", a.chainID),
   204  			)
   205  			a.ethSource.UpdateContractBlock(addr.Hex(), a.chainID, lastSeen)
   206  		}
   207  	}
   208  
   209  	a.stakingAssetTotalSupply = accounts.StakingAssetTotalSupply.Clone()
   210  	var err error
   211  	a.accState.serialised, err = proto.Marshal(p.IntoProto())
   212  	a.broker.SendBatch(evts)
   213  	a.broker.SendBatch(pevts)
   214  	return err
   215  }