github.com/blend/go-sdk@v1.20220411.3/shamir/shamir_test.go (about)

     1  /*
     2  
     3  Copyright (c) 2022 - Present. Blend Labs, Inc. All rights reserved
     4  Use of this source code is governed by a MIT license that can be found in the LICENSE file.
     5  
     6  */
     7  
     8  package shamir
     9  
    10  import (
    11  	"bytes"
    12  	"fmt"
    13  	"testing"
    14  
    15  	"github.com/blend/go-sdk/assert"
    16  )
    17  
    18  func TestSplitInvalid(t *testing.T) {
    19  	assert := assert.New(t)
    20  
    21  	secret := []byte("test")
    22  
    23  	_, err := Split(secret, 0, 0)
    24  	assert.NotNil(err)
    25  	_, err = Split(secret, 2, 3)
    26  	assert.NotNil(err)
    27  	_, err = Split(secret, 2, 3)
    28  	assert.NotNil(err)
    29  	_, err = Split(secret, 1000, 3)
    30  	assert.NotNil(err)
    31  	_, err = Split(secret, 10, 1)
    32  	assert.NotNil(err)
    33  	_, err = Split(nil, 3, 2)
    34  	assert.NotNil(err)
    35  }
    36  
    37  func TestSplit(t *testing.T) {
    38  	assert := assert.New(t)
    39  
    40  	secret := []byte("test")
    41  
    42  	out, err := Split(secret, 5, 3)
    43  	assert.Nil(err)
    44  	assert.Len(out, 5)
    45  
    46  	for _, share := range out {
    47  		assert.Equal(len(share), len(secret)+1)
    48  	}
    49  }
    50  
    51  func TestCombineInvalid(t *testing.T) {
    52  	assert := assert.New(t)
    53  
    54  	_, err := Combine(nil)
    55  	assert.NotNil(err)
    56  
    57  	// Mis-match in length
    58  	parts := [][]byte{
    59  		[]byte("foo"),
    60  		[]byte("ba"),
    61  	}
    62  	_, err = Combine(parts)
    63  	assert.NotNil(err)
    64  
    65  	//Too short
    66  	parts = [][]byte{
    67  		[]byte("f"),
    68  		[]byte("b"),
    69  	}
    70  	_, err = Combine(parts)
    71  	assert.NotNil(err)
    72  
    73  	parts = [][]byte{
    74  		[]byte("foo"),
    75  		[]byte("foo"),
    76  	}
    77  	if _, err := Combine(parts); err == nil {
    78  		t.Fatalf("should err")
    79  	}
    80  }
    81  
    82  func TestCombine(t *testing.T) {
    83  	assert := assert.New(t)
    84  
    85  	secret := []byte("test")
    86  
    87  	out, err := Split(secret, 5, 3)
    88  	assert.Nil(err)
    89  
    90  	// There is 5*4*3 possible choices,
    91  	// we will just brute force try them all
    92  	for i := 0; i < 5; i++ {
    93  		for j := 0; j < 5; j++ {
    94  			if j == i {
    95  				continue
    96  			}
    97  			for k := 0; k < 5; k++ {
    98  				if k == i || k == j {
    99  					continue
   100  				}
   101  				parts := [][]byte{out[i], out[j], out[k]}
   102  				recomb, err := Combine(parts)
   103  				assert.Nil(err)
   104  				assert.True(bytes.Equal(recomb, secret), fmt.Sprintf("parts: (i:%d, j:%d, k:%d) %v", i, j, k, parts))
   105  			}
   106  		}
   107  	}
   108  }
   109  
   110  func TestFieldAdd(t *testing.T) {
   111  	if out := add(16, 16); out != 0 {
   112  		t.Fatalf("Bad: %v 16", out)
   113  	}
   114  
   115  	if out := add(3, 4); out != 7 {
   116  		t.Fatalf("Bad: %v 7", out)
   117  	}
   118  }
   119  
   120  func TestFieldMult(t *testing.T) {
   121  	if out := mult(3, 7); out != 9 {
   122  		t.Fatalf("Bad: %v 9", out)
   123  	}
   124  
   125  	if out := mult(3, 0); out != 0 {
   126  		t.Fatalf("Bad: %v 0", out)
   127  	}
   128  
   129  	if out := mult(0, 3); out != 0 {
   130  		t.Fatalf("Bad: %v 0", out)
   131  	}
   132  }
   133  
   134  func TestFieldDivide(t *testing.T) {
   135  	if out := div(0, 7); out != 0 {
   136  		t.Fatalf("Bad: %v 0", out)
   137  	}
   138  
   139  	if out := div(3, 3); out != 1 {
   140  		t.Fatalf("Bad: %v 1", out)
   141  	}
   142  
   143  	if out := div(6, 3); out != 2 {
   144  		t.Fatalf("Bad: %v 2", out)
   145  	}
   146  }
   147  
   148  func TestPolynomialRandom(t *testing.T) {
   149  	p, err := makePolynomial(42, 2)
   150  	if err != nil {
   151  		t.Fatalf("err: %v", err)
   152  	}
   153  
   154  	if p.coefficients[0] != 42 {
   155  		t.Fatalf("bad: %v", p.coefficients)
   156  	}
   157  }
   158  
   159  func TestPolynomialEval(t *testing.T) {
   160  	p, err := makePolynomial(42, 1)
   161  	if err != nil {
   162  		t.Fatalf("err: %v", err)
   163  	}
   164  
   165  	if out := p.evaluate(0); out != 42 {
   166  		t.Fatalf("bad: %v", out)
   167  	}
   168  
   169  	out := p.evaluate(1)
   170  	exp := add(42, mult(1, p.coefficients[1]))
   171  	if out != exp {
   172  		t.Fatalf("bad: %v %v %v", out, exp, p.coefficients)
   173  	}
   174  }
   175  
   176  func TestInterpolateRand(t *testing.T) {
   177  	for i := 0; i < 256; i++ {
   178  		p, err := makePolynomial(uint8(i), 2)
   179  		if err != nil {
   180  			t.Fatalf("err: %v", err)
   181  		}
   182  
   183  		xVals := []uint8{1, 2, 3}
   184  		yVals := []uint8{p.evaluate(1), p.evaluate(2), p.evaluate(3)}
   185  		out := interpolatePolynomial(xVals, yVals, 0)
   186  		if out != uint8(i) {
   187  			t.Fatalf("Bad: %v %d", out, i)
   188  		}
   189  	}
   190  }