github.com/nats-io/nsc/v2@v2.8.7-0.20240307184528-efd7023c6896/cmd/common_test.go (about)

     1  /*
     2   * Copyright 2018-2023 The NATS Authors
     3   * Licensed under the Apache License, Version 2.0 (the "License");
     4   * you may not use this file except in compliance with the License.
     5   * You may obtain a copy of the License at
     6   *
     7   * http://www.apache.org/licenses/LICENSE-2.0
     8   *
     9   * Unless required by applicable law or agreed to in writing, software
    10   * distributed under the License is distributed on an "AS IS" BASIS,
    11   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12   * See the License for the specific language governing permissions and
    13   * limitations under the License.
    14   */
    15  
    16  package cmd
    17  
    18  import (
    19  	"fmt"
    20  	"io"
    21  	"net/http"
    22  	"net/http/httptest"
    23  	"net/url"
    24  	"os"
    25  	"path/filepath"
    26  	"testing"
    27  	"time"
    28  
    29  	"github.com/mitchellh/go-homedir"
    30  	"github.com/nats-io/jwt/v2"
    31  	"github.com/nats-io/nkeys"
    32  	"github.com/spf13/cobra"
    33  	"github.com/stretchr/testify/require"
    34  )
    35  
    36  func TestCommon_ResolvePath(t *testing.T) {
    37  	v := ResolvePath("bar", "foo")
    38  	require.Equal(t, v, "bar", "non defined variable")
    39  
    40  	v = ResolvePath("bar", "")
    41  	require.Equal(t, v, "bar", "empty variable")
    42  
    43  	os.Setenv("foo", "foobar")
    44  	v = ResolvePath("bar", "foo")
    45  	require.Equal(t, v, "foobar", "env set")
    46  }
    47  
    48  func TestCommon_GetOutput(t *testing.T) {
    49  	dir := MakeTempDir(t)
    50  	defer os.RemoveAll(dir)
    51  
    52  	type testd struct {
    53  		fp      string
    54  		create  bool
    55  		isError bool
    56  		isDir   bool
    57  	}
    58  	tests := []testd{
    59  		{"--", false, false, false},
    60  		{filepath.Join(dir, "dir"), true, true, true},
    61  		{filepath.Join(dir, "nonExisting"), false, false, false},
    62  		{filepath.Join(dir, "existing"), false, false, false},
    63  	}
    64  	for _, d := range tests {
    65  		if d.isDir {
    66  			os.MkdirAll(d.fp, 0777)
    67  		} else if d.create {
    68  			os.Create(d.fp)
    69  		}
    70  		file, err := GetOutput(d.fp)
    71  		if file != nil && d.fp != "--" {
    72  			file.Close()
    73  		}
    74  		if d.isError && err == nil {
    75  			t.Errorf("expected error creating %#q, but didn't", d.fp)
    76  		}
    77  		if !d.isError && err != nil {
    78  			t.Errorf("unexpected error creating %#q: %v", d.fp, err)
    79  		}
    80  	}
    81  }
    82  
    83  func createWriteCmd(t *testing.T) *cobra.Command {
    84  	var out string
    85  	cmd := &cobra.Command{
    86  		Use: "test",
    87  		RunE: func(cmd *cobra.Command, args []string) error {
    88  			return Write(out, []byte("hello"))
    89  		},
    90  	}
    91  	cmd.Flags().StringVarP(&out, "out", "", "--", "")
    92  	return cmd
    93  }
    94  
    95  func Test_WriteDestinations(t *testing.T) {
    96  	stdout, _, err := ExecuteCmd(createWriteCmd(t), "--out", "--")
    97  	require.NoError(t, err)
    98  	require.Contains(t, stdout, "hello")
    99  	dir := MakeTempDir(t)
   100  	fn := filepath.Join(dir, "test.txt")
   101  	_, _, err = ExecuteCmd(createWriteCmd(t), "--out", fn)
   102  	require.NoError(t, err)
   103  	require.FileExists(t, fn)
   104  	d, err := os.ReadFile(fn)
   105  	require.NoError(t, err)
   106  	require.Contains(t, string(d), "hello")
   107  }
   108  
   109  func TestCommon_IsStdOut(t *testing.T) {
   110  	require.True(t, IsStdOut("--"))
   111  	require.False(t, IsStdOut("/tmp/foo.txt"))
   112  }
   113  
   114  func TestCommon_ResolveKeyEmpty(t *testing.T) {
   115  	old := KeyPathFlag
   116  	KeyPathFlag = ""
   117  
   118  	rkp, err := ResolveKeyFlag()
   119  	KeyPathFlag = old
   120  
   121  	require.NoError(t, err)
   122  	require.Nil(t, rkp)
   123  }
   124  
   125  func TestCommon_ResolveKeyFromSeed(t *testing.T) {
   126  	seed, p, _ := CreateAccountKey(t)
   127  	old := KeyPathFlag
   128  	KeyPathFlag = string(seed)
   129  
   130  	rkp, err := ResolveKeyFlag()
   131  	KeyPathFlag = old
   132  
   133  	require.NoError(t, err)
   134  
   135  	pp, err := rkp.PublicKey()
   136  	require.NoError(t, err)
   137  
   138  	require.Equal(t, pp, p)
   139  }
   140  
   141  func TestCommon_ResolveKeyFromFile(t *testing.T) {
   142  	dir := MakeTempDir(t)
   143  	_, p, kp := CreateAccountKey(t)
   144  	old := KeyPathFlag
   145  	KeyPathFlag = StoreKey(t, kp, dir)
   146  	rkp, err := ResolveKeyFlag()
   147  	KeyPathFlag = old
   148  
   149  	require.NoError(t, err)
   150  
   151  	pp, err := rkp.PublicKey()
   152  	require.NoError(t, err)
   153  
   154  	require.Equal(t, pp, p)
   155  }
   156  
   157  func TestCommon_ParseNumber(t *testing.T) {
   158  	type testd struct {
   159  		input   string
   160  		output  int64
   161  		isError bool
   162  	}
   163  	tests := []testd{
   164  		{"", 0, false},
   165  		{"0", 0, false},
   166  		{"1000", 1000, false},
   167  		{"1K", 1000, false},
   168  		{"1k", 1000, false},
   169  		{"1M", 1000 * 1000, false},
   170  		{"1m", 1000 * 1000, false},
   171  		{"1G", 1000 * 1000 * 1000, false},
   172  		{"1g", 1000 * 1000 * 1000, false},
   173  		{"1KIB", 1024, false},
   174  		{"1kib", 1024, false},
   175  		{"1MIB", 1024 * 1024, false},
   176  		{"1mib", 1024 * 1024, false},
   177  		{"1GIB", 1024 * 1024 * 1024, false},
   178  		{"1gib", 1024 * 1024 * 1024, false},
   179  		{"32a", 0, true},
   180  	}
   181  	for _, d := range tests {
   182  		v, err := ParseNumber(d.input)
   183  		if err != nil && !d.isError {
   184  			t.Errorf("%s didn't expect error: %v", d.input, err)
   185  			continue
   186  		}
   187  		if err == nil && d.isError {
   188  			t.Errorf("expected error from %s", d.input)
   189  			continue
   190  		}
   191  		if v != d.output {
   192  			t.Errorf("%s expected %d but got %d", d.input, d.output, v)
   193  		}
   194  	}
   195  }
   196  
   197  func TestCommon_NKeyValidatorActualKey(t *testing.T) {
   198  	aSeed, _, _ := CreateAccountKey(t)
   199  	fn := NKeyValidator(nkeys.PrefixByteAccount)
   200  	require.NoError(t, fn(string(aSeed)))
   201  
   202  	oSeed, _, _ := CreateOperatorKey(t)
   203  	require.Error(t, fn(string(oSeed)))
   204  }
   205  
   206  func TestCommon_NKeyValidatorKeyInFile(t *testing.T) {
   207  	dir := MakeTempDir(t)
   208  	aSeed, _, _ := CreateAccountKey(t)
   209  	oSeed, _, _ := CreateOperatorKey(t)
   210  
   211  	require.NoError(t, Write(filepath.Join(dir, "as.nk"), aSeed))
   212  	require.NoError(t, Write(filepath.Join(dir, "os.nk"), oSeed))
   213  
   214  	fn := NKeyValidator(nkeys.PrefixByteAccount)
   215  	require.NoError(t, fn(filepath.Join(dir, "as.nk")))
   216  
   217  	require.Error(t, fn(filepath.Join(dir, "os.nk")))
   218  }
   219  
   220  func TestCommon_LoadFromURL(t *testing.T) {
   221  	v := "1,2,3"
   222  	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   223  		fmt.Fprint(w, v)
   224  	}))
   225  	defer ts.Close()
   226  
   227  	d, err := LoadFromURL(ts.URL)
   228  	require.NoError(t, err)
   229  	require.Equal(t, v, string(d))
   230  }
   231  
   232  func TestCommon_LoadFromURLTimeout(t *testing.T) {
   233  	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   234  		time.Sleep(time.Second * 6)
   235  	}))
   236  	defer ts.Close()
   237  
   238  	_, err := LoadFromURL(ts.URL)
   239  	require.Error(t, err)
   240  	require.Contains(t, err.Error(), "Timeout exceeded")
   241  }
   242  
   243  func TestCommon_IsValidDir(t *testing.T) {
   244  	d := MakeTempDir(t)
   245  	require.NoError(t, IsValidDir(d))
   246  
   247  	tp := filepath.Join(d, "foo")
   248  	err := IsValidDir(tp)
   249  	require.Error(t, err)
   250  	require.True(t, os.IsNotExist(err))
   251  
   252  	err = os.WriteFile(tp, []byte("hello"), 0600)
   253  	require.NoError(t, err)
   254  	err = IsValidDir(tp)
   255  	require.Error(t, err)
   256  	require.Equal(t, "not a directory", err.Error())
   257  }
   258  
   259  func TestCommon_MaybeMakeDir(t *testing.T) {
   260  	d := MakeTempDir(t)
   261  	dir := filepath.Join(d, "foo")
   262  	_, err := os.Stat(dir)
   263  	require.True(t, os.IsNotExist(err))
   264  	err = MaybeMakeDir(dir)
   265  	require.NoError(t, err)
   266  	require.DirExists(t, dir)
   267  
   268  	// test no fail if exists
   269  	err = MaybeMakeDir(dir)
   270  	require.NoError(t, err)
   271  }
   272  
   273  func TestCommon_MaybeMakeDir_FileExists(t *testing.T) {
   274  	d := MakeTempDir(t)
   275  	fp := filepath.Join(d, "foo")
   276  	err := Write(fp, []byte("hello"))
   277  	require.NoError(t, err)
   278  
   279  	err = MaybeMakeDir(fp)
   280  	require.Error(t, err)
   281  	require.Contains(t, err.Error(), "is not a dir")
   282  }
   283  
   284  func TestCommon_Read(t *testing.T) {
   285  	d := MakeTempDir(t)
   286  	dir := filepath.Join(d, "foo", "bar", "baz")
   287  	err := MaybeMakeDir(dir)
   288  	require.NoError(t, err)
   289  
   290  	fp := filepath.Join(dir, "..", "..", "foo.txt")
   291  	err = Write(fp, []byte("hello"))
   292  	require.NoError(t, err)
   293  
   294  	require.DirExists(t, dir)
   295  	require.FileExists(t, filepath.Join(d, "foo", "foo.txt"))
   296  	data, err := Read(fp)
   297  	require.NoError(t, err)
   298  	require.Equal(t, "hello", string(data))
   299  }
   300  
   301  func TestCommon_WriteJSON(t *testing.T) {
   302  	d := MakeTempDir(t)
   303  	fp := filepath.Join(d, "foo")
   304  
   305  	n := struct {
   306  		Name string `json:"name"`
   307  	}{}
   308  	n.Name = "test"
   309  
   310  	err := WriteJson(fp, n)
   311  	require.NoError(t, err)
   312  
   313  	v, err := Read(fp)
   314  	require.NoError(t, err)
   315  	require.JSONEq(t, `{"name": "test"}`, string(v))
   316  }
   317  
   318  func TestCommon_ReadJSON(t *testing.T) {
   319  	d := MakeTempDir(t)
   320  	fp := filepath.Join(d, "foo")
   321  	err := Write(fp, []byte(`{"name": "test"}`))
   322  	require.NoError(t, err)
   323  
   324  	n := struct {
   325  		Name string `json:"name"`
   326  	}{}
   327  
   328  	err = ReadJson(fp, &n)
   329  	require.NoError(t, err)
   330  	require.Equal(t, "test", n.Name)
   331  }
   332  
   333  func TestCommon_AbbrevHomePaths(t *testing.T) {
   334  	require.Equal(t, "", AbbrevHomePaths(""))
   335  	require.Equal(t, "/foo/bar", AbbrevHomePaths("/foo/bar"))
   336  	v, err := homedir.Dir()
   337  	if err != nil {
   338  		require.Equal(t, "~/bar", AbbrevHomePaths(filepath.Join(v, "bar")))
   339  	}
   340  }
   341  
   342  func Test_NKeyValidator(t *testing.T) {
   343  	ts := NewTestStore(t, "O")
   344  	defer ts.Done(t)
   345  
   346  	oSeed, opk, _ := CreateOperatorKey(t)
   347  	aSeed, pk, _ := CreateAccountKey(t)
   348  	asf := filepath.Join(ts.Dir, "account_seed_file.nk")
   349  	require.NoError(t, os.WriteFile(asf, aSeed, 0700))
   350  	pkf := filepath.Join(ts.Dir, "account_public_file.nk")
   351  	require.NoError(t, os.WriteFile(pkf, []byte(pk), 0700))
   352  	nff := filepath.Join(ts.Dir, "not_exist.nk")
   353  
   354  	var keyTests = []struct {
   355  		arg string
   356  		ok  bool
   357  	}{
   358  		{asf, true},
   359  		{pkf, true},
   360  		{nff, false},
   361  		{ts.Dir, false},
   362  		{string(aSeed), true},
   363  		{string(aSeed), true},
   364  		{pk, true},
   365  		{string(oSeed), false},
   366  		{opk, false},
   367  		{"", false},
   368  		{"foo", false},
   369  	}
   370  
   371  	fun := NKeyValidator(nkeys.PrefixByteAccount)
   372  	for i, kt := range keyTests {
   373  		err := fun(kt.arg)
   374  		var failed bool
   375  		message := fmt.Sprintf("unexpected error on test %q (%d): %v", kt.arg, i, err)
   376  		if err != nil {
   377  			failed = true
   378  		}
   379  		require.Equal(t, !kt.ok, failed, message)
   380  	}
   381  }
   382  
   383  func Test_SeedNKeyValidatorMatching(t *testing.T) {
   384  	ts := NewTestStore(t, "O")
   385  	defer ts.Done(t)
   386  
   387  	oSeed, opk, _ := CreateOperatorKey(t)
   388  	aSeed, pk, _ := CreateAccountKey(t)
   389  	as1, pk1, _ := CreateAccountKey(t)
   390  	as2, pk2, _ := CreateAccountKey(t)
   391  
   392  	validPubs := []string{string(pk), string(pk1)}
   393  
   394  	var keyTests = []struct {
   395  		arg string
   396  		ok  bool
   397  	}{
   398  		{string(oSeed), false},
   399  		{"", false},
   400  		{"foo", false},
   401  		{pk, false},
   402  		{pk2, false},
   403  		{opk, false},
   404  		{filepath.Join(ts.Dir, "notexist.nk"), false},
   405  		{string(as2), false},
   406  		{string(aSeed), true},
   407  		{string(as1), true},
   408  	}
   409  
   410  	fun := SeedNKeyValidatorMatching(validPubs, nkeys.PrefixByteAccount)
   411  	for i, kt := range keyTests {
   412  		err := fun(kt.arg)
   413  		var failed bool
   414  		message := fmt.Sprintf("unexpected error on test %q (%d): %v", kt.arg, i, err)
   415  		if err != nil {
   416  			failed = true
   417  		}
   418  		require.Equal(t, !kt.ok, failed, message)
   419  	}
   420  }
   421  
   422  func TestPushAccount(t *testing.T) {
   423  	_, opk, okp := CreateOperatorKey(t)
   424  
   425  	// create an http server to accept the request
   426  	hts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   427  		defer r.Body.Close()
   428  
   429  		body, err := io.ReadAll(r.Body)
   430  		if err != nil {
   431  			require.NoError(t, err)
   432  		}
   433  		ac, err := jwt.DecodeAccountClaims(string(body))
   434  		require.NoError(t, err)
   435  
   436  		token, err := ac.Encode(okp)
   437  		require.NoError(t, err)
   438  
   439  		w.Header().Add("Content-Type", "application/jwt")
   440  		w.WriteHeader(200)
   441  		w.Write([]byte(token))
   442  	}))
   443  	defer hts.Close()
   444  
   445  	// self sign a jwt
   446  	_, apk, akp := CreateAccountKey(t)
   447  	ac := jwt.NewAccountClaims(apk)
   448  	araw, err := ac.Encode(akp)
   449  	require.NoError(t, err)
   450  
   451  	u, err := url.Parse(hts.URL)
   452  	require.NoError(t, err)
   453  	_, sraw, err := PushAccount(u.String(), []byte(araw))
   454  	require.NoError(t, err)
   455  	require.NotNil(t, sraw)
   456  
   457  	ac, err = jwt.DecodeAccountClaims(string(sraw))
   458  	require.NoError(t, err)
   459  	require.Equal(t, apk, ac.Subject)
   460  	require.Equal(t, opk, ac.Issuer)
   461  }
   462  
   463  func Test_NameFlagArgOnlyOnEmpty(t *testing.T) {
   464  	var tests = []struct {
   465  		n   string
   466  		a   []string
   467  		out string
   468  	}{
   469  		{"", nil, ""},
   470  		{"a", nil, "a"},
   471  		{"a", []string{}, "a"},
   472  		{"a", []string{"b", "c"}, "a"},
   473  		{"", []string{"b", "c"}, "b"},
   474  	}
   475  
   476  	for i, v := range tests {
   477  		r := nameFlagOrArgument(v.n, v.a)
   478  		require.Equal(t, v.out, r, "failed test %d", i)
   479  	}
   480  }