git.gammaspectra.live/P2Pool/consensus/v3@v3.8.0/monero/crypto/crypto_test.go (about)

     1  package crypto
     2  
     3  import (
     4  	"git.gammaspectra.live/P2Pool/consensus/v3/types"
     5  	fasthex "github.com/tmthrgd/go-hex"
     6  	"os"
     7  	"path"
     8  	"runtime"
     9  	"strconv"
    10  	"strings"
    11  	"testing"
    12  )
    13  
    14  func init() {
    15  	_, filename, _, _ := runtime.Caller(0)
    16  	// The ".." may change depending on you folder structure
    17  	dir := path.Join(path.Dir(filename), "../..")
    18  	err := os.Chdir(dir)
    19  	if err != nil {
    20  		panic(err)
    21  	}
    22  
    23  }
    24  
    25  func GetTestEntries(name string, n int) chan []string {
    26  	buf, err := os.ReadFile("testdata/crypto_tests.txt")
    27  	if err != nil {
    28  		return nil
    29  	}
    30  	result := make(chan []string)
    31  	go func() {
    32  		defer close(result)
    33  		for _, line := range strings.Split(string(buf), "\n") {
    34  			entries := strings.Split(strings.TrimSpace(line), " ")
    35  			if entries[0] == name && len(entries) >= (n+1) {
    36  				result <- entries[1:]
    37  			}
    38  		}
    39  	}()
    40  	return result
    41  }
    42  
    43  func TestGenerateKeyDerivation(t *testing.T) {
    44  	results := GetTestEntries("generate_key_derivation", 3)
    45  	if results == nil {
    46  		t.Fatal()
    47  	}
    48  	for e := range results {
    49  		var expectedDerivation types.Hash
    50  
    51  		key1 := PublicKeyBytes(types.MustHashFromString(e[0]))
    52  		key2 := PrivateKeyBytes(types.MustHashFromString(e[1]))
    53  		result := e[2] == "true"
    54  		if result {
    55  			expectedDerivation = types.MustHashFromString(e[3])
    56  		}
    57  
    58  		point := key1.AsPoint()
    59  		scalar := key2.AsScalar()
    60  
    61  		if result == false && (point == nil || scalar == nil) {
    62  			//expected failure
    63  			continue
    64  		} else if point == nil || scalar == nil {
    65  			t.Fatalf("invalid point %s / scalar %s", key1.String(), key2.String())
    66  		}
    67  
    68  		derivation := scalar.GetDerivationCofactor(point)
    69  		if result {
    70  			if expectedDerivation.String() != derivation.String() {
    71  				t.Fatalf("expected %s, got %s", expectedDerivation.String(), derivation.String())
    72  			}
    73  		}
    74  	}
    75  }
    76  
    77  func TestDeriveViewTag(t *testing.T) {
    78  	results := GetTestEntries("derive_view_tag", 3)
    79  	if results == nil {
    80  		t.Fatal()
    81  	}
    82  
    83  	hasher := GetKeccak256Hasher()
    84  	defer PutKeccak256Hasher(hasher)
    85  
    86  	for e := range results {
    87  		derivation := PublicKeyBytes(types.MustHashFromString(e[0]))
    88  		outputIndex, _ := strconv.ParseUint(e[1], 10, 0)
    89  		result, _ := fasthex.DecodeString(e[2])
    90  
    91  		viewTag := GetDerivationViewTagForOutputIndex(&derivation, outputIndex)
    92  
    93  		_, viewTag2 := GetDerivationSharedDataAndViewTagForOutputIndexNoAllocate(derivation, outputIndex, hasher)
    94  
    95  		if viewTag != viewTag2 {
    96  			t.Errorf("derive_view_tag differs from no_allocate: %d != %d", viewTag, &viewTag2)
    97  		}
    98  
    99  		if result[0] != viewTag {
   100  			t.Errorf("expected %s, got %s", fasthex.EncodeToString(result), fasthex.EncodeToString([]byte{viewTag}))
   101  		}
   102  	}
   103  }