github.com/blend/go-sdk@v1.20220411.3/shamir/shamir.go (about)

     1  /*
     2  
     3  Copyright (c) 2022 - Present. Blend Labs, Inc. All rights reserved
     4  Use of this source code is governed by a MIT license that can be found in the LICENSE file.
     5  
     6  */
     7  
     8  package shamir
     9  
    10  import (
    11  	"crypto/rand"
    12  	"crypto/subtle"
    13  	"fmt"
    14  	mathrand "math/rand"
    15  	"time"
    16  
    17  	"github.com/blend/go-sdk/ex"
    18  )
    19  
    20  const (
    21  	// ShareOverhead is the byte size overhead of each share
    22  	// when using Split on a secret. This is caused by appending
    23  	// a one byte tag to the share.
    24  	ShareOverhead = 1
    25  )
    26  
    27  // Split takes an arbitrarily long secret and generates a `parts`
    28  // number of shares, `threshold` of which are required to reconstruct
    29  // the secret. The parts and threshold must be at least 2, and less
    30  // than 256. The returned shares are each one byte longer than the secret
    31  // as they attach a tag used to reconstruct the secret.
    32  func Split(secret []byte, parts, threshold int) ([][]byte, error) {
    33  	// Sanity check the input
    34  	if parts < threshold {
    35  		return nil, fmt.Errorf("parts cannot be less than threshold")
    36  	}
    37  	if parts > 255 {
    38  		return nil, fmt.Errorf("parts cannot exceed 255")
    39  	}
    40  	if threshold < 2 {
    41  		return nil, fmt.Errorf("threshold must be at least 2")
    42  	}
    43  	if threshold > 255 {
    44  		return nil, fmt.Errorf("threshold cannot exceed 255")
    45  	}
    46  	if len(secret) == 0 {
    47  		return nil, fmt.Errorf("cannot split an empty secret")
    48  	}
    49  
    50  	// Generate random list of x coordinates
    51  	mathrand.Seed(time.Now().UnixNano())
    52  	xCoordinates := mathrand.Perm(255)
    53  
    54  	// Allocate the output array, initialize the final byte
    55  	// of the output with the offset. The representation of each
    56  	// output is {y1, y2, .., yN, x}.
    57  	out := make([][]byte, parts)
    58  	for idx := range out {
    59  		out[idx] = make([]byte, len(secret)+1)
    60  		out[idx][len(secret)] = uint8(xCoordinates[idx]) + 1
    61  	}
    62  
    63  	// Construct a random polynomial for each byte of the secret.
    64  	// Because we are using a field of size 256, we can only represent
    65  	// a single byte as the intercept of the polynomial, so we must
    66  	// use a new polynomial for each byte.
    67  	for idx, val := range secret {
    68  		p, err := makePolynomial(val, uint8(threshold-1))
    69  		if err != nil {
    70  			return nil, ex.New("failed to generate polynomial", ex.OptInner(err))
    71  		}
    72  
    73  		// Generate a `parts` number of (x,y) pairs
    74  		// We cheat by encoding the x value once as the final index,
    75  		// so that it only needs to be stored once.
    76  		for i := 0; i < parts; i++ {
    77  			x := uint8(xCoordinates[i]) + 1
    78  			y := p.evaluate(x)
    79  			out[i][idx] = y
    80  		}
    81  	}
    82  
    83  	// Return the encoded secrets
    84  	return out, nil
    85  }
    86  
    87  // Combine is used to reverse a Split and reconstruct a secret
    88  // once a `threshold` number of parts are available.
    89  func Combine(parts [][]byte) ([]byte, error) {
    90  	// Verify enough parts provided
    91  	if len(parts) < 2 {
    92  		return nil, fmt.Errorf("less than two parts cannot be used to reconstruct the secret")
    93  	}
    94  
    95  	// Verify the parts are all the same length
    96  	firstPartLen := len(parts[0])
    97  	if firstPartLen < 2 {
    98  		return nil, fmt.Errorf("parts must be at least two bytes")
    99  	}
   100  	for i := 1; i < len(parts); i++ {
   101  		if len(parts[i]) != firstPartLen {
   102  			return nil, fmt.Errorf("all parts must be the same length")
   103  		}
   104  	}
   105  
   106  	// Create a buffer to store the reconstructed secret
   107  	secret := make([]byte, firstPartLen-1)
   108  
   109  	// Buffer to store the samples
   110  	xSamples := make([]uint8, len(parts))
   111  	ySamples := make([]uint8, len(parts))
   112  
   113  	// Set the x value for each sample and ensure no x_sample values are the same,
   114  	// otherwise div() can be unhappy
   115  	checkMap := map[byte]bool{}
   116  	for i, part := range parts {
   117  		samp := part[firstPartLen-1]
   118  		if exists := checkMap[samp]; exists {
   119  			return nil, fmt.Errorf("duplicate part detected")
   120  		}
   121  		checkMap[samp] = true
   122  		xSamples[i] = samp
   123  	}
   124  
   125  	// Reconstruct each byte
   126  	for idx := range secret {
   127  		// Set the y value for each sample
   128  		for i, part := range parts {
   129  			ySamples[i] = part[idx]
   130  		}
   131  
   132  		// Interpolate the polynomial and compute the value at 0
   133  		val := interpolatePolynomial(xSamples, ySamples, 0)
   134  
   135  		// Evaluate the 0th value to get the intercept
   136  		secret[idx] = val
   137  	}
   138  	return secret, nil
   139  }
   140  
   141  // polynomial represents a polynomial of arbitrary degree
   142  type polynomial struct {
   143  	coefficients []uint8
   144  }
   145  
   146  // makePolynomial constructs a random polynomial of the given
   147  // degree but with the provided intercept value.
   148  func makePolynomial(intercept, degree uint8) (polynomial, error) {
   149  	// Create a wrapper
   150  	p := polynomial{
   151  		coefficients: make([]byte, degree+1),
   152  	}
   153  
   154  	// Ensure the intercept is set
   155  	p.coefficients[0] = intercept
   156  
   157  	// Assign random co-efficients to the polynomial
   158  	if _, err := rand.Read(p.coefficients[1:]); err != nil {
   159  		return p, err
   160  	}
   161  
   162  	return p, nil
   163  }
   164  
   165  // evaluate returns the value of the polynomial for the given x
   166  func (p *polynomial) evaluate(x uint8) uint8 {
   167  	// Special case the origin
   168  	if x == 0 {
   169  		return p.coefficients[0]
   170  	}
   171  
   172  	// Compute the polynomial value using Horner's method.
   173  	degree := len(p.coefficients) - 1
   174  	out := p.coefficients[degree]
   175  	for i := degree - 1; i >= 0; i-- {
   176  		coeff := p.coefficients[i]
   177  		out = add(mult(out, x), coeff)
   178  	}
   179  	return out
   180  }
   181  
   182  // interpolatePolynomial takes N sample points and returns
   183  // the value at a given x using a lagrange interpolation.
   184  func interpolatePolynomial(xSamples, ySamples []uint8, x uint8) uint8 {
   185  	limit := len(xSamples)
   186  	var result, basis uint8
   187  	for i := 0; i < limit; i++ {
   188  		basis = 1
   189  		for j := 0; j < limit; j++ {
   190  			if i == j {
   191  				continue
   192  			}
   193  			num := add(x, xSamples[j])
   194  			denom := add(xSamples[i], xSamples[j])
   195  			term := div(num, denom)
   196  			basis = mult(basis, term)
   197  		}
   198  		group := mult(ySamples[i], basis)
   199  		result = add(result, group)
   200  	}
   201  	return result
   202  }
   203  
   204  // div divides two numbers in GF(2^8)
   205  func div(a, b uint8) uint8 {
   206  	if b == 0 {
   207  		// leaks some timing information but we don't care anyways as this
   208  		// should never happen, hence the panic
   209  		panic("divide by zero")
   210  	}
   211  
   212  	var goodVal, zero uint8
   213  	logA := logTable[a]
   214  	logB := logTable[b]
   215  	diff := (int(logA) - int(logB)) % 255
   216  	if diff < 0 {
   217  		diff += 255
   218  	}
   219  
   220  	ret := expTable[diff]
   221  
   222  	// Ensure we return zero if a is zero but aren't subject to timing attacks
   223  	goodVal = ret
   224  
   225  	if subtle.ConstantTimeByteEq(a, 0) == 1 {
   226  		ret = zero
   227  	} else {
   228  		ret = goodVal
   229  	}
   230  
   231  	return ret
   232  }
   233  
   234  // mult multiplies two numbers in GF(2^8)
   235  func mult(a, b uint8) (out uint8) {
   236  	var goodVal, zero uint8
   237  	logA := logTable[a]
   238  	logB := logTable[b]
   239  	sum := (int(logA) + int(logB)) % 255
   240  
   241  	ret := expTable[sum]
   242  
   243  	// Ensure we return zero if either a or b are zero but aren't subject to
   244  	// timing attacks
   245  	goodVal = ret
   246  
   247  	if subtle.ConstantTimeByteEq(a, 0) == 1 {
   248  		ret = zero
   249  	} else {
   250  		ret = goodVal
   251  	}
   252  
   253  	if subtle.ConstantTimeByteEq(b, 0) == 1 {
   254  		ret = zero
   255  	} else {
   256  		// This operation does not do anything logically useful. It
   257  		// only ensures a constant number of assignments to thwart
   258  		// timing attacks.
   259  
   260  		//nolint:ineffassign
   261  		goodVal = zero
   262  	}
   263  
   264  	return ret
   265  }
   266  
   267  // add combines two numbers in GF(2^8)
   268  // This can also be used for subtraction since it is symmetric.
   269  func add(a, b uint8) uint8 {
   270  	return a ^ b
   271  }