github.com/nspcc-dev/neo-go@v0.105.2-0.20240517133400-6be757af3eba/pkg/crypto/keys/nep2.go (about)

     1  package keys
     2  
     3  import (
     4  	"bytes"
     5  	"errors"
     6  	"fmt"
     7  
     8  	"github.com/nspcc-dev/neo-go/pkg/crypto/hash"
     9  	"github.com/nspcc-dev/neo-go/pkg/encoding/base58"
    10  	"github.com/nspcc-dev/neo-go/pkg/util/slice"
    11  	"golang.org/x/crypto/scrypt"
    12  	"golang.org/x/text/unicode/norm"
    13  )
    14  
    15  // NEP-2 standard implementation for encrypting and decrypting private keys.
    16  
    17  // NEP-2 specified parameters used for cryptography.
    18  const (
    19  	n       = 16384
    20  	r       = 8
    21  	p       = 8
    22  	keyLen  = 64
    23  	nepFlag = 0xe0
    24  )
    25  
    26  var nepHeader = []byte{0x01, 0x42}
    27  
    28  // ScryptParams is a json-serializable container for scrypt KDF parameters.
    29  type ScryptParams struct {
    30  	N int `json:"n"`
    31  	R int `json:"r"`
    32  	P int `json:"p"`
    33  }
    34  
    35  // NEP2ScryptParams returns scrypt parameters specified in the NEP-2.
    36  func NEP2ScryptParams() ScryptParams {
    37  	return ScryptParams{
    38  		N: n,
    39  		R: r,
    40  		P: p,
    41  	}
    42  }
    43  
    44  // NEP2Encrypt encrypts a the PrivateKey using the given passphrase
    45  // under the NEP-2 standard.
    46  func NEP2Encrypt(priv *PrivateKey, passphrase string, params ScryptParams) (s string, err error) {
    47  	address := priv.Address()
    48  
    49  	addrHash := hash.Checksum([]byte(address))
    50  	// Normalize the passphrase according to the NFC standard.
    51  	phraseNorm := norm.NFC.Bytes([]byte(passphrase))
    52  	derivedKey, err := scrypt.Key(phraseNorm, addrHash, params.N, params.R, params.P, keyLen)
    53  	if err != nil {
    54  		return s, err
    55  	}
    56  	defer slice.Clean(derivedKey)
    57  
    58  	derivedKey1 := derivedKey[:32]
    59  	derivedKey2 := derivedKey[32:]
    60  
    61  	privBytes := priv.Bytes()
    62  	defer slice.Clean(privBytes)
    63  	xr := xor(privBytes, derivedKey1)
    64  	defer slice.Clean(xr)
    65  
    66  	encrypted, err := aesEncrypt(xr, derivedKey2)
    67  	if err != nil {
    68  		return s, err
    69  	}
    70  
    71  	buf := new(bytes.Buffer)
    72  	buf.Write(nepHeader)
    73  	buf.WriteByte(nepFlag)
    74  	buf.Write(addrHash)
    75  	buf.Write(encrypted)
    76  
    77  	if buf.Len() != 39 {
    78  		return s, fmt.Errorf("invalid buffer length: expecting 39 bytes got %d", buf.Len())
    79  	}
    80  
    81  	return base58.CheckEncode(buf.Bytes()), nil
    82  }
    83  
    84  // NEP2Decrypt decrypts an encrypted key using the given passphrase
    85  // under the NEP-2 standard.
    86  func NEP2Decrypt(key, passphrase string, params ScryptParams) (*PrivateKey, error) {
    87  	b, err := base58.CheckDecode(key)
    88  	if err != nil {
    89  		return nil, err
    90  	}
    91  	if err := validateNEP2Format(b); err != nil {
    92  		return nil, err
    93  	}
    94  
    95  	addrHash := b[3:7]
    96  	// Normalize the passphrase according to the NFC standard.
    97  	phraseNorm := norm.NFC.Bytes([]byte(passphrase))
    98  	derivedKey, err := scrypt.Key(phraseNorm, addrHash, params.N, params.R, params.P, keyLen)
    99  	if err != nil {
   100  		return nil, err
   101  	}
   102  	defer slice.Clean(derivedKey)
   103  
   104  	derivedKey1 := derivedKey[:32]
   105  	derivedKey2 := derivedKey[32:]
   106  	encryptedBytes := b[7:]
   107  
   108  	decrypted, err := aesDecrypt(encryptedBytes, derivedKey2)
   109  	if err != nil {
   110  		return nil, err
   111  	}
   112  	defer slice.Clean(decrypted)
   113  
   114  	privBytes := xor(decrypted, derivedKey1)
   115  	defer slice.Clean(privBytes)
   116  
   117  	// Rebuild the private key.
   118  	privKey, err := NewPrivateKeyFromBytes(privBytes)
   119  	if err != nil {
   120  		return nil, err
   121  	}
   122  
   123  	if !compareAddressHash(privKey, addrHash) {
   124  		return nil, errors.New("password mismatch")
   125  	}
   126  
   127  	return privKey, nil
   128  }
   129  
   130  func compareAddressHash(priv *PrivateKey, inhash []byte) bool {
   131  	address := priv.Address()
   132  	addrHash := hash.Checksum([]byte(address))
   133  	return bytes.Equal(addrHash, inhash)
   134  }
   135  
   136  func validateNEP2Format(b []byte) error {
   137  	if len(b) != 39 {
   138  		return fmt.Errorf("invalid length: expecting 39 got %d", len(b))
   139  	}
   140  	if b[0] != 0x01 {
   141  		return fmt.Errorf("invalid byte sequence: expecting 0x01 got 0x%02x", b[0])
   142  	}
   143  	if b[1] != 0x42 {
   144  		return fmt.Errorf("invalid byte sequence: expecting 0x42 got 0x%02x", b[1])
   145  	}
   146  	if b[2] != 0xe0 {
   147  		return fmt.Errorf("invalid byte sequence: expecting 0xe0 got 0x%02x", b[2])
   148  	}
   149  	return nil
   150  }
   151  
   152  func xor(a, b []byte) []byte {
   153  	if len(a) != len(b) {
   154  		panic("cannot XOR non equal length arrays")
   155  	}
   156  	dst := make([]byte, len(a))
   157  	for i := 0; i < len(dst); i++ {
   158  		dst[i] = a[i] ^ b[i]
   159  	}
   160  	return dst
   161  }