go.charczuk.com@v0.0.0-20240327042549-bc490516bd1a/cmd/shamir/main.go (about)

     1  /*
     2  
     3  Copyright (c) 2023 - Present. Will Charczuk. All rights reserved.
     4  Use of this source code is governed by a MIT license that can be found in the LICENSE file at the root of the repository.
     5  
     6  */
     7  
     8  package main
     9  
    10  import (
    11  	"encoding/hex"
    12  	"fmt"
    13  	"io"
    14  	"os"
    15  	"strings"
    16  
    17  	"github.com/urfave/cli/v2"
    18  
    19  	"go.charczuk.com/sdk/cliutil"
    20  	"go.charczuk.com/sdk/shamir"
    21  	"go.charczuk.com/sdk/stringutil"
    22  )
    23  
    24  func main() {
    25  	root := &cli.App{
    26  		Name:  "shamir",
    27  		Usage: "shamir splits and combines secrets into a configurable number of parts",
    28  		Commands: []*cli.Command{
    29  			split,
    30  			combine,
    31  		},
    32  	}
    33  	if err := root.Run(os.Args); err != nil {
    34  		cliutil.Fatal(err)
    35  	}
    36  }
    37  
    38  var split = &cli.Command{
    39  	Name:  "split",
    40  	Usage: "split takes a given input from stdin or a file and separates it into a configurable number of sections",
    41  	Flags: []cli.Flag{
    42  		&cli.StringFlag{Name: "secret", Aliases: []string{"s"}, Usage: "the input secret"},
    43  		&cli.StringFlag{Name: "file", Aliases: []string{"f"}, Value: "-", Usage: "the input file ('-' instructs to read from stdin)"},
    44  		&cli.IntFlag{Name: "parts", Aliases: []string{"p"}, Value: 5, Usage: "the number of parts to split the secret into"},
    45  		&cli.IntFlag{Name: "threshold", Aliases: []string{"t"}, Value: 2, Usage: "the number of parts required to form the original secret"},
    46  	},
    47  	Action: func(ctx *cli.Context) error {
    48  		var secret = ctx.String("secret")
    49  		var threshold = ctx.Int("threshold")
    50  		var parts = ctx.Int("parts")
    51  		var inputFile = ctx.String("file")
    52  
    53  		var contents []byte
    54  		var err error
    55  		if len(secret) > 0 {
    56  			contents = []byte(strings.TrimSpace(secret))
    57  		} else if strings.TrimSpace(inputFile) == "-" {
    58  			contents, err = io.ReadAll(os.Stdin)
    59  		} else {
    60  			contents, err = os.ReadFile(strings.TrimSpace(inputFile))
    61  		}
    62  		if err != nil {
    63  			return err
    64  		}
    65  		if len(contents) == 0 {
    66  			return fmt.Errorf("invalid input; is empty")
    67  		}
    68  
    69  		shares, err := shamir.Split(contents, parts, threshold)
    70  		if err != nil {
    71  			return err
    72  		}
    73  		for _, share := range shares {
    74  			fmt.Fprintln(os.Stdout, hex.EncodeToString(share))
    75  		}
    76  		return nil
    77  	},
    78  }
    79  
    80  var combine = &cli.Command{
    81  	Name:  "combine",
    82  	Usage: "combine takes a shard share and combines it into the final output",
    83  	Flags: []cli.Flag{
    84  		&cli.StringFlag{Name: "input", Aliases: []string{"i"}, Value: "-", Usage: "the input file ('-' instructs to read from stdin)"},
    85  		&cli.StringSliceFlag{Name: "part", Aliases: []string{"p"}, Value: nil, Usage: "individual parts to combine (must include the threshold amount)"},
    86  	},
    87  	Action: func(ctx *cli.Context) error {
    88  		var input = ctx.String("input")
    89  		var parts = ctx.StringSlice("parts")
    90  
    91  		var inputParts [][]byte
    92  		var inputPartsEncoded []string
    93  		if len(parts) > 0 {
    94  			inputPartsEncoded = parts
    95  		} else {
    96  			var contents []byte
    97  			var err error
    98  			if strings.TrimSpace(input) == "-" {
    99  				contents, err = io.ReadAll(os.Stdin)
   100  			} else {
   101  				contents, err = os.ReadFile(strings.TrimSpace(input))
   102  			}
   103  			if err != nil {
   104  				return err
   105  			}
   106  			inputPartsEncoded = stringutil.SplitLines(string(contents))
   107  		}
   108  
   109  		for _, part := range inputPartsEncoded {
   110  			decoded, err := hex.DecodeString(strings.TrimSpace(part))
   111  			if err != nil {
   112  				return err
   113  			}
   114  			inputParts = append(inputParts, decoded)
   115  		}
   116  
   117  		original, err := shamir.Combine(inputParts)
   118  		if err != nil {
   119  			return err
   120  		}
   121  		fmt.Println(string(original))
   122  		return nil
   123  	},
   124  }