github.com/nats-io/nsc@v0.0.0-20221206222106-35db9400b257/cmd/tool.go (about)

     1  /*
     2   * Copyright 2018-2019 The NATS Authors
     3   * Licensed under the Apache License, Version 2.0 (the "License");
     4   * you may not use this file except in compliance with the License.
     5   * You may obtain a copy of the License at
     6   *
     7   * http://www.apache.org/licenses/LICENSE-2.0
     8   *
     9   * Unless required by applicable law or agreed to in writing, software
    10   * distributed under the License is distributed on an "AS IS" BASIS,
    11   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12   * See the License for the specific language governing permissions and
    13   * limitations under the License.
    14   */
    15  
    16  package cmd
    17  
    18  import (
    19  	"bytes"
    20  	"crypto/aes"
    21  	"crypto/cipher"
    22  	"crypto/sha256"
    23  	"encoding/base64"
    24  	"errors"
    25  	"fmt"
    26  	"io"
    27  	"strings"
    28  	"time"
    29  
    30  	nats "github.com/nats-io/nats.go"
    31  	"github.com/spf13/cobra"
    32  )
    33  
    34  var toolCmd = &cobra.Command{
    35  	Use:   "tool",
    36  	Short: "NATS tools: pub, sub, req, rep, rtt",
    37  }
    38  
    39  var natsURLFlag = ""
    40  var encryptFlag bool
    41  
    42  func init() {
    43  	toolCmd.PersistentFlags().StringVarP(&natsURLFlag, "nats", "", "", "nats url, defaults to the operator's service URLs")
    44  	GetRootCmd().AddCommand(toolCmd)
    45  }
    46  
    47  func createDefaultToolOptions(name string, ctx ActionCtx, o ...nats.Option) []nats.Option {
    48  	connectTimeout := 5 * time.Second
    49  	totalWait := 10 * time.Minute
    50  	reconnectDelay := 2 * time.Second
    51  
    52  	opts := []nats.Option{nats.Name(name)}
    53  	opts = append(opts, nats.Timeout(connectTimeout))
    54  	opts = append(opts, rootCAsNats)
    55  	opts = append(opts, tlsKeyNats)
    56  	opts = append(opts, tlsCertNats)
    57  	opts = append(opts, nats.ReconnectWait(reconnectDelay))
    58  	opts = append(opts, nats.MaxReconnects(int(totalWait/reconnectDelay)))
    59  	opts = append(opts, nats.DisconnectErrHandler(func(nc *nats.Conn, err error) {
    60  		if err != nil {
    61  			ctx.CurrentCmd().Printf("Disconnected: error: %v\n", err)
    62  		}
    63  		if nc.Status() == nats.CLOSED {
    64  			return
    65  		}
    66  		ctx.CurrentCmd().Printf("Disconnected: will attempt reconnects for %.0fm", totalWait.Minutes())
    67  	}))
    68  	opts = append(opts, nats.ReconnectHandler(func(nc *nats.Conn) {
    69  		ctx.CurrentCmd().Printf("Reconnected [%s]", nc.ConnectedUrl())
    70  	}))
    71  	opts = append(opts, nats.ClosedHandler(func(nc *nats.Conn) {
    72  		if nc.Status() == nats.CLOSED {
    73  			return
    74  		}
    75  		ctx.CurrentCmd().Printf("Exiting, no servers available, or connection closed")
    76  	}))
    77  	opts = append(opts, o...)
    78  	return opts
    79  }
    80  
    81  func createCypher(pk string) (cipher.AEAD, error) {
    82  	// hash the provided private nkey into 32 bytes
    83  	hash := sha256.Sum256([]byte(pk))
    84  	c, err := aes.NewCipher(hash[:32])
    85  	if err != nil {
    86  		return nil, fmt.Errorf("unable to generate cypher: %v", err)
    87  	}
    88  
    89  	// create the symmetric key cipher
    90  	return cipher.NewGCM(c)
    91  }
    92  
    93  func EncryptKV(pk string, data []byte) ([]byte, error) {
    94  	// source data is <k><space><v>
    95  	i := bytes.IndexByte(data, ' ')
    96  	if i == -1 {
    97  		k, err := Encrypt(pk, data)
    98  		if err != nil {
    99  			return nil, err
   100  		}
   101  		return k, nil
   102  	}
   103  	// kv pair
   104  	k, err := Encrypt(pk, data[:i])
   105  	if err != nil {
   106  		return nil, err
   107  	}
   108  	v, err := Encrypt(pk, data[i+1:])
   109  	if err != nil {
   110  		return nil, err
   111  	}
   112  	return bytes.Join([][]byte{k, v}, []byte(" ")), nil
   113  }
   114  
   115  func Encrypt(pk string, data []byte) ([]byte, error) {
   116  	g, err := createCypher(pk)
   117  	if err != nil {
   118  		return nil, err
   119  	}
   120  	// creates a byte array the size of the nonce required
   121  	nonce := make([]byte, g.NonceSize())
   122  	// seed the nonce with the same seed so that we have predictable encryption
   123  	if _, err = io.ReadFull(strings.NewReader(pk), nonce); err != nil {
   124  		return nil, fmt.Errorf("error generating random sequence: %v", err)
   125  	}
   126  	// encrypt the data
   127  	raw := g.Seal(nonce, nonce, data, nil)
   128  
   129  	// encode the data
   130  	var codec = base64.StdEncoding.WithPadding(base64.NoPadding)
   131  	buf := make([]byte, codec.EncodedLen(len(raw)))
   132  	codec.Encode(buf, raw)
   133  	return buf[:], nil
   134  }
   135  
   136  func Decrypt(pk string, data []byte) ([]byte, error) {
   137  	// response payloads may be encrypted or may be lists of values separated by a space
   138  	if bytes.IndexByte(data, ' ') != -1 {
   139  		var decoded [][]byte
   140  		for _, a := range bytes.Split(data, []byte(" ")) {
   141  			d, err := decrypt(pk, a)
   142  			if err != nil {
   143  				return nil, err
   144  			}
   145  			decoded = append(decoded, d)
   146  		}
   147  		return bytes.Join(decoded, []byte(" ")), nil
   148  	} else {
   149  		return decrypt(pk, data)
   150  	}
   151  }
   152  
   153  func decrypt(pk string, data []byte) ([]byte, error) {
   154  	var codec = base64.StdEncoding.WithPadding(base64.NoPadding)
   155  	raw := make([]byte, codec.DecodedLen(len(data)))
   156  	n, err := codec.Decode(raw, data)
   157  	if err != nil {
   158  		if _, ok := err.(base64.CorruptInputError); ok {
   159  			// possibly not encrypted - so just return what we got
   160  			return data, nil
   161  		}
   162  		return nil, err
   163  	}
   164  	raw = raw[:n]
   165  
   166  	g, err := createCypher(pk)
   167  	if err != nil {
   168  		return nil, err
   169  	}
   170  	nonceLen := g.NonceSize()
   171  	if nonceLen > len(raw) {
   172  		return nil, errors.New("unexpected data length")
   173  	}
   174  	nonce, cypher := raw[:nonceLen], raw[nonceLen:]
   175  	return g.Open(nil, nonce, cypher, nil)
   176  }