github.com/nats-io/nsc/v2@v2.8.7-0.20240307184528-efd7023c6896/cmd/pull.go (about)

     1  /*
     2   * Copyright 2018-2021 The NATS Authors
     3   * Licensed under the Apache License, Version 2.0 (the "License");
     4   * you may not use this file except in compliance with the License.
     5   * You may obtain a copy of the License at
     6   *
     7   * http://www.apache.org/licenses/LICENSE-2.0
     8   *
     9   * Unless required by applicable law or agreed to in writing, software
    10   * distributed under the License is distributed on an "AS IS" BASIS,
    11   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12   * See the License for the specific language governing permissions and
    13   * limitations under the License.
    14   */
    15  
    16  package cmd
    17  
    18  import (
    19  	"errors"
    20  	"fmt"
    21  	"strings"
    22  	"sync"
    23  	"time"
    24  
    25  	cli "github.com/nats-io/cliprompts/v2"
    26  	"github.com/nats-io/jwt/v2"
    27  	"github.com/nats-io/nats.go"
    28  	"github.com/nats-io/nsc/v2/cmd/store"
    29  	"github.com/spf13/cobra"
    30  )
    31  
    32  func createPullCmd() *cobra.Command {
    33  	var params PullParams
    34  	cmd := &cobra.Command{
    35  		Use:   "pull",
    36  		Short: "Pull an operator or account jwt replacing the local jwt with the server's version",
    37  		RunE: func(cmd *cobra.Command, args []string) error {
    38  			return RunAction(cmd, args, &params)
    39  		},
    40  	}
    41  
    42  	cmd.Flags().BoolVarP(&params.All, "all", "A", false, "operator and all accounts under the operator")
    43  	cmd.Flags().BoolVarP(&params.Overwrite, "overwrite-newer", "F", false, "overwrite local JWTs that are newer than remote")
    44  	cmd.Flags().StringVarP(&params.sysAcc, "system-account", "", "", "System account for use with nats-resolver enabled nats-server. (Default is system account specified by operator)")
    45  	cmd.Flags().StringVarP(&params.sysAccUser, "system-user", "", "", "System account user for use with nats-resolver enabled nats-server. (Default to temporarily generated user)")
    46  	params.AccountContextParams.BindFlags(cmd)
    47  	return cmd
    48  }
    49  
    50  func init() {
    51  	rootCmd.AddCommand(createPullCmd())
    52  }
    53  
    54  type PullParams struct {
    55  	AccountContextParams
    56  	All        bool
    57  	Jobs       PullJobs
    58  	Overwrite  bool
    59  	sysAccUser string // when present use
    60  	sysAcc     string
    61  }
    62  
    63  type PullJob struct {
    64  	Name       string
    65  	ASU        string
    66  	Err        error
    67  	StoreErr   error
    68  	LocalClaim jwt.Claims
    69  
    70  	PullStatus *store.Report
    71  }
    72  
    73  func (j *PullJob) Token() (string, error) {
    74  	if len(j.PullStatus.Data) == 0 {
    75  		return "", errors.New("no data")
    76  	}
    77  	token, err := jwt.ParseDecoratedJWT(j.PullStatus.Data)
    78  	if err != nil {
    79  		return "", err
    80  	}
    81  	j.PullStatus.Data = []byte(token)
    82  	gc, err := jwt.DecodeGeneric(token)
    83  	if err != nil {
    84  		return "", err
    85  	}
    86  	switch gc.ClaimType() {
    87  	case jwt.AccountClaim:
    88  		_, err := jwt.DecodeAccountClaims(token)
    89  		if err != nil {
    90  			return "", err
    91  		}
    92  		return token, nil
    93  	case jwt.OperatorClaim:
    94  		_, err := jwt.DecodeOperatorClaims(token)
    95  		if err != nil {
    96  			return "", err
    97  		}
    98  		return token, nil
    99  	default:
   100  		return "", fmt.Errorf("unsupported token type: %q", gc.ClaimType())
   101  	}
   102  }
   103  
   104  func (j *PullJob) Run() {
   105  	if j.PullStatus != nil {
   106  		// already ran
   107  		return
   108  	}
   109  	s, err := store.PullAccount(j.ASU)
   110  	if err != nil {
   111  		j.Err = err
   112  		return
   113  	}
   114  	ps, ok := s.(*store.Report)
   115  	if !ok {
   116  		j.Err = errors.New("unable to convert pull status")
   117  		return
   118  	}
   119  	j.PullStatus = ps
   120  }
   121  
   122  type PullJobs []*PullJob
   123  
   124  func (p *PullParams) SetDefaults(ctx ActionCtx) error {
   125  	return p.AccountContextParams.SetDefaults(ctx)
   126  }
   127  
   128  func (p *PullParams) PreInteractive(ctx ActionCtx) error {
   129  	var err error
   130  	tc := GetConfig()
   131  	p.All, err = cli.Confirm(fmt.Sprintf("pull operator %q and all accounts", tc.Operator), true)
   132  	if err != nil {
   133  		return err
   134  	}
   135  	if !p.All {
   136  		if err := p.AccountContextParams.Edit(ctx); err != nil {
   137  			return err
   138  		}
   139  	}
   140  	return nil
   141  }
   142  
   143  func (p *PullParams) Load(ctx ActionCtx) error {
   144  	return nil
   145  }
   146  
   147  func (p *PullParams) PostInteractive(ctx ActionCtx) error {
   148  	return nil
   149  }
   150  
   151  func (p *PullParams) Validate(ctx ActionCtx) error {
   152  	if !p.All && p.Name == "" {
   153  		return errors.New("specify --all or --account")
   154  	}
   155  	oc, err := ctx.StoreCtx().Store.ReadOperatorClaim()
   156  	if err != nil {
   157  		return err
   158  	}
   159  	if oc.AccountServerURL == "" {
   160  		return fmt.Errorf("operator %q doesn't set account server url - unable to pull", ctx.StoreCtx().Operator.Name)
   161  	}
   162  	if IsResolverURL(oc.AccountServerURL) && !p.All {
   163  		return fmt.Errorf("operator %q can only pull all jwt - unable to pull by account", ctx.StoreCtx().Operator.Name)
   164  	}
   165  	return nil
   166  }
   167  
   168  func (p *PullParams) setupJobs(ctx ActionCtx) error {
   169  	s := ctx.StoreCtx().Store
   170  	oc, err := s.ReadOperatorClaim()
   171  	if err != nil {
   172  		return err
   173  	}
   174  	if p.All {
   175  		u, err := OperatorJwtURL(oc)
   176  		if err != nil {
   177  			return err
   178  		}
   179  		// in ngs we are in v2 operator, so lets try /jwt/v2/operator first
   180  		uv2 := strings.Replace(u, "/jwt/v1/operator", "/jwt/v2/operator", 1)
   181  		j := PullJob{ASU: uv2, Name: oc.Name, LocalClaim: oc}
   182  		j.Run()
   183  		if j.Err == nil {
   184  			p.Jobs = append(p.Jobs, &j)
   185  		} else {
   186  			j = PullJob{ASU: u, Name: oc.Name, LocalClaim: oc}
   187  			p.Jobs = append(p.Jobs, &j)
   188  		}
   189  		tc := GetConfig()
   190  		accounts, err := tc.ListAccounts()
   191  		if err != nil {
   192  			return err
   193  		}
   194  		for _, v := range accounts {
   195  			ac, err := s.ReadAccountClaim(v)
   196  			if err != nil {
   197  				return err
   198  			}
   199  			u, err := AccountJwtURL(oc, ac)
   200  			if err != nil {
   201  				return err
   202  			}
   203  			j := PullJob{ASU: u, Name: ac.Name, LocalClaim: ac}
   204  			p.Jobs = append(p.Jobs, &j)
   205  		}
   206  	} else {
   207  		ac, err := s.ReadAccountClaim(p.Name)
   208  		if err != nil {
   209  			return err
   210  		}
   211  		u, err := AccountJwtURL(oc, ac)
   212  		if err != nil {
   213  			return err
   214  		}
   215  		j := PullJob{ASU: u, Name: ac.Name, LocalClaim: ac}
   216  		p.Jobs = append(p.Jobs, &j)
   217  	}
   218  
   219  	return nil
   220  }
   221  
   222  func (p *PullParams) maybeStoreJWT(ctx ActionCtx, sub *store.Report, token string) {
   223  	remoteClaim, err := jwt.DecodeGeneric(token)
   224  	if err != nil {
   225  		sub.AddError("error decoding remote token: %v %s", err, token)
   226  		return
   227  	}
   228  	orig := int64(0)
   229  	if localClaim, err := ctx.StoreCtx().Store.ReadAccountClaim(remoteClaim.Name); err == nil {
   230  		orig = localClaim.IssuedAt
   231  	}
   232  	remote := remoteClaim.IssuedAt
   233  	if (orig > remote) && !p.Overwrite {
   234  		sub.AddError("local jwt for %q is newer than remote version - specify --overwrite-newer to overwrite", remoteClaim.Name)
   235  		return
   236  	}
   237  	if err := ctx.StoreCtx().Store.StoreRaw([]byte(token)); err != nil {
   238  		sub.AddError("error storing %q: %v", remoteClaim.Name, err)
   239  		return
   240  	}
   241  	sub.AddOK("stored %s %q", remoteClaim.ClaimType(), remoteClaim.Name)
   242  	if sub.OK() {
   243  		sub.Label = fmt.Sprintf("pulled %q from the account server", remoteClaim.Name)
   244  	}
   245  }
   246  
   247  func (p *PullParams) Run(ctx ActionCtx) (store.Status, error) {
   248  	r := store.NewDetailedReport(true)
   249  	if op, err := ctx.StoreCtx().Store.ReadOperatorClaim(); err != nil {
   250  		r.AddError("could not read operator claim: %v", err)
   251  		return r, err
   252  	} else if op.AccountServerURL == "" {
   253  		err := fmt.Errorf("operator has no account server url")
   254  		r.AddError("operator %s: %v", op.Name, err)
   255  		return r, err
   256  	} else if url := op.AccountServerURL; IsResolverURL(url) {
   257  		subR := store.NewReport(store.OK, `pull from cluster using system account`)
   258  		r.Add(subR)
   259  		ib := nats.NewInbox()
   260  		_, opt, err := getSystemAccountUser(ctx, p.sysAcc, p.sysAccUser, ib, "$SYS.REQ.CLAIMS.PACK")
   261  		if err != nil {
   262  			subR.AddError("failed to obtain system user: %v", err)
   263  			return r, nil
   264  		}
   265  		nc, err := nats.Connect(url, createDefaultToolOptions("nsc_pull", ctx, opt)...)
   266  		if err != nil {
   267  			subR.AddError("failed to connect to %s: %v", url, err)
   268  			return r, nil
   269  		}
   270  		defer nc.Close()
   271  
   272  		sub, err := nc.SubscribeSync(ib)
   273  		if err != nil {
   274  			subR.AddError("failed to subscribe to response subject: %v", err)
   275  			return r, nil
   276  		}
   277  		if err := nc.PublishRequest("$SYS.REQ.CLAIMS.PACK", ib, nil); err != nil {
   278  			subR.AddError("failed to pull accounts: %v", err)
   279  			return r, nil
   280  		}
   281  		for {
   282  			if resp, err := sub.NextMsg(time.Second); err != nil {
   283  				subR.AddError("failed to get response to pull: %v", err)
   284  				break
   285  			} else if msg := string(resp.Data); msg == "" { // empty response means end
   286  				break
   287  			} else if tk := strings.Split(string(resp.Data), "|"); len(tk) != 2 {
   288  				subR.AddError("pull response bad")
   289  				break
   290  			} else {
   291  				p.maybeStoreJWT(ctx, subR, tk[1])
   292  			}
   293  		}
   294  		return r, nil
   295  	}
   296  
   297  	ctx.CurrentCmd().SilenceUsage = true
   298  	if err := p.setupJobs(ctx); err != nil {
   299  		return nil, err
   300  	}
   301  	var wg sync.WaitGroup
   302  	wg.Add(len(p.Jobs))
   303  	for _, j := range p.Jobs {
   304  		go func(j *PullJob) {
   305  			defer wg.Done()
   306  			j.Run()
   307  		}(j)
   308  	}
   309  	wg.Wait()
   310  
   311  	for _, j := range p.Jobs {
   312  		sub := store.NewReport(store.OK, "pull %q from the account server", j.Name)
   313  		sub.Opt = store.DetailsOnErrorOrWarning
   314  		r.Add(sub)
   315  		if j.PullStatus != nil {
   316  			sub.Add(store.HoistChildren(j.PullStatus)...)
   317  		}
   318  		if j.Err != nil {
   319  			sub.AddFromError(j.Err)
   320  			continue
   321  		}
   322  		if j.PullStatus.OK() {
   323  			token, _ := j.Token()
   324  			p.maybeStoreJWT(ctx, sub, token)
   325  		}
   326  	}
   327  	return r, nil
   328  }