github.com/Ilhicas/nomad@v1.0.4-0.20210304152020-e86851182bc3/client/allocrunner/taskrunner/getter/getter_test.go (about)

     1  package getter
     2  
     3  import (
     4  	"fmt"
     5  	"io"
     6  	"io/ioutil"
     7  	"mime"
     8  	"net/http"
     9  	"net/http/httptest"
    10  	"os"
    11  	"path/filepath"
    12  	"reflect"
    13  	"runtime"
    14  	"strings"
    15  	"testing"
    16  
    17  	"github.com/hashicorp/nomad/client/taskenv"
    18  	"github.com/hashicorp/nomad/helper"
    19  	"github.com/hashicorp/nomad/nomad/mock"
    20  	"github.com/hashicorp/nomad/nomad/structs"
    21  	"github.com/stretchr/testify/require"
    22  )
    23  
    24  // noopReplacer is a noop version of taskenv.TaskEnv.ReplaceEnv.
    25  type noopReplacer struct {
    26  	taskDir string
    27  }
    28  
    29  func clientPath(taskDir, path string, join bool) (string, bool) {
    30  	if !filepath.IsAbs(path) || (helper.PathEscapesSandbox(taskDir, path) && join) {
    31  		path = filepath.Join(taskDir, path)
    32  	}
    33  	path = filepath.Clean(path)
    34  	if taskDir != "" && !helper.PathEscapesSandbox(taskDir, path) {
    35  		return path, false
    36  	}
    37  	return path, true
    38  }
    39  
    40  func (noopReplacer) ReplaceEnv(s string) string {
    41  	return s
    42  }
    43  
    44  func (r noopReplacer) ClientPath(p string, join bool) (string, bool) {
    45  	path, escapes := clientPath(r.taskDir, r.ReplaceEnv(p), join)
    46  	return path, escapes
    47  }
    48  
    49  func noopTaskEnv(taskDir string) EnvReplacer {
    50  	return noopReplacer{
    51  		taskDir: taskDir,
    52  	}
    53  }
    54  
    55  // upperReplacer is a version of taskenv.TaskEnv.ReplaceEnv that upper-cases
    56  // the given input.
    57  type upperReplacer struct {
    58  	taskDir string
    59  }
    60  
    61  func (upperReplacer) ReplaceEnv(s string) string {
    62  	return strings.ToUpper(s)
    63  }
    64  
    65  func (u upperReplacer) ClientPath(p string, join bool) (string, bool) {
    66  	path, escapes := clientPath(u.taskDir, u.ReplaceEnv(p), join)
    67  	return path, escapes
    68  }
    69  
    70  func removeAllT(t *testing.T, path string) {
    71  	require.NoError(t, os.RemoveAll(path))
    72  }
    73  
    74  func TestGetArtifact_getHeaders(t *testing.T) {
    75  	t.Run("nil", func(t *testing.T) {
    76  		require.Nil(t, getHeaders(noopTaskEnv(""), nil))
    77  	})
    78  
    79  	t.Run("empty", func(t *testing.T) {
    80  		require.Nil(t, getHeaders(noopTaskEnv(""), make(map[string]string)))
    81  	})
    82  
    83  	t.Run("set", func(t *testing.T) {
    84  		upperTaskEnv := new(upperReplacer)
    85  		expected := make(http.Header)
    86  		expected.Set("foo", "BAR")
    87  		result := getHeaders(upperTaskEnv, map[string]string{
    88  			"foo": "bar",
    89  		})
    90  		require.Equal(t, expected, result)
    91  	})
    92  }
    93  
    94  func TestGetArtifact_Headers(t *testing.T) {
    95  	file := "output.txt"
    96  
    97  	// Create the test server with a handler that will validate headers are set.
    98  	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    99  		// Validate the expected value for our header.
   100  		value := r.Header.Get("X-Some-Value")
   101  		require.Equal(t, "FOOBAR", value)
   102  
   103  		// Write the value to the file that is our artifact, for fun.
   104  		w.Header().Set("Content-Type", mime.TypeByExtension(filepath.Ext(file)))
   105  		w.WriteHeader(http.StatusOK)
   106  		_, err := io.Copy(w, strings.NewReader(value))
   107  		require.NoError(t, err)
   108  	}))
   109  	defer ts.Close()
   110  
   111  	// Create a temp directory to download into.
   112  	taskDir, err := ioutil.TempDir("", "nomad-test")
   113  	require.NoError(t, err)
   114  	defer removeAllT(t, taskDir)
   115  
   116  	// Create the artifact.
   117  	artifact := &structs.TaskArtifact{
   118  		GetterSource: fmt.Sprintf("%s/%s", ts.URL, file),
   119  		GetterHeaders: map[string]string{
   120  			"X-Some-Value": "foobar",
   121  		},
   122  		RelativeDest: file,
   123  		GetterMode:   "file",
   124  	}
   125  
   126  	// Download the artifact.
   127  	taskEnv := upperReplacer{
   128  		taskDir: taskDir,
   129  	}
   130  	err = GetArtifact(taskEnv, artifact)
   131  	require.NoError(t, err)
   132  
   133  	// Verify artifact exists.
   134  	b, err := ioutil.ReadFile(filepath.Join(taskDir, taskEnv.ReplaceEnv(file)))
   135  	require.NoError(t, err)
   136  
   137  	// Verify we wrote the interpolated header value into the file that is our
   138  	// artifact.
   139  	require.Equal(t, "FOOBAR", string(b))
   140  }
   141  
   142  func TestGetArtifact_FileAndChecksum(t *testing.T) {
   143  	// Create the test server hosting the file to download
   144  	ts := httptest.NewServer(http.FileServer(http.Dir(filepath.Dir("./test-fixtures/"))))
   145  	defer ts.Close()
   146  
   147  	// Create a temp directory to download into
   148  	taskDir, err := ioutil.TempDir("", "nomad-test")
   149  	if err != nil {
   150  		t.Fatalf("failed to make temp directory: %v", err)
   151  	}
   152  	defer removeAllT(t, taskDir)
   153  
   154  	// Create the artifact
   155  	file := "test.sh"
   156  	artifact := &structs.TaskArtifact{
   157  		GetterSource: fmt.Sprintf("%s/%s", ts.URL, file),
   158  		GetterOptions: map[string]string{
   159  			"checksum": "md5:bce963762aa2dbfed13caf492a45fb72",
   160  		},
   161  	}
   162  
   163  	// Download the artifact
   164  	if err := GetArtifact(noopTaskEnv(taskDir), artifact); err != nil {
   165  		t.Fatalf("GetArtifact failed: %v", err)
   166  	}
   167  
   168  	// Verify artifact exists
   169  	if _, err := os.Stat(filepath.Join(taskDir, file)); err != nil {
   170  		t.Fatalf("file not found: %s", err)
   171  	}
   172  }
   173  
   174  func TestGetArtifact_File_RelativeDest(t *testing.T) {
   175  	// Create the test server hosting the file to download
   176  	ts := httptest.NewServer(http.FileServer(http.Dir(filepath.Dir("./test-fixtures/"))))
   177  	defer ts.Close()
   178  
   179  	// Create a temp directory to download into
   180  	taskDir, err := ioutil.TempDir("", "nomad-test")
   181  	if err != nil {
   182  		t.Fatalf("failed to make temp directory: %v", err)
   183  	}
   184  	defer removeAllT(t, taskDir)
   185  
   186  	// Create the artifact
   187  	file := "test.sh"
   188  	relative := "foo/"
   189  	artifact := &structs.TaskArtifact{
   190  		GetterSource: fmt.Sprintf("%s/%s", ts.URL, file),
   191  		GetterOptions: map[string]string{
   192  			"checksum": "md5:bce963762aa2dbfed13caf492a45fb72",
   193  		},
   194  		RelativeDest: relative,
   195  	}
   196  
   197  	// Download the artifact
   198  	if err := GetArtifact(noopTaskEnv(taskDir), artifact); err != nil {
   199  		t.Fatalf("GetArtifact failed: %v", err)
   200  	}
   201  
   202  	// Verify artifact was downloaded to the correct path
   203  	if _, err := os.Stat(filepath.Join(taskDir, relative, file)); err != nil {
   204  		t.Fatalf("file not found: %s", err)
   205  	}
   206  }
   207  
   208  func TestGetArtifact_File_EscapeDest(t *testing.T) {
   209  	// Create the test server hosting the file to download
   210  	ts := httptest.NewServer(http.FileServer(http.Dir(filepath.Dir("./test-fixtures/"))))
   211  	defer ts.Close()
   212  
   213  	// Create a temp directory to download into
   214  	taskDir, err := ioutil.TempDir("", "nomad-test")
   215  	if err != nil {
   216  		t.Fatalf("failed to make temp directory: %v", err)
   217  	}
   218  	defer removeAllT(t, taskDir)
   219  
   220  	// Create the artifact
   221  	file := "test.sh"
   222  	relative := "../../../../foo/"
   223  	artifact := &structs.TaskArtifact{
   224  		GetterSource: fmt.Sprintf("%s/%s", ts.URL, file),
   225  		GetterOptions: map[string]string{
   226  			"checksum": "md5:bce963762aa2dbfed13caf492a45fb72",
   227  		},
   228  		RelativeDest: relative,
   229  	}
   230  
   231  	// attempt to download the artifact
   232  	err = GetArtifact(noopTaskEnv(taskDir), artifact)
   233  	if err == nil || !strings.Contains(err.Error(), "escapes") {
   234  		t.Fatalf("expected GetArtifact to disallow sandbox escape: %v", err)
   235  	}
   236  }
   237  
   238  func TestGetGetterUrl_Interpolation(t *testing.T) {
   239  	// Create the artifact
   240  	artifact := &structs.TaskArtifact{
   241  		GetterSource: "${NOMAD_META_ARTIFACT}",
   242  	}
   243  
   244  	url := "foo.com"
   245  	alloc := mock.Alloc()
   246  	task := alloc.Job.TaskGroups[0].Tasks[0]
   247  	task.Meta = map[string]string{"artifact": url}
   248  	taskEnv := taskenv.NewBuilder(mock.Node(), alloc, task, "global").Build()
   249  
   250  	act, err := getGetterUrl(taskEnv, artifact)
   251  	if err != nil {
   252  		t.Fatalf("getGetterUrl() failed: %v", err)
   253  	}
   254  
   255  	if act != url {
   256  		t.Fatalf("getGetterUrl() returned %q; want %q", act, url)
   257  	}
   258  }
   259  
   260  func TestGetArtifact_InvalidChecksum(t *testing.T) {
   261  	// Create the test server hosting the file to download
   262  	ts := httptest.NewServer(http.FileServer(http.Dir(filepath.Dir("./test-fixtures/"))))
   263  	defer ts.Close()
   264  
   265  	// Create a temp directory to download into
   266  	taskDir, err := ioutil.TempDir("", "nomad-test")
   267  	if err != nil {
   268  		t.Fatalf("failed to make temp directory: %v", err)
   269  	}
   270  	defer removeAllT(t, taskDir)
   271  
   272  	// Create the artifact with an incorrect checksum
   273  	file := "test.sh"
   274  	artifact := &structs.TaskArtifact{
   275  		GetterSource: fmt.Sprintf("%s/%s", ts.URL, file),
   276  		GetterOptions: map[string]string{
   277  			"checksum": "md5:aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
   278  		},
   279  	}
   280  
   281  	// Download the artifact and expect an error
   282  	if err := GetArtifact(noopTaskEnv(taskDir), artifact); err == nil {
   283  		t.Fatalf("GetArtifact should have failed")
   284  	}
   285  }
   286  
   287  func createContents(basedir string, fileContents map[string]string, t *testing.T) {
   288  	for relPath, content := range fileContents {
   289  		folder := basedir
   290  		if strings.Index(relPath, "/") != -1 {
   291  			// Create the folder.
   292  			folder = filepath.Join(basedir, filepath.Dir(relPath))
   293  			if err := os.Mkdir(folder, 0777); err != nil {
   294  				t.Fatalf("failed to make directory: %v", err)
   295  			}
   296  		}
   297  
   298  		// Create a file in the existing folder.
   299  		file := filepath.Join(folder, filepath.Base(relPath))
   300  		if err := ioutil.WriteFile(file, []byte(content), 0777); err != nil {
   301  			t.Fatalf("failed to write data to file %v: %v", file, err)
   302  		}
   303  	}
   304  }
   305  
   306  func checkContents(basedir string, fileContents map[string]string, t *testing.T) {
   307  	for relPath, content := range fileContents {
   308  		path := filepath.Join(basedir, relPath)
   309  		actual, err := ioutil.ReadFile(path)
   310  		if err != nil {
   311  			t.Fatalf("failed to read file %q: %v", path, err)
   312  		}
   313  
   314  		if !reflect.DeepEqual(actual, []byte(content)) {
   315  			t.Fatalf("%q: expected %q; got %q", path, content, string(actual))
   316  		}
   317  	}
   318  }
   319  
   320  func TestGetArtifact_Archive(t *testing.T) {
   321  	// Create the test server hosting the file to download
   322  	ts := httptest.NewServer(http.FileServer(http.Dir(filepath.Dir("./test-fixtures/"))))
   323  	defer ts.Close()
   324  
   325  	// Create a temp directory to download into and create some of the same
   326  	// files that exist in the artifact to ensure they are overridden
   327  	taskDir, err := ioutil.TempDir("", "nomad-test")
   328  	if err != nil {
   329  		t.Fatalf("failed to make temp directory: %v", err)
   330  	}
   331  	defer removeAllT(t, taskDir)
   332  
   333  	create := map[string]string{
   334  		"exist/my.config": "to be replaced",
   335  		"untouched":       "existing top-level",
   336  	}
   337  	createContents(taskDir, create, t)
   338  
   339  	file := "archive.tar.gz"
   340  	artifact := &structs.TaskArtifact{
   341  		GetterSource: fmt.Sprintf("%s/%s", ts.URL, file),
   342  		GetterOptions: map[string]string{
   343  			"checksum": "sha1:20bab73c72c56490856f913cf594bad9a4d730f6",
   344  		},
   345  	}
   346  
   347  	if err := GetArtifact(noopTaskEnv(taskDir), artifact); err != nil {
   348  		t.Fatalf("GetArtifact failed: %v", err)
   349  	}
   350  
   351  	// Verify the unarchiving overrode files properly.
   352  	expected := map[string]string{
   353  		"untouched":       "existing top-level",
   354  		"exist/my.config": "hello world\n",
   355  		"new/my.config":   "hello world\n",
   356  		"test.sh":         "sleep 1\n",
   357  	}
   358  	checkContents(taskDir, expected, t)
   359  }
   360  
   361  func TestGetArtifact_Setuid(t *testing.T) {
   362  	// Create the test server hosting the file to download
   363  	ts := httptest.NewServer(http.FileServer(http.Dir(filepath.Dir("./test-fixtures/"))))
   364  	defer ts.Close()
   365  
   366  	// Create a temp directory to download into and create some of the same
   367  	// files that exist in the artifact to ensure they are overridden
   368  	taskDir, err := ioutil.TempDir("", "nomad-test")
   369  	require.NoError(t, err)
   370  	defer removeAllT(t, taskDir)
   371  
   372  	file := "setuid.tgz"
   373  	artifact := &structs.TaskArtifact{
   374  		GetterSource: fmt.Sprintf("%s/%s", ts.URL, file),
   375  		GetterOptions: map[string]string{
   376  			"checksum": "sha1:e892194748ecbad5d0f60c6c6b2db2bdaa384a90",
   377  		},
   378  	}
   379  
   380  	require.NoError(t, GetArtifact(noopTaskEnv(taskDir), artifact))
   381  
   382  	var expected map[string]int
   383  
   384  	if runtime.GOOS == "windows" {
   385  		// windows doesn't support Chmod changing file permissions.
   386  		expected = map[string]int{
   387  			"public":  0666,
   388  			"private": 0666,
   389  			"setuid":  0666,
   390  		}
   391  	} else {
   392  		// Verify the unarchiving masked files properly.
   393  		expected = map[string]int{
   394  			"public":  0666,
   395  			"private": 0600,
   396  			"setuid":  0755,
   397  		}
   398  	}
   399  
   400  	for file, perm := range expected {
   401  		path := filepath.Join(taskDir, "setuid", file)
   402  		s, err := os.Stat(path)
   403  		require.NoError(t, err)
   404  		p := os.FileMode(perm)
   405  		o := s.Mode()
   406  		require.Equalf(t, p, o, "%s expected %o found %o", file, p, o)
   407  	}
   408  }
   409  
   410  func TestGetGetterUrl_Queries(t *testing.T) {
   411  	cases := []struct {
   412  		name     string
   413  		artifact *structs.TaskArtifact
   414  		output   string
   415  	}{
   416  		{
   417  			name: "adds query parameters",
   418  			artifact: &structs.TaskArtifact{
   419  				GetterSource: "https://foo.com?test=1",
   420  				GetterOptions: map[string]string{
   421  					"foo": "bar",
   422  					"bam": "boom",
   423  				},
   424  			},
   425  			output: "https://foo.com?bam=boom&foo=bar&test=1",
   426  		},
   427  		{
   428  			name: "git without http",
   429  			artifact: &structs.TaskArtifact{
   430  				GetterSource: "github.com/hashicorp/nomad",
   431  				GetterOptions: map[string]string{
   432  					"ref": "abcd1234",
   433  				},
   434  			},
   435  			output: "github.com/hashicorp/nomad?ref=abcd1234",
   436  		},
   437  		{
   438  			name: "git using ssh",
   439  			artifact: &structs.TaskArtifact{
   440  				GetterSource: "git@github.com:hashicorp/nomad?sshkey=1",
   441  				GetterOptions: map[string]string{
   442  					"ref": "abcd1234",
   443  				},
   444  			},
   445  			output: "git@github.com:hashicorp/nomad?ref=abcd1234&sshkey=1",
   446  		},
   447  		{
   448  			name: "s3 scheme 1",
   449  			artifact: &structs.TaskArtifact{
   450  				GetterSource: "s3::https://s3.amazonaws.com/bucket/foo",
   451  				GetterOptions: map[string]string{
   452  					"aws_access_key_id": "abcd1234",
   453  				},
   454  			},
   455  			output: "s3::https://s3.amazonaws.com/bucket/foo?aws_access_key_id=abcd1234",
   456  		},
   457  		{
   458  			name: "s3 scheme 2",
   459  			artifact: &structs.TaskArtifact{
   460  				GetterSource: "s3::https://s3-eu-west-1.amazonaws.com/bucket/foo",
   461  				GetterOptions: map[string]string{
   462  					"aws_access_key_id": "abcd1234",
   463  				},
   464  			},
   465  			output: "s3::https://s3-eu-west-1.amazonaws.com/bucket/foo?aws_access_key_id=abcd1234",
   466  		},
   467  		{
   468  			name: "s3 scheme 3",
   469  			artifact: &structs.TaskArtifact{
   470  				GetterSource: "bucket.s3.amazonaws.com/foo",
   471  				GetterOptions: map[string]string{
   472  					"aws_access_key_id": "abcd1234",
   473  				},
   474  			},
   475  			output: "bucket.s3.amazonaws.com/foo?aws_access_key_id=abcd1234",
   476  		},
   477  		{
   478  			name: "s3 scheme 4",
   479  			artifact: &structs.TaskArtifact{
   480  				GetterSource: "bucket.s3-eu-west-1.amazonaws.com/foo/bar",
   481  				GetterOptions: map[string]string{
   482  					"aws_access_key_id": "abcd1234",
   483  				},
   484  			},
   485  			output: "bucket.s3-eu-west-1.amazonaws.com/foo/bar?aws_access_key_id=abcd1234",
   486  		},
   487  		{
   488  			name: "gcs",
   489  			artifact: &structs.TaskArtifact{
   490  				GetterSource: "gcs::https://www.googleapis.com/storage/v1/b/d/f",
   491  			},
   492  			output: "gcs::https://www.googleapis.com/storage/v1/b/d/f",
   493  		},
   494  		{
   495  			name: "local file",
   496  			artifact: &structs.TaskArtifact{
   497  				GetterSource: "/foo/bar",
   498  			},
   499  			output: "/foo/bar",
   500  		},
   501  	}
   502  
   503  	for _, c := range cases {
   504  		t.Run(c.name, func(t *testing.T) {
   505  			act, err := getGetterUrl(noopTaskEnv(""), c.artifact)
   506  			if err != nil {
   507  				t.Fatalf("want %q; got err %v", c.output, err)
   508  			} else if act != c.output {
   509  				t.Fatalf("want %q; got %q", c.output, act)
   510  			}
   511  		})
   512  	}
   513  }