github.com/blend/go-sdk@v1.20220411.3/cmd/shamir/main.go (about)

     1  /*
     2  
     3  Copyright (c) 2022 - Present. Blend Labs, Inc. All rights reserved
     4  Use of this source code is governed by a MIT license that can be found in the LICENSE file.
     5  
     6  */
     7  
     8  package main
     9  
    10  import (
    11  	"encoding/hex"
    12  	"fmt"
    13  	"io"
    14  	"os"
    15  	"strings"
    16  
    17  	"github.com/spf13/cobra"
    18  
    19  	"github.com/blend/go-sdk/shamir"
    20  	"github.com/blend/go-sdk/stringutil"
    21  )
    22  
    23  func main() {
    24  	root := &cobra.Command{
    25  		Use:   "shamir",
    26  		Short: "shamir splits and combines secrets into a configurable number of parts",
    27  	}
    28  	root.AddCommand(NewSplitCommand(), NewCombineCommand())
    29  	if err := root.Execute(); err != nil {
    30  		fmt.Fprintln(os.Stderr, err)
    31  		os.Exit(1)
    32  	}
    33  }
    34  
    35  // NewSplitCommand returns a new split command.
    36  func NewSplitCommand() *cobra.Command {
    37  	var parts *int
    38  	var threshold *int
    39  	var inputFile *string
    40  	var secret *string
    41  	split := &cobra.Command{
    42  		Use:   "split",
    43  		Short: "split takes a given input from stdin or a file and separates it into a configurable number of sections",
    44  		Run: func(cmd *cobra.Command, args []string) {
    45  			var contents []byte
    46  			var err error
    47  			if len(*secret) > 0 {
    48  				contents = []byte(strings.TrimSpace(*secret))
    49  			} else if strings.TrimSpace(*inputFile) == "-" {
    50  				contents, err = io.ReadAll(os.Stdin)
    51  			} else {
    52  				contents, err = os.ReadFile(strings.TrimSpace(*inputFile))
    53  			}
    54  			if err != nil {
    55  				fmt.Fprintln(os.Stderr, err)
    56  				os.Exit(1)
    57  			}
    58  			if len(contents) == 0 {
    59  				fmt.Fprintln(os.Stderr, fmt.Errorf("invalid input; is empty"))
    60  				os.Exit(1)
    61  			}
    62  
    63  			shares, err := shamir.Split(contents, *parts, *threshold)
    64  			if err != nil {
    65  				fmt.Fprintln(os.Stderr, err)
    66  				os.Exit(1)
    67  			}
    68  
    69  			for _, share := range shares {
    70  				fmt.Println(hex.EncodeToString(share))
    71  			}
    72  		},
    73  	}
    74  	secret = split.Flags().StringP("secret", "s", "", "the input secret")
    75  	inputFile = split.Flags().StringP("file", "f", "-", "the input file ('-' instructs to read from stdin)")
    76  	parts = split.Flags().IntP("parts", "p", 5, "the number of parts to split the secret into")
    77  	threshold = split.Flags().IntP("threshold", "t", 2, "the number of parts required to form the original secret")
    78  	return split
    79  }
    80  
    81  // NewCombineCommand returns a new combine command.
    82  func NewCombineCommand() *cobra.Command {
    83  	var input *string
    84  	var parts *[]string
    85  	combine := &cobra.Command{
    86  		Use:   "combine",
    87  		Short: "combine takes a shard share and combines it into the final output",
    88  		Run: func(cmd *cobra.Command, args []string) {
    89  			var inputParts [][]byte
    90  			var inputPartsEncoded []string
    91  			if len(*parts) > 0 {
    92  				inputPartsEncoded = *parts
    93  			} else {
    94  				var contents []byte
    95  				var err error
    96  				if strings.TrimSpace(*input) == "-" {
    97  					contents, err = io.ReadAll(os.Stdin)
    98  				} else {
    99  					contents, err = os.ReadFile(strings.TrimSpace(*input))
   100  				}
   101  				if err != nil {
   102  					fmt.Fprintln(os.Stderr, err)
   103  					os.Exit(1)
   104  				}
   105  				inputPartsEncoded = stringutil.SplitLines(string(contents))
   106  			}
   107  
   108  			for _, part := range inputPartsEncoded {
   109  				decoded, err := hex.DecodeString(strings.TrimSpace(part))
   110  				if err != nil {
   111  					fmt.Fprintln(os.Stderr, err)
   112  					os.Exit(1)
   113  				}
   114  				inputParts = append(inputParts, decoded)
   115  			}
   116  
   117  			original, err := shamir.Combine(inputParts)
   118  			if err != nil {
   119  				fmt.Fprintln(os.Stderr, err)
   120  				os.Exit(1)
   121  			}
   122  			fmt.Println(string(original))
   123  		},
   124  	}
   125  	input = combine.Flags().StringP("input", "i", "-", "the input file ('-' instructs to read from stdin)")
   126  	parts = combine.Flags().StringArrayP("part", "p", nil, "individual parts to combine (must include the threshold amount")
   127  	return combine
   128  }