github.com/kbehouse/nsc@v0.0.6/cmd/subtool.go (about)

     1  /*
     2   * Copyright 2018-2019 The NATS Authors
     3   * Licensed under the Apache License, Version 2.0 (the "License");
     4   * you may not use this file except in compliance with the License.
     5   * You may obtain a copy of the License at
     6   *
     7   * http://www.apache.org/licenses/LICENSE-2.0
     8   *
     9   * Unless required by applicable law or agreed to in writing, software
    10   * distributed under the License is distributed on an "AS IS" BASIS,
    11   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12   * See the License for the specific language governing permissions and
    13   * limitations under the License.
    14   */
    15  
    16  package cmd
    17  
    18  import (
    19  	"errors"
    20  	"fmt"
    21  	"os"
    22  	"strings"
    23  	"time"
    24  
    25  	"github.com/kbehouse/nsc/cmd/store"
    26  
    27  	nats "github.com/nats-io/nats.go"
    28  	"github.com/spf13/cobra"
    29  )
    30  
    31  func createSubCmd() *cobra.Command {
    32  	var params SubParams
    33  	var cmd = &cobra.Command{
    34  		Use:     "sub",
    35  		Short:   "Subscribe to a subject on a NATS account",
    36  		Example: "nsc tool sub <subject>\nnsc tool --queue <name> subject",
    37  		Args:    cobra.MinimumNArgs(1),
    38  		RunE: func(cmd *cobra.Command, args []string) error {
    39  			return RunAction(cmd, args, &params)
    40  		},
    41  	}
    42  	cmd.Flags().StringVarP(&params.queue, "queue", "q", "", "subscription queue name")
    43  	cmd.Flags().IntVarP(&params.maxMessages, "max-messages", "", -1, "max messages")
    44  	cmd.Flags().BoolVarP(&encryptFlag, "encrypt", "E", false, "encrypted payload")
    45  	cmd.Flags().MarkHidden("max-messages")
    46  	cmd.Flags().MarkHidden("decrypt")
    47  
    48  	params.BindFlags(cmd)
    49  	return cmd
    50  }
    51  
    52  func init() {
    53  	toolCmd.AddCommand(createSubCmd())
    54  	hidden := createSubCmd()
    55  	hidden.Hidden = true
    56  	hidden.Example = "nsc sub <subject>\nnsc --queue <name> subject"
    57  	GetRootCmd().AddCommand(hidden)
    58  }
    59  
    60  type SubParams struct {
    61  	AccountUserContextParams
    62  	credsPath   string
    63  	natsURLs    []string
    64  	queue       string
    65  	maxMessages int
    66  }
    67  
    68  func (p *SubParams) SetDefaults(ctx ActionCtx) error {
    69  	return p.AccountUserContextParams.SetDefaults(ctx)
    70  }
    71  
    72  func (p *SubParams) PreInteractive(ctx ActionCtx) error {
    73  	return p.AccountUserContextParams.Edit(ctx)
    74  }
    75  
    76  func (p *SubParams) Load(ctx ActionCtx) error {
    77  	p.credsPath = ctx.StoreCtx().KeyStore.CalcUserCredsPath(p.AccountContextParams.Name, p.UserContextParams.Name)
    78  	if natsURLFlag != "" {
    79  		p.natsURLs = []string{natsURLFlag}
    80  		return nil
    81  	}
    82  
    83  	oc, err := ctx.StoreCtx().Store.ReadOperatorClaim()
    84  	if err != nil {
    85  		return err
    86  	}
    87  	p.natsURLs = oc.OperatorServiceURLs
    88  	return nil
    89  }
    90  
    91  func (p *SubParams) PostInteractive(ctx ActionCtx) error {
    92  	return nil
    93  }
    94  
    95  func (p *SubParams) Validate(ctx ActionCtx) error {
    96  	if err := p.AccountUserContextParams.Validate(ctx); err != nil {
    97  		return err
    98  	}
    99  
   100  	if p.maxMessages == 0 {
   101  		return errors.New("max-messages must be greater than zero")
   102  	}
   103  
   104  	if p.credsPath == "" {
   105  		return fmt.Errorf("a creds file for account %q/%q was not found", p.AccountContextParams.Name, p.UserContextParams.Name)
   106  	}
   107  	_, err := os.Stat(p.credsPath)
   108  	if os.IsNotExist(err) {
   109  		return err
   110  	}
   111  	if len(p.natsURLs) == 0 {
   112  		return fmt.Errorf("operator %q doesn't have operator_service_urls set", ctx.StoreCtx().Operator.Name)
   113  	}
   114  	return nil
   115  }
   116  
   117  func (p *SubParams) Run(ctx ActionCtx) (store.Status, error) {
   118  	nc, err := nats.Connect(strings.Join(p.natsURLs, ", "),
   119  		createDefaultToolOptions("nsc_sub", ctx, nats.UserCredentials(p.credsPath))...)
   120  	if err != nil {
   121  		return nil, err
   122  	}
   123  	defer nc.Close()
   124  
   125  	subj := ctx.Args()[0]
   126  	// we are doing sync subs because we want the cli to cleanup properly
   127  	// when the command returns
   128  	var sub *nats.Subscription
   129  	if p.queue != "" {
   130  		sub, err = nc.QueueSubscribeSync(subj, p.queue)
   131  		if err != nil {
   132  			return nil, err
   133  		}
   134  	} else {
   135  		sub, err = nc.SubscribeSync(subj)
   136  		if err != nil {
   137  			return nil, err
   138  		}
   139  	}
   140  	if p.maxMessages > 0 {
   141  		if err := sub.AutoUnsubscribe(p.maxMessages); err != nil {
   142  			return nil, err
   143  		}
   144  		ctx.CurrentCmd().Printf("Listening on [%s] for %d messages\n", subj, p.maxMessages)
   145  	} else {
   146  		ctx.CurrentCmd().Printf("Listening on [%s]\n", subj)
   147  	}
   148  
   149  	if err := nc.Flush(); err != nil {
   150  		return nil, err
   151  	}
   152  
   153  	var seed string
   154  	if encryptFlag {
   155  		// cannot fail if we are here
   156  		seed, err = ctx.StoreCtx().KeyStore.GetSeed(ctx.StoreCtx().Account.PublicKey)
   157  		if err != nil {
   158  			return nil, fmt.Errorf("unable to get the account private key to encrypt/decrypt the payload: %v", err)
   159  		}
   160  	}
   161  
   162  	i := 0
   163  	for {
   164  		msg, err := sub.NextMsg(10 * time.Second)
   165  		if err == nats.ErrTimeout {
   166  			continue
   167  		}
   168  		if err == nats.ErrMaxMessages {
   169  			break
   170  		}
   171  		if err == nats.ErrConnectionClosed {
   172  			break
   173  		}
   174  		if err != nil {
   175  			return nil, err
   176  		}
   177  
   178  		i++
   179  		if encryptFlag {
   180  			msg = maybeDecryptMessage(seed, msg)
   181  		}
   182  		ctx.CurrentCmd().Printf("[#%d] received on [%s]: '%s'\n", i, msg.Subject, string(msg.Data))
   183  	}
   184  
   185  	return nil, nil
   186  }
   187  
   188  func maybeDecryptMessage(seed string, msg *nats.Msg) *nats.Msg {
   189  	var dmsg nats.Msg
   190  	// last part of the subject will be encrypted
   191  	tokens := strings.Split(msg.Subject, ".")
   192  	k := tokens[len(tokens)-1]
   193  	kk, err := Decrypt(seed, []byte(k))
   194  	if err != nil {
   195  		dmsg.Subject = msg.Subject
   196  	} else {
   197  		tokens[len(tokens)-1] = string(kk)
   198  		dmsg.Subject = strings.Join(tokens, ".")
   199  	}
   200  
   201  	dmsg.Data, err = Decrypt(seed, msg.Data)
   202  	if err != nil {
   203  		dmsg.Data = msg.Data
   204  	}
   205  	return &dmsg
   206  }