github.com/lazyledger/lazyledger-core@v0.35.0-dev.0.20210613111200-4c651f053571/test/e2e/app/state.go (about)

     1  //nolint: gosec
     2  package main
     3  
     4  import (
     5  	"crypto/sha256"
     6  	"encoding/json"
     7  	"errors"
     8  	"fmt"
     9  	"io/ioutil"
    10  	"os"
    11  	"sort"
    12  	"sync"
    13  )
    14  
    15  // State is the application state.
    16  type State struct {
    17  	sync.RWMutex
    18  	Height uint64
    19  	Values map[string]string
    20  	Hash   []byte
    21  
    22  	// private fields aren't marshalled to disk.
    23  	file            string
    24  	persistInterval uint64
    25  	initialHeight   uint64
    26  }
    27  
    28  // NewState creates a new state.
    29  func NewState(file string, persistInterval uint64) (*State, error) {
    30  	state := &State{
    31  		Values:          make(map[string]string),
    32  		file:            file,
    33  		persistInterval: persistInterval,
    34  	}
    35  	state.Hash = hashItems(state.Values)
    36  	err := state.load()
    37  	switch {
    38  	case errors.Is(err, os.ErrNotExist):
    39  	case err != nil:
    40  		return nil, err
    41  	}
    42  	return state, nil
    43  }
    44  
    45  // load loads state from disk. It does not take out a lock, since it is called
    46  // during construction.
    47  func (s *State) load() error {
    48  	bz, err := ioutil.ReadFile(s.file)
    49  	if err != nil {
    50  		return fmt.Errorf("failed to read state from %q: %w", s.file, err)
    51  	}
    52  	err = json.Unmarshal(bz, s)
    53  	if err != nil {
    54  		return fmt.Errorf("invalid state data in %q: %w", s.file, err)
    55  	}
    56  	return nil
    57  }
    58  
    59  // save saves the state to disk. It does not take out a lock since it is called
    60  // internally by Commit which does lock.
    61  func (s *State) save() error {
    62  	bz, err := json.Marshal(s)
    63  	if err != nil {
    64  		return fmt.Errorf("failed to marshal state: %w", err)
    65  	}
    66  	// We write the state to a separate file and move it to the destination, to
    67  	// make it atomic.
    68  	newFile := fmt.Sprintf("%v.new", s.file)
    69  	err = ioutil.WriteFile(newFile, bz, 0644)
    70  	if err != nil {
    71  		return fmt.Errorf("failed to write state to %q: %w", s.file, err)
    72  	}
    73  	return os.Rename(newFile, s.file)
    74  }
    75  
    76  // Export exports key/value pairs as JSON, used for state sync snapshots.
    77  func (s *State) Export() ([]byte, error) {
    78  	s.RLock()
    79  	defer s.RUnlock()
    80  	return json.Marshal(s.Values)
    81  }
    82  
    83  // Import imports key/value pairs from JSON bytes, used for InitChain.AppStateBytes and
    84  // state sync snapshots. It also saves the state once imported.
    85  func (s *State) Import(height uint64, jsonBytes []byte) error {
    86  	s.Lock()
    87  	defer s.Unlock()
    88  	values := map[string]string{}
    89  	err := json.Unmarshal(jsonBytes, &values)
    90  	if err != nil {
    91  		return fmt.Errorf("failed to decode imported JSON data: %w", err)
    92  	}
    93  	s.Height = height
    94  	s.Values = values
    95  	s.Hash = hashItems(values)
    96  	return s.save()
    97  }
    98  
    99  // Get fetches a value. A missing value is returned as an empty string.
   100  func (s *State) Get(key string) string {
   101  	s.RLock()
   102  	defer s.RUnlock()
   103  	return s.Values[key]
   104  }
   105  
   106  // Set sets a value. Setting an empty value is equivalent to deleting it.
   107  func (s *State) Set(key, value string) {
   108  	s.Lock()
   109  	defer s.Unlock()
   110  	if value == "" {
   111  		delete(s.Values, key)
   112  	} else {
   113  		s.Values[key] = value
   114  	}
   115  }
   116  
   117  // Commit commits the current state.
   118  func (s *State) Commit() (uint64, []byte, error) {
   119  	s.Lock()
   120  	defer s.Unlock()
   121  	s.Hash = hashItems(s.Values)
   122  	switch {
   123  	case s.Height > 0:
   124  		s.Height++
   125  	case s.initialHeight > 0:
   126  		s.Height = s.initialHeight
   127  	default:
   128  		s.Height = 1
   129  	}
   130  	if s.persistInterval > 0 && s.Height%s.persistInterval == 0 {
   131  		err := s.save()
   132  		if err != nil {
   133  			return 0, nil, err
   134  		}
   135  	}
   136  	return s.Height, s.Hash, nil
   137  }
   138  
   139  // hashItems hashes a set of key/value items.
   140  func hashItems(items map[string]string) []byte {
   141  	keys := make([]string, 0, len(items))
   142  	for key := range items {
   143  		keys = append(keys, key)
   144  	}
   145  	sort.Strings(keys)
   146  
   147  	hasher := sha256.New()
   148  	for _, key := range keys {
   149  		_, _ = hasher.Write([]byte(key))
   150  		_, _ = hasher.Write([]byte{0})
   151  		_, _ = hasher.Write([]byte(items[key]))
   152  		_, _ = hasher.Write([]byte{0})
   153  	}
   154  	return hasher.Sum(nil)
   155  }