github.com/Ilhicas/nomad@v1.0.4-0.20210304152020-e86851182bc3/helper/funcs_test.go (about)

     1  package helper
     2  
     3  import (
     4  	"fmt"
     5  	"path/filepath"
     6  	"reflect"
     7  	"sort"
     8  	"testing"
     9  	"time"
    10  
    11  	"github.com/stretchr/testify/require"
    12  )
    13  
    14  func TestSliceStringIsSubset(t *testing.T) {
    15  	l := []string{"a", "b", "c"}
    16  	s := []string{"d"}
    17  
    18  	sub, offending := SliceStringIsSubset(l, l[:1])
    19  	if !sub || len(offending) != 0 {
    20  		t.Fatalf("bad %v %v", sub, offending)
    21  	}
    22  
    23  	sub, offending = SliceStringIsSubset(l, s)
    24  	if sub || len(offending) == 0 || offending[0] != "d" {
    25  		t.Fatalf("bad %v %v", sub, offending)
    26  	}
    27  }
    28  
    29  func TestSliceStringContains(t *testing.T) {
    30  	list := []string{"a", "b", "c"}
    31  	require.True(t, SliceStringContains(list, "a"))
    32  	require.True(t, SliceStringContains(list, "b"))
    33  	require.True(t, SliceStringContains(list, "c"))
    34  	require.False(t, SliceStringContains(list, "d"))
    35  }
    36  
    37  func TestCompareTimePtrs(t *testing.T) {
    38  	t.Run("nil", func(t *testing.T) {
    39  		a := (*time.Duration)(nil)
    40  		b := (*time.Duration)(nil)
    41  		require.True(t, CompareTimePtrs(a, b))
    42  		c := TimeToPtr(3 * time.Second)
    43  		require.False(t, CompareTimePtrs(a, c))
    44  		require.False(t, CompareTimePtrs(c, a))
    45  	})
    46  
    47  	t.Run("not nil", func(t *testing.T) {
    48  		a := TimeToPtr(1 * time.Second)
    49  		b := TimeToPtr(1 * time.Second)
    50  		c := TimeToPtr(2 * time.Second)
    51  		require.True(t, CompareTimePtrs(a, b))
    52  		require.False(t, CompareTimePtrs(a, c))
    53  	})
    54  }
    55  
    56  func TestCompareSliceSetString(t *testing.T) {
    57  	cases := []struct {
    58  		A      []string
    59  		B      []string
    60  		Result bool
    61  	}{
    62  		{
    63  			A:      []string{},
    64  			B:      []string{},
    65  			Result: true,
    66  		},
    67  		{
    68  			A:      []string{},
    69  			B:      []string{"a"},
    70  			Result: false,
    71  		},
    72  		{
    73  			A:      []string{"a"},
    74  			B:      []string{"a"},
    75  			Result: true,
    76  		},
    77  		{
    78  			A:      []string{"a"},
    79  			B:      []string{"b"},
    80  			Result: false,
    81  		},
    82  		{
    83  			A:      []string{"a", "b"},
    84  			B:      []string{"b"},
    85  			Result: false,
    86  		},
    87  		{
    88  			A:      []string{"a", "b"},
    89  			B:      []string{"a"},
    90  			Result: false,
    91  		},
    92  		{
    93  			A:      []string{"a", "b"},
    94  			B:      []string{"a", "b"},
    95  			Result: true,
    96  		},
    97  		{
    98  			A:      []string{"a", "b"},
    99  			B:      []string{"b", "a"},
   100  			Result: true,
   101  		},
   102  	}
   103  
   104  	for i, tc := range cases {
   105  		tc := tc
   106  		t.Run(fmt.Sprintf("case-%da", i), func(t *testing.T) {
   107  			if res := CompareSliceSetString(tc.A, tc.B); res != tc.Result {
   108  				t.Fatalf("expected %t but CompareSliceSetString(%v, %v) -> %t",
   109  					tc.Result, tc.A, tc.B, res,
   110  				)
   111  			}
   112  		})
   113  
   114  		// Function is commutative so compare B and A
   115  		t.Run(fmt.Sprintf("case-%db", i), func(t *testing.T) {
   116  			if res := CompareSliceSetString(tc.B, tc.A); res != tc.Result {
   117  				t.Fatalf("expected %t but CompareSliceSetString(%v, %v) -> %t",
   118  					tc.Result, tc.B, tc.A, res,
   119  				)
   120  			}
   121  		})
   122  	}
   123  }
   124  
   125  func TestMapStringStringSliceValueSet(t *testing.T) {
   126  	m := map[string][]string{
   127  		"foo": {"1", "2"},
   128  		"bar": {"3"},
   129  		"baz": nil,
   130  	}
   131  
   132  	act := MapStringStringSliceValueSet(m)
   133  	exp := []string{"1", "2", "3"}
   134  	sort.Strings(act)
   135  	if !reflect.DeepEqual(act, exp) {
   136  		t.Fatalf("Bad; got %v; want %v", act, exp)
   137  	}
   138  }
   139  
   140  func TestCopyMapStringSliceString(t *testing.T) {
   141  	m := map[string][]string{
   142  		"x": {"a", "b", "c"},
   143  		"y": {"1", "2", "3"},
   144  		"z": nil,
   145  	}
   146  
   147  	c := CopyMapStringSliceString(m)
   148  	if !reflect.DeepEqual(c, m) {
   149  		t.Fatalf("%#v != %#v", m, c)
   150  	}
   151  
   152  	c["x"][1] = "---"
   153  	if reflect.DeepEqual(c, m) {
   154  		t.Fatalf("Shared slices: %#v == %#v", m["x"], c["x"])
   155  	}
   156  }
   157  
   158  func TestCopyMapSliceInterface(t *testing.T) {
   159  	m := map[string]interface{}{
   160  		"foo": "bar",
   161  		"baz": 2,
   162  	}
   163  
   164  	c := CopyMapStringInterface(m)
   165  	require.True(t, reflect.DeepEqual(m, c))
   166  
   167  	m["foo"] = "zzz"
   168  	require.False(t, reflect.DeepEqual(m, c))
   169  }
   170  
   171  func TestCleanEnvVar(t *testing.T) {
   172  	type testCase struct {
   173  		input    string
   174  		expected string
   175  	}
   176  	cases := []testCase{
   177  		{"asdf", "asdf"},
   178  		{"ASDF", "ASDF"},
   179  		{"0sdf", "_sdf"},
   180  		{"asd0", "asd0"},
   181  		{"_asd", "_asd"},
   182  		{"-asd", "_asd"},
   183  		{"asd.fgh", "asd.fgh"},
   184  		{"A~!@#$%^&*()_+-={}[]|\\;:'\"<,>?/Z", "A______________________________Z"},
   185  		{"A\U0001f4a9Z", "A____Z"},
   186  	}
   187  	for _, c := range cases {
   188  		if output := CleanEnvVar(c.input, '_'); output != c.expected {
   189  			t.Errorf("CleanEnvVar(%q, '_') -> %q != %q", c.input, output, c.expected)
   190  		}
   191  	}
   192  }
   193  
   194  func BenchmarkCleanEnvVar(b *testing.B) {
   195  	in := "NOMAD_ADDR_redis-cache"
   196  	replacement := byte('_')
   197  	b.SetBytes(int64(len(in)))
   198  	b.ReportAllocs()
   199  	b.ResetTimer()
   200  	for i := 0; i < b.N; i++ {
   201  		CleanEnvVar(in, replacement)
   202  	}
   203  }
   204  
   205  type testCase struct {
   206  	input    string
   207  	expected string
   208  }
   209  
   210  func commonCleanFilenameCases() (cases []testCase) {
   211  	// Common set of test cases for all 3 TestCleanFilenameX functions
   212  	cases = []testCase{
   213  		{"asdf", "asdf"},
   214  		{"ASDF", "ASDF"},
   215  		{"0sdf", "0sdf"},
   216  		{"asd0", "asd0"},
   217  		{"_asd", "_asd"},
   218  		{"-asd", "-asd"},
   219  		{"asd.fgh", "asd.fgh"},
   220  		{"Linux/Forbidden", "Linux_Forbidden"},
   221  		{"Windows<>:\"/\\|?*Forbidden", "Windows_________Forbidden"},
   222  		{`Windows<>:"/\|?*Forbidden_StringLiteral`, "Windows_________Forbidden_StringLiteral"},
   223  	}
   224  	return cases
   225  }
   226  
   227  func TestCleanFilename(t *testing.T) {
   228  	cases := append(
   229  		[]testCase{
   230  			{"A\U0001f4a9Z", "A💩Z"}, // CleanFilename allows unicode
   231  			{"A💩Z", "A💩Z"},
   232  			{"A~!@#$%^&*()_+-={}[]|\\;:'\"<,>?/Z", "A~!@#$%^&_()_+-={}[]__;_'__,___Z"},
   233  		}, commonCleanFilenameCases()...)
   234  
   235  	for i, c := range cases {
   236  		t.Run(fmt.Sprintf("case-%d", i), func(t *testing.T) {
   237  			output := CleanFilename(c.input, "_")
   238  			failMsg := fmt.Sprintf("CleanFilename(%q, '_') -> %q != %q", c.input, output, c.expected)
   239  			require.Equal(t, c.expected, output, failMsg)
   240  		})
   241  	}
   242  }
   243  
   244  func TestCleanFilenameASCIIOnly(t *testing.T) {
   245  	ASCIIOnlyCases := append(
   246  		[]testCase{
   247  			{"A\U0001f4a9Z", "A_Z"}, // CleanFilenameASCIIOnly does not allow unicode
   248  			{"A💩Z", "A_Z"},
   249  			{"A~!@#$%^&*()_+-={}[]|\\;:'\"<,>?/Z", "A~!@#$%^&_()_+-={}[]__;_'__,___Z"},
   250  		}, commonCleanFilenameCases()...)
   251  
   252  	for i, c := range ASCIIOnlyCases {
   253  		t.Run(fmt.Sprintf("case-%d", i), func(t *testing.T) {
   254  			output := CleanFilenameASCIIOnly(c.input, "_")
   255  			failMsg := fmt.Sprintf("CleanFilenameASCIIOnly(%q, '_') -> %q != %q", c.input, output, c.expected)
   256  			require.Equal(t, c.expected, output, failMsg)
   257  		})
   258  	}
   259  }
   260  
   261  func TestCleanFilenameStrict(t *testing.T) {
   262  	strictCases := append(
   263  		[]testCase{
   264  			{"A\U0001f4a9Z", "A💩Z"}, // CleanFilenameStrict allows unicode
   265  			{"A💩Z", "A💩Z"},
   266  			{"A~!@#$%^&*()_+-={}[]|\\;:'\"<,>?/Z", "A_!___%^______-_{}_____________Z"},
   267  		}, commonCleanFilenameCases()...)
   268  
   269  	for i, c := range strictCases {
   270  		t.Run(fmt.Sprintf("case-%d", i), func(t *testing.T) {
   271  			output := CleanFilenameStrict(c.input, "_")
   272  			failMsg := fmt.Sprintf("CleanFilenameStrict(%q, '_') -> %q != %q", c.input, output, c.expected)
   273  			require.Equal(t, c.expected, output, failMsg)
   274  		})
   275  	}
   276  }
   277  
   278  func TestCheckNamespaceScope(t *testing.T) {
   279  	cases := []struct {
   280  		desc      string
   281  		provided  string
   282  		requested []string
   283  		offending []string
   284  	}{
   285  		{
   286  			desc:      "root ns requesting namespace",
   287  			provided:  "",
   288  			requested: []string{"engineering"},
   289  		},
   290  		{
   291  			desc:      "matching parent ns with child",
   292  			provided:  "engineering",
   293  			requested: []string{"engineering", "engineering/sub-team"},
   294  		},
   295  		{
   296  			desc:      "mismatch ns",
   297  			provided:  "engineering",
   298  			requested: []string{"finance", "engineering/sub-team", "eng"},
   299  			offending: []string{"finance", "eng"},
   300  		},
   301  		{
   302  			desc:      "mismatch child",
   303  			provided:  "engineering/sub-team",
   304  			requested: []string{"engineering/new-team", "engineering/sub-team", "engineering/sub-team/child"},
   305  			offending: []string{"engineering/new-team"},
   306  		},
   307  		{
   308  			desc:      "matching prefix",
   309  			provided:  "engineering",
   310  			requested: []string{"engineering/new-team", "engineering/new-team/sub-team"},
   311  		},
   312  	}
   313  
   314  	for _, tc := range cases {
   315  		t.Run(tc.desc, func(t *testing.T) {
   316  			offending := CheckNamespaceScope(tc.provided, tc.requested)
   317  			require.Equal(t, offending, tc.offending)
   318  		})
   319  	}
   320  }
   321  
   322  func TestPathEscapesSandbox(t *testing.T) {
   323  	cases := []struct {
   324  		name     string
   325  		path     string
   326  		dir      string
   327  		expected bool
   328  	}{
   329  		{
   330  			// this is the ${NOMAD_SECRETS_DIR} case
   331  			name:     "ok joined absolute path inside sandbox",
   332  			path:     filepath.Join("/alloc", "/secrets"),
   333  			dir:      "/alloc",
   334  			expected: false,
   335  		},
   336  		{
   337  			name:     "fail unjoined absolute path outside sandbox",
   338  			path:     "/secrets",
   339  			dir:      "/alloc",
   340  			expected: true,
   341  		},
   342  		{
   343  			name:     "ok joined relative path inside sandbox",
   344  			path:     filepath.Join("/alloc", "./safe"),
   345  			dir:      "/alloc",
   346  			expected: false,
   347  		},
   348  		{
   349  			name:     "fail unjoined relative path outside sandbox",
   350  			path:     "./safe",
   351  			dir:      "/alloc",
   352  			expected: true,
   353  		},
   354  		{
   355  			name:     "ok relative path traversal constrained to sandbox",
   356  			path:     filepath.Join("/alloc", "../../alloc/safe"),
   357  			dir:      "/alloc",
   358  			expected: false,
   359  		},
   360  		{
   361  			name:     "ok unjoined absolute path traversal constrained to sandbox",
   362  			path:     filepath.Join("/alloc", "/../alloc/safe"),
   363  			dir:      "/alloc",
   364  			expected: false,
   365  		},
   366  		{
   367  			name:     "ok unjoined absolute path traversal constrained to sandbox",
   368  			path:     "/../alloc/safe",
   369  			dir:      "/alloc",
   370  			expected: false,
   371  		},
   372  		{
   373  			name:     "fail joined relative path traverses outside sandbox",
   374  			path:     filepath.Join("/alloc", "../../../unsafe"),
   375  			dir:      "/alloc",
   376  			expected: true,
   377  		},
   378  		{
   379  			name:     "fail unjoined relative path traverses outside sandbox",
   380  			path:     "../../../unsafe",
   381  			dir:      "/alloc",
   382  			expected: true,
   383  		},
   384  		{
   385  			name:     "fail joined absolute path tries to transverse outside sandbox",
   386  			path:     filepath.Join("/alloc", "/alloc/../../unsafe"),
   387  			dir:      "/alloc",
   388  			expected: true,
   389  		},
   390  		{
   391  			name:     "fail unjoined absolute path tries to transverse outside sandbox",
   392  			path:     "/alloc/../../unsafe",
   393  			dir:      "/alloc",
   394  			expected: true,
   395  		},
   396  	}
   397  
   398  	for _, tc := range cases {
   399  		t.Run(tc.name, func(t *testing.T) {
   400  			caseMsg := fmt.Sprintf("path: %v\ndir: %v", tc.path, tc.dir)
   401  			escapes := PathEscapesSandbox(tc.dir, tc.path)
   402  			require.Equal(t, tc.expected, escapes, caseMsg)
   403  		})
   404  	}
   405  }