github.com/nats-io/nsc@v0.0.0-20221206222106-35db9400b257/cmd/tool.go (about) 1 /* 2 * Copyright 2018-2019 The NATS Authors 3 * Licensed under the Apache License, Version 2.0 (the "License"); 4 * you may not use this file except in compliance with the License. 5 * You may obtain a copy of the License at 6 * 7 * http://www.apache.org/licenses/LICENSE-2.0 8 * 9 * Unless required by applicable law or agreed to in writing, software 10 * distributed under the License is distributed on an "AS IS" BASIS, 11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 * See the License for the specific language governing permissions and 13 * limitations under the License. 14 */ 15 16 package cmd 17 18 import ( 19 "bytes" 20 "crypto/aes" 21 "crypto/cipher" 22 "crypto/sha256" 23 "encoding/base64" 24 "errors" 25 "fmt" 26 "io" 27 "strings" 28 "time" 29 30 nats "github.com/nats-io/nats.go" 31 "github.com/spf13/cobra" 32 ) 33 34 var toolCmd = &cobra.Command{ 35 Use: "tool", 36 Short: "NATS tools: pub, sub, req, rep, rtt", 37 } 38 39 var natsURLFlag = "" 40 var encryptFlag bool 41 42 func init() { 43 toolCmd.PersistentFlags().StringVarP(&natsURLFlag, "nats", "", "", "nats url, defaults to the operator's service URLs") 44 GetRootCmd().AddCommand(toolCmd) 45 } 46 47 func createDefaultToolOptions(name string, ctx ActionCtx, o ...nats.Option) []nats.Option { 48 connectTimeout := 5 * time.Second 49 totalWait := 10 * time.Minute 50 reconnectDelay := 2 * time.Second 51 52 opts := []nats.Option{nats.Name(name)} 53 opts = append(opts, nats.Timeout(connectTimeout)) 54 opts = append(opts, rootCAsNats) 55 opts = append(opts, tlsKeyNats) 56 opts = append(opts, tlsCertNats) 57 opts = append(opts, nats.ReconnectWait(reconnectDelay)) 58 opts = append(opts, nats.MaxReconnects(int(totalWait/reconnectDelay))) 59 opts = append(opts, nats.DisconnectErrHandler(func(nc *nats.Conn, err error) { 60 if err != nil { 61 ctx.CurrentCmd().Printf("Disconnected: error: %v\n", err) 62 } 63 if nc.Status() == nats.CLOSED { 64 return 65 } 66 ctx.CurrentCmd().Printf("Disconnected: will attempt reconnects for %.0fm", totalWait.Minutes()) 67 })) 68 opts = append(opts, nats.ReconnectHandler(func(nc *nats.Conn) { 69 ctx.CurrentCmd().Printf("Reconnected [%s]", nc.ConnectedUrl()) 70 })) 71 opts = append(opts, nats.ClosedHandler(func(nc *nats.Conn) { 72 if nc.Status() == nats.CLOSED { 73 return 74 } 75 ctx.CurrentCmd().Printf("Exiting, no servers available, or connection closed") 76 })) 77 opts = append(opts, o...) 78 return opts 79 } 80 81 func createCypher(pk string) (cipher.AEAD, error) { 82 // hash the provided private nkey into 32 bytes 83 hash := sha256.Sum256([]byte(pk)) 84 c, err := aes.NewCipher(hash[:32]) 85 if err != nil { 86 return nil, fmt.Errorf("unable to generate cypher: %v", err) 87 } 88 89 // create the symmetric key cipher 90 return cipher.NewGCM(c) 91 } 92 93 func EncryptKV(pk string, data []byte) ([]byte, error) { 94 // source data is <k><space><v> 95 i := bytes.IndexByte(data, ' ') 96 if i == -1 { 97 k, err := Encrypt(pk, data) 98 if err != nil { 99 return nil, err 100 } 101 return k, nil 102 } 103 // kv pair 104 k, err := Encrypt(pk, data[:i]) 105 if err != nil { 106 return nil, err 107 } 108 v, err := Encrypt(pk, data[i+1:]) 109 if err != nil { 110 return nil, err 111 } 112 return bytes.Join([][]byte{k, v}, []byte(" ")), nil 113 } 114 115 func Encrypt(pk string, data []byte) ([]byte, error) { 116 g, err := createCypher(pk) 117 if err != nil { 118 return nil, err 119 } 120 // creates a byte array the size of the nonce required 121 nonce := make([]byte, g.NonceSize()) 122 // seed the nonce with the same seed so that we have predictable encryption 123 if _, err = io.ReadFull(strings.NewReader(pk), nonce); err != nil { 124 return nil, fmt.Errorf("error generating random sequence: %v", err) 125 } 126 // encrypt the data 127 raw := g.Seal(nonce, nonce, data, nil) 128 129 // encode the data 130 var codec = base64.StdEncoding.WithPadding(base64.NoPadding) 131 buf := make([]byte, codec.EncodedLen(len(raw))) 132 codec.Encode(buf, raw) 133 return buf[:], nil 134 } 135 136 func Decrypt(pk string, data []byte) ([]byte, error) { 137 // response payloads may be encrypted or may be lists of values separated by a space 138 if bytes.IndexByte(data, ' ') != -1 { 139 var decoded [][]byte 140 for _, a := range bytes.Split(data, []byte(" ")) { 141 d, err := decrypt(pk, a) 142 if err != nil { 143 return nil, err 144 } 145 decoded = append(decoded, d) 146 } 147 return bytes.Join(decoded, []byte(" ")), nil 148 } else { 149 return decrypt(pk, data) 150 } 151 } 152 153 func decrypt(pk string, data []byte) ([]byte, error) { 154 var codec = base64.StdEncoding.WithPadding(base64.NoPadding) 155 raw := make([]byte, codec.DecodedLen(len(data))) 156 n, err := codec.Decode(raw, data) 157 if err != nil { 158 if _, ok := err.(base64.CorruptInputError); ok { 159 // possibly not encrypted - so just return what we got 160 return data, nil 161 } 162 return nil, err 163 } 164 raw = raw[:n] 165 166 g, err := createCypher(pk) 167 if err != nil { 168 return nil, err 169 } 170 nonceLen := g.NonceSize() 171 if nonceLen > len(raw) { 172 return nil, errors.New("unexpected data length") 173 } 174 nonce, cypher := raw[:nonceLen], raw[nonceLen:] 175 return g.Open(nil, nonce, cypher, nil) 176 }