github.com/ari-anchor/sei-tendermint@v0.0.0-20230519144642-dc826b7b56bb/test/e2e/app/state.go (about)

     1  //nolint: gosec
     2  package app
     3  
     4  import (
     5  	"crypto/sha256"
     6  	"encoding/binary"
     7  	"encoding/json"
     8  	"errors"
     9  	"fmt"
    10  	"os"
    11  	"path/filepath"
    12  	"sort"
    13  	"sync"
    14  )
    15  
    16  const stateFileName = "app_state.json"
    17  const prevStateFileName = "prev_app_state.json"
    18  
    19  // State is the application state.
    20  type State struct {
    21  	sync.RWMutex
    22  	Height uint64
    23  	Values map[string]string
    24  	Hash   []byte
    25  
    26  	// private fields aren't marshaled to disk.
    27  	currentFile string
    28  	// app saves current and previous state for rollback functionality
    29  	previousFile    string
    30  	persistInterval uint64
    31  	initialHeight   uint64
    32  }
    33  
    34  // NewState creates a new state.
    35  func NewState(dir string, persistInterval uint64) (*State, error) {
    36  	state := &State{
    37  		Values:          make(map[string]string),
    38  		currentFile:     filepath.Join(dir, stateFileName),
    39  		previousFile:    filepath.Join(dir, prevStateFileName),
    40  		persistInterval: persistInterval,
    41  	}
    42  	state.Hash = hashItems(state.Values, state.Height)
    43  	err := state.load()
    44  	switch {
    45  	case errors.Is(err, os.ErrNotExist):
    46  	case err != nil:
    47  		return nil, err
    48  	}
    49  	return state, nil
    50  }
    51  
    52  // load loads state from disk. It does not take out a lock, since it is called
    53  // during construction.
    54  func (s *State) load() error {
    55  	bz, err := os.ReadFile(s.currentFile)
    56  	if err != nil {
    57  		// if the current state doesn't exist then we try recover from the previous state
    58  		if errors.Is(err, os.ErrNotExist) {
    59  			bz, err = os.ReadFile(s.previousFile)
    60  			if err != nil {
    61  				return fmt.Errorf("failed to read both current and previous state (%q): %w",
    62  					s.previousFile, err)
    63  			}
    64  		} else {
    65  			return fmt.Errorf("failed to read state from %q: %w", s.currentFile, err)
    66  		}
    67  	}
    68  	err = json.Unmarshal(bz, s)
    69  	if err != nil {
    70  		return fmt.Errorf("invalid state data in %q: %w", s.currentFile, err)
    71  	}
    72  	return nil
    73  }
    74  
    75  // save saves the state to disk. It does not take out a lock since it is called
    76  // internally by Commit which does lock.
    77  func (s *State) save() error {
    78  	bz, err := json.Marshal(s)
    79  	if err != nil {
    80  		return fmt.Errorf("failed to marshal state: %w", err)
    81  	}
    82  	// We write the state to a separate file and move it to the destination, to
    83  	// make it atomic.
    84  	newFile := fmt.Sprintf("%v.new", s.currentFile)
    85  	err = os.WriteFile(newFile, bz, 0644)
    86  	if err != nil {
    87  		return fmt.Errorf("failed to write state to %q: %w", s.currentFile, err)
    88  	}
    89  	// We take the current state and move it to the previous state, replacing it
    90  	if _, err := os.Stat(s.currentFile); err == nil {
    91  		if err := os.Rename(s.currentFile, s.previousFile); err != nil {
    92  			return fmt.Errorf("failed to replace previous state: %w", err)
    93  		}
    94  	}
    95  	// Finally, we take the new state and replace the current state.
    96  	return os.Rename(newFile, s.currentFile)
    97  }
    98  
    99  // Export exports key/value pairs as JSON, used for state sync snapshots.
   100  func (s *State) Export() ([]byte, error) {
   101  	s.RLock()
   102  	defer s.RUnlock()
   103  	return json.Marshal(s.Values)
   104  }
   105  
   106  // Import imports key/value pairs from JSON bytes, used for InitChain.AppStateBytes and
   107  // state sync snapshots. It also saves the state once imported.
   108  func (s *State) Import(height uint64, jsonBytes []byte) error {
   109  	s.Lock()
   110  	defer s.Unlock()
   111  	values := map[string]string{}
   112  	err := json.Unmarshal(jsonBytes, &values)
   113  	if err != nil {
   114  		return fmt.Errorf("failed to decode imported JSON data: %w", err)
   115  	}
   116  	s.Height = height
   117  	s.Values = values
   118  	s.Hash = hashItems(values, height)
   119  	return s.save()
   120  }
   121  
   122  // Get fetches a value. A missing value is returned as an empty string.
   123  func (s *State) Get(key string) string {
   124  	s.RLock()
   125  	defer s.RUnlock()
   126  	return s.Values[key]
   127  }
   128  
   129  // Set sets a value. Setting an empty value is equivalent to deleting it.
   130  func (s *State) Set(key, value string) {
   131  	s.Lock()
   132  	defer s.Unlock()
   133  	if value == "" {
   134  		delete(s.Values, key)
   135  	} else {
   136  		s.Values[key] = value
   137  	}
   138  }
   139  
   140  // Finalize is called after applying a block, updating the height and returning the new app_hash
   141  func (s *State) Finalize() []byte {
   142  	s.Lock()
   143  	defer s.Unlock()
   144  	switch {
   145  	case s.Height > 0:
   146  		s.Height++
   147  	case s.initialHeight > 0:
   148  		s.Height = s.initialHeight
   149  	default:
   150  		s.Height = 1
   151  	}
   152  	s.Hash = hashItems(s.Values, s.Height)
   153  	return s.Hash
   154  }
   155  
   156  // Commit commits the current state.
   157  func (s *State) Commit() (uint64, error) {
   158  	s.Lock()
   159  	defer s.Unlock()
   160  	if s.persistInterval > 0 && s.Height%s.persistInterval == 0 {
   161  		err := s.save()
   162  		if err != nil {
   163  			return 0, err
   164  		}
   165  	}
   166  	return s.Height, nil
   167  }
   168  
   169  func (s *State) Rollback() error {
   170  	bz, err := os.ReadFile(s.previousFile)
   171  	if err != nil {
   172  		return fmt.Errorf("failed to read state from %q: %w", s.previousFile, err)
   173  	}
   174  	err = json.Unmarshal(bz, s)
   175  	if err != nil {
   176  		return fmt.Errorf("invalid state data in %q: %w", s.previousFile, err)
   177  	}
   178  	return nil
   179  }
   180  
   181  // hashItems hashes a set of key/value items.
   182  func hashItems(items map[string]string, height uint64) []byte {
   183  	keys := make([]string, 0, len(items))
   184  	for key := range items {
   185  		keys = append(keys, key)
   186  	}
   187  	sort.Strings(keys)
   188  
   189  	hasher := sha256.New()
   190  	var b [8]byte
   191  	binary.BigEndian.PutUint64(b[:], height)
   192  	_, _ = hasher.Write(b[:])
   193  	for _, key := range keys {
   194  		_, _ = hasher.Write([]byte(key))
   195  		_, _ = hasher.Write([]byte{0})
   196  		_, _ = hasher.Write([]byte(items[key]))
   197  		_, _ = hasher.Write([]byte{0})
   198  	}
   199  	return hasher.Sum(nil)
   200  }