go.charczuk.com@v0.0.0-20240327042549-bc490516bd1a/sdk/shamir/shamir_test.go (about)

     1  /*
     2  
     3  Copyright (c) 2023 - Present. Will Charczuk. All rights reserved.
     4  Use of this source code is governed by a MIT license that can be found in the LICENSE file at the root of the repository.
     5  
     6  */
     7  
     8  package shamir
     9  
    10  import (
    11  	"bytes"
    12  	"fmt"
    13  	"testing"
    14  
    15  	"go.charczuk.com/sdk/assert"
    16  )
    17  
    18  func TestSplitInvalid(t *testing.T) {
    19  	secret := []byte("test")
    20  
    21  	_, err := Split(secret, 0, 0)
    22  	assert.ItsNotNil(t, err)
    23  	_, err = Split(secret, 2, 3)
    24  	assert.ItsNotNil(t, err)
    25  	_, err = Split(secret, 2, 3)
    26  	assert.ItsNotNil(t, err)
    27  	_, err = Split(secret, 1000, 3)
    28  	assert.ItsNotNil(t, err)
    29  	_, err = Split(secret, 10, 1)
    30  	assert.ItsNotNil(t, err)
    31  	_, err = Split(nil, 3, 2)
    32  	assert.ItsNotNil(t, err)
    33  }
    34  
    35  func TestSplit(t *testing.T) {
    36  	secret := []byte("test")
    37  
    38  	out, err := Split(secret, 5, 3)
    39  	assert.ItsNil(t, err)
    40  	assert.ItsLen(t, out, 5)
    41  
    42  	for _, share := range out {
    43  		assert.ItsEqual(t, len(share), len(secret)+1)
    44  	}
    45  }
    46  
    47  func TestCombineInvalid(t *testing.T) {
    48  	_, err := Combine(nil)
    49  	assert.ItsNotNil(t, err)
    50  
    51  	// Mis-match in length
    52  	parts := [][]byte{
    53  		[]byte("foo"),
    54  		[]byte("ba"),
    55  	}
    56  	_, err = Combine(parts)
    57  	assert.ItsNotNil(t, err)
    58  
    59  	//Too short
    60  	parts = [][]byte{
    61  		[]byte("f"),
    62  		[]byte("b"),
    63  	}
    64  	_, err = Combine(parts)
    65  	assert.ItsNotNil(t, err)
    66  
    67  	parts = [][]byte{
    68  		[]byte("foo"),
    69  		[]byte("foo"),
    70  	}
    71  	if _, err := Combine(parts); err == nil {
    72  		t.Fatalf("should err")
    73  	}
    74  }
    75  
    76  func TestCombine(t *testing.T) {
    77  	secret := []byte("test")
    78  
    79  	out, err := Split(secret, 5, 3)
    80  	assert.ItsNil(t, err)
    81  
    82  	// There is 5*4*3 possible choices,
    83  	// we will just brute force try them all
    84  	for i := 0; i < 5; i++ {
    85  		for j := 0; j < 5; j++ {
    86  			if j == i {
    87  				continue
    88  			}
    89  			for k := 0; k < 5; k++ {
    90  				if k == i || k == j {
    91  					continue
    92  				}
    93  				parts := [][]byte{out[i], out[j], out[k]}
    94  				recomb, err := Combine(parts)
    95  				assert.ItsNil(t, err)
    96  				assert.ItsTrue(t, bytes.Equal(recomb, secret), fmt.Sprintf("parts: (i:%d, j:%d, k:%d) %v", i, j, k, parts))
    97  			}
    98  		}
    99  	}
   100  }
   101  
   102  func TestFieldAdd(t *testing.T) {
   103  	if out := add(16, 16); out != 0 {
   104  		t.Fatalf("Bad: %v 16", out)
   105  	}
   106  
   107  	if out := add(3, 4); out != 7 {
   108  		t.Fatalf("Bad: %v 7", out)
   109  	}
   110  }
   111  
   112  func TestFieldMult(t *testing.T) {
   113  	if out := mult(3, 7); out != 9 {
   114  		t.Fatalf("Bad: %v 9", out)
   115  	}
   116  
   117  	if out := mult(3, 0); out != 0 {
   118  		t.Fatalf("Bad: %v 0", out)
   119  	}
   120  
   121  	if out := mult(0, 3); out != 0 {
   122  		t.Fatalf("Bad: %v 0", out)
   123  	}
   124  }
   125  
   126  func TestFieldDivide(t *testing.T) {
   127  	if out := div(0, 7); out != 0 {
   128  		t.Fatalf("Bad: %v 0", out)
   129  	}
   130  
   131  	if out := div(3, 3); out != 1 {
   132  		t.Fatalf("Bad: %v 1", out)
   133  	}
   134  
   135  	if out := div(6, 3); out != 2 {
   136  		t.Fatalf("Bad: %v 2", out)
   137  	}
   138  }
   139  
   140  func TestPolynomialRandom(t *testing.T) {
   141  	p, err := makePolynomial(42, 2)
   142  	if err != nil {
   143  		t.Fatalf("err: %v", err)
   144  	}
   145  
   146  	if p.coefficients[0] != 42 {
   147  		t.Fatalf("bad: %v", p.coefficients)
   148  	}
   149  }
   150  
   151  func TestPolynomialEval(t *testing.T) {
   152  	p, err := makePolynomial(42, 1)
   153  	if err != nil {
   154  		t.Fatalf("err: %v", err)
   155  	}
   156  
   157  	if out := p.evaluate(0); out != 42 {
   158  		t.Fatalf("bad: %v", out)
   159  	}
   160  
   161  	out := p.evaluate(1)
   162  	exp := add(42, mult(1, p.coefficients[1]))
   163  	if out != exp {
   164  		t.Fatalf("bad: %v %v %v", out, exp, p.coefficients)
   165  	}
   166  }
   167  
   168  func TestInterpolateRand(t *testing.T) {
   169  	for i := 0; i < 256; i++ {
   170  		p, err := makePolynomial(uint8(i), 2)
   171  		if err != nil {
   172  			t.Fatalf("err: %v", err)
   173  		}
   174  
   175  		xVals := []uint8{1, 2, 3}
   176  		yVals := []uint8{p.evaluate(1), p.evaluate(2), p.evaluate(3)}
   177  		out := interpolatePolynomial(xVals, yVals, 0)
   178  		if out != uint8(i) {
   179  			t.Fatalf("Bad: %v %d", out, i)
   180  		}
   181  	}
   182  }