github.com/pelicanplatform/pelican@v1.0.5/client/handle_http_test.go (about)

     1  /***************************************************************
     2   *
     3   * Copyright (C) 2023, University of Nebraska-Lincoln
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License"); you
     6   * may not use this file except in compliance with the License.  You may
     7   * obtain a copy of the License at
     8   *
     9   *    http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   ***************************************************************/
    18  
    19  package client
    20  
    21  import (
    22  	"bytes"
    23  	"net"
    24  	"net/http"
    25  	"net/http/httptest"
    26  	"net/http/httputil"
    27  	"net/url"
    28  	"os"
    29  	"path/filepath"
    30  	"strings"
    31  	"testing"
    32  	"time"
    33  
    34  	"github.com/spf13/viper"
    35  	"github.com/stretchr/testify/assert"
    36  
    37  	"github.com/pelicanplatform/pelican/config"
    38  	"github.com/pelicanplatform/pelican/namespaces"
    39  )
    40  
    41  func TestMain(m *testing.M) {
    42  	if err := config.InitClient(); err != nil {
    43  		os.Exit(1)
    44  	}
    45  	os.Exit(m.Run())
    46  }
    47  
    48  // TestIsPort calls main.hasPort with a hostname, checking
    49  // for a valid return value.
    50  func TestIsPort(t *testing.T) {
    51  
    52  	if HasPort("blah.not.port:") {
    53  		t.Fatal("Failed to parse port when : at end")
    54  	}
    55  
    56  	if !HasPort("host:1") {
    57  		t.Fatal("Failed to parse with port = 1")
    58  	}
    59  
    60  	if HasPort("https://example.com") {
    61  		t.Fatal("Failed when scheme is specified")
    62  	}
    63  }
    64  
    65  // TestNewTransferDetails checks the creation of transfer details
    66  func TestNewTransferDetails(t *testing.T) {
    67  	os.Setenv("http_proxy", "http://proxy.edu:3128")
    68  
    69  	// Case 1: cache with http
    70  	testCache := namespaces.Cache{
    71  		AuthEndpoint: "cache.edu:8443",
    72  		Endpoint:     "cache.edu:8000",
    73  		Resource:     "Cache",
    74  	}
    75  	transfers := NewTransferDetails(testCache, TransferDetailsOptions{false, ""})
    76  	assert.Equal(t, 2, len(transfers))
    77  	assert.Equal(t, "cache.edu:8000", transfers[0].Url.Host)
    78  	assert.Equal(t, "http", transfers[0].Url.Scheme)
    79  	assert.Equal(t, true, transfers[0].Proxy)
    80  	assert.Equal(t, "cache.edu:8000", transfers[1].Url.Host)
    81  	assert.Equal(t, "http", transfers[1].Url.Scheme)
    82  	assert.Equal(t, false, transfers[1].Proxy)
    83  
    84  	// Case 2: cache with https
    85  	transfers = NewTransferDetails(testCache, TransferDetailsOptions{true, ""})
    86  	assert.Equal(t, 1, len(transfers))
    87  	assert.Equal(t, "cache.edu:8443", transfers[0].Url.Host)
    88  	assert.Equal(t, "https", transfers[0].Url.Scheme)
    89  	assert.Equal(t, false, transfers[0].Proxy)
    90  
    91  	testCache.Endpoint = "cache.edu"
    92  	// Case 3: cache without port with http
    93  	transfers = NewTransferDetails(testCache, TransferDetailsOptions{false, ""})
    94  	assert.Equal(t, 2, len(transfers))
    95  	assert.Equal(t, "cache.edu:8000", transfers[0].Url.Host)
    96  	assert.Equal(t, "http", transfers[0].Url.Scheme)
    97  	assert.Equal(t, true, transfers[0].Proxy)
    98  	assert.Equal(t, "cache.edu:8000", transfers[1].Url.Host)
    99  	assert.Equal(t, "http", transfers[1].Url.Scheme)
   100  	assert.Equal(t, false, transfers[1].Proxy)
   101  
   102  	// Case 4. cache without port with https
   103  	testCache.AuthEndpoint = "cache.edu"
   104  	transfers = NewTransferDetails(testCache, TransferDetailsOptions{true, ""})
   105  	assert.Equal(t, 2, len(transfers))
   106  	assert.Equal(t, "cache.edu:8444", transfers[0].Url.Host)
   107  	assert.Equal(t, "https", transfers[0].Url.Scheme)
   108  	assert.Equal(t, false, transfers[0].Proxy)
   109  	assert.Equal(t, "cache.edu:8443", transfers[1].Url.Host)
   110  	assert.Equal(t, "https", transfers[1].Url.Scheme)
   111  	assert.Equal(t, false, transfers[1].Proxy)
   112  }
   113  
   114  func TestNewTransferDetailsEnv(t *testing.T) {
   115  
   116  	testCache := namespaces.Cache{
   117  		AuthEndpoint: "cache.edu:8443",
   118  		Endpoint:     "cache.edu:8000",
   119  		Resource:     "Cache",
   120  	}
   121  
   122  	os.Setenv("OSG_DISABLE_PROXY_FALLBACK", "")
   123  	err := config.InitClient()
   124  	assert.Nil(t, err)
   125  	transfers := NewTransferDetails(testCache, TransferDetailsOptions{false, ""})
   126  	assert.Equal(t, 1, len(transfers))
   127  	assert.Equal(t, true, transfers[0].Proxy)
   128  
   129  	transfers = NewTransferDetails(testCache, TransferDetailsOptions{true, ""})
   130  	assert.Equal(t, 1, len(transfers))
   131  	assert.Equal(t, "https", transfers[0].Url.Scheme)
   132  	assert.Equal(t, false, transfers[0].Proxy)
   133  	os.Unsetenv("OSG_DISABLE_PROXY_FALLBACK")
   134  	viper.Reset()
   135  	err = config.InitClient()
   136  	assert.Nil(t, err)
   137  }
   138  
   139  func TestSlowTransfers(t *testing.T) {
   140  	// Adjust down some timeouts to speed up the test
   141  	viper.Set("Client.SlowTransferWindow", 5)
   142  	viper.Set("Client.SlowTransferRampupTime", 10)
   143  
   144  	channel := make(chan bool)
   145  	slowDownload := 1024 * 10 // 10 KiB/s < 100 KiB/s
   146  	svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   147  		buffer := make([]byte, slowDownload)
   148  		for {
   149  			select {
   150  			case <-channel:
   151  				return
   152  			default:
   153  				_, err := w.Write(buffer)
   154  				if err != nil {
   155  					return
   156  				}
   157  				w.(http.Flusher).Flush()
   158  				time.Sleep(1 * time.Second)
   159  			}
   160  		}
   161  	}))
   162  
   163  	defer svr.CloseClientConnections()
   164  	defer svr.Close()
   165  
   166  	testCache := namespaces.Cache{
   167  		AuthEndpoint: svr.URL,
   168  		Endpoint:     svr.URL,
   169  		Resource:     "Cache",
   170  	}
   171  	transfers := NewTransferDetails(testCache, TransferDetailsOptions{false, ""})
   172  	assert.Equal(t, 2, len(transfers))
   173  	assert.Equal(t, svr.URL, transfers[0].Url.String())
   174  
   175  	finishedChannel := make(chan bool)
   176  	var err error
   177  	// Do a quick timeout
   178  	go func() {
   179  		_, err = DownloadHTTP(transfers[0], filepath.Join(t.TempDir(), "test.txt"), "")
   180  		finishedChannel <- true
   181  	}()
   182  
   183  	select {
   184  	case <-finishedChannel:
   185  		if err == nil {
   186  			t.Fatal("Error is nil, download should have failed")
   187  		}
   188  	case <-time.After(time.Second * 160):
   189  		// 120 seconds for warmup, 30 seconds for download
   190  		t.Fatal("Maximum downloading time reach, download should have failed")
   191  	}
   192  
   193  	// Close the channel to allow the download to complete
   194  	channel <- true
   195  
   196  	// Make sure the errors are correct
   197  	assert.NotNil(t, err)
   198  	assert.IsType(t, &SlowTransferError{}, err)
   199  }
   200  
   201  // Test stopped transfer
   202  func TestStoppedTransfer(t *testing.T) {
   203  	// Adjust down the timeouts
   204  	viper.Set("Client.StoppedTransferTimeout", 3)
   205  	viper.Set("Client.SlowTransferRampupTime", 100)
   206  
   207  	channel := make(chan bool)
   208  	svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   209  		buffer := make([]byte, 1024*100)
   210  		for {
   211  			select {
   212  			case <-channel:
   213  				return
   214  			default:
   215  				_, err := w.Write(buffer)
   216  				if err != nil {
   217  					return
   218  				}
   219  				w.(http.Flusher).Flush()
   220  				time.Sleep(1 * time.Second)
   221  				buffer = make([]byte, 0)
   222  			}
   223  		}
   224  	}))
   225  
   226  	defer svr.CloseClientConnections()
   227  	defer svr.Close()
   228  
   229  	testCache := namespaces.Cache{
   230  		AuthEndpoint: svr.URL,
   231  		Endpoint:     svr.URL,
   232  		Resource:     "Cache",
   233  	}
   234  	transfers := NewTransferDetails(testCache, TransferDetailsOptions{false, ""})
   235  	assert.Equal(t, 2, len(transfers))
   236  	assert.Equal(t, svr.URL, transfers[0].Url.String())
   237  
   238  	finishedChannel := make(chan bool)
   239  	var err error
   240  
   241  	go func() {
   242  		_, err = DownloadHTTP(transfers[0], filepath.Join(t.TempDir(), "test.txt"), "")
   243  		finishedChannel <- true
   244  	}()
   245  
   246  	select {
   247  	case <-finishedChannel:
   248  		if err == nil {
   249  			t.Fatal("Download should have failed")
   250  		}
   251  	case <-time.After(time.Second * 150):
   252  		t.Fatal("Download should have failed")
   253  	}
   254  
   255  	// Close the channel to allow the download to complete
   256  	channel <- true
   257  
   258  	// Make sure the errors are correct
   259  	assert.NotNil(t, err)
   260  	assert.IsType(t, &StoppedTransferError{}, err, err.Error())
   261  }
   262  
   263  // Test connection error
   264  func TestConnectionError(t *testing.T) {
   265  	l, err := net.Listen("tcp", "127.0.0.1:0")
   266  	if err != nil {
   267  		t.Fatalf("dialClosedPort: Listen failed: %v", err)
   268  	}
   269  	addr := l.Addr().String()
   270  	l.Close()
   271  
   272  	_, err = DownloadHTTP(TransferDetails{Url: url.URL{Host: addr, Scheme: "http"}, Proxy: false}, filepath.Join(t.TempDir(), "test.txt"), "")
   273  
   274  	assert.IsType(t, &ConnectionSetupError{}, err)
   275  
   276  }
   277  
   278  func TestTrailerError(t *testing.T) {
   279  	// Set up an HTTP server that returns an error trailer
   280  	svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   281  		w.Header().Set("Trailer", "X-Transfer-Status")
   282  		w.Header().Set("X-Transfer-Status", "500: Unable to read test.txt; input/output error")
   283  
   284  		chunkedWriter := httputil.NewChunkedWriter(w)
   285  		defer chunkedWriter.Close()
   286  
   287  		_, err := chunkedWriter.Write([]byte("Test data"))
   288  		if err != nil {
   289  			t.Fatalf("Error writing to chunked writer: %v", err)
   290  		}
   291  	}))
   292  
   293  	defer svr.Close()
   294  
   295  	testCache := namespaces.Cache{
   296  		AuthEndpoint: svr.URL,
   297  		Endpoint:     svr.URL,
   298  		Resource:     "Cache",
   299  	}
   300  	transfers := NewTransferDetails(testCache, TransferDetailsOptions{false, ""})
   301  	assert.Equal(t, 2, len(transfers))
   302  	assert.Equal(t, svr.URL, transfers[0].Url.String())
   303  
   304  	// Call DownloadHTTP and check if the error is returned correctly
   305  	_, err := DownloadHTTP(transfers[0], filepath.Join(t.TempDir(), "test.txt"), "")
   306  
   307  	assert.NotNil(t, err)
   308  	assert.EqualError(t, err, "transfer error: Unable to read test.txt; input/output error")
   309  }
   310  
   311  func TestUploadZeroLengthFile(t *testing.T) {
   312  	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   313  
   314  		//t.Logf("%s", dump)
   315  		assert.Equal(t, "PUT", r.Method, "Not PUT Method")
   316  		assert.Equal(t, int64(0), r.ContentLength, "ContentLength should be 0")
   317  	}))
   318  	defer ts.Close()
   319  	reader := bytes.NewReader([]byte{})
   320  	request, err := http.NewRequest("PUT", ts.URL, reader)
   321  	if err != nil {
   322  		assert.NoError(t, err)
   323  	}
   324  
   325  	request.Header.Set("Authorization", "Bearer test")
   326  	errorChan := make(chan error, 1)
   327  	responseChan := make(chan *http.Response)
   328  	go doPut(request, responseChan, errorChan)
   329  	select {
   330  	case err := <-errorChan:
   331  		assert.NoError(t, err)
   332  	case response := <-responseChan:
   333  		assert.Equal(t, http.StatusOK, response.StatusCode)
   334  	case <-time.After(time.Second * 2):
   335  		assert.Fail(t, "Timeout while waiting for response")
   336  	}
   337  }
   338  
   339  func TestFailedUpload(t *testing.T) {
   340  	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   341  
   342  		//t.Logf("%s", dump)
   343  		assert.Equal(t, "PUT", r.Method, "Not PUT Method")
   344  		w.WriteHeader(http.StatusInternalServerError)
   345  		_, err := w.Write([]byte("Error"))
   346  		assert.NoError(t, err)
   347  
   348  	}))
   349  	defer ts.Close()
   350  	reader := strings.NewReader("test")
   351  	request, err := http.NewRequest("PUT", ts.URL, reader)
   352  	if err != nil {
   353  		assert.NoError(t, err)
   354  	}
   355  	request.Header.Set("Authorization", "Bearer test")
   356  	errorChan := make(chan error, 1)
   357  	responseChan := make(chan *http.Response)
   358  	go doPut(request, responseChan, errorChan)
   359  	select {
   360  	case err := <-errorChan:
   361  		assert.Error(t, err)
   362  	case response := <-responseChan:
   363  		assert.Equal(t, http.StatusInternalServerError, response.StatusCode)
   364  	case <-time.After(time.Second * 2):
   365  		assert.Fail(t, "Timeout while waiting for response")
   366  	}
   367  }
   368  
   369  func TestFullUpload(t *testing.T) {
   370  	testFileContent := "test file content"
   371  	ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   372  
   373  		//t.Logf("%s", dump)
   374  		assert.Equal(t, "PUT", r.Method, "Not PUT Method")
   375  		_, err := w.Write([]byte(":)"))
   376  		assert.NoError(t, err)
   377  	}))
   378  	defer ts.Close()
   379  
   380  	// Create the temporary file to upload
   381  	tempFile, err := os.CreateTemp(t.TempDir(), "test")
   382  	assert.NoError(t, err, "Error creating temp file")
   383  	defer os.Remove(tempFile.Name())
   384  	_, err = tempFile.WriteString(testFileContent)
   385  	assert.NoError(t, err, "Error writing to temp file")
   386  	tempFile.Close()
   387  
   388  	// Create the namespace (only the write back host is read)
   389  	testURL, err := url.Parse(ts.URL)
   390  	assert.NoError(t, err, "Error parsing test URL")
   391  	testNamespace := namespaces.Namespace{
   392  		WriteBackHost: "https://" + testURL.Host,
   393  	}
   394  
   395  	// Upload the file
   396  	uploadURL, err := url.Parse("stash:///test/stuff/blah.txt")
   397  	assert.NoError(t, err, "Error parsing upload URL")
   398  	// Set the upload client to trust the server
   399  	UploadClient = ts.Client()
   400  	uploaded, err := UploadFile(tempFile.Name(), uploadURL, "Bearer test", testNamespace)
   401  	assert.NoError(t, err, "Error uploading file")
   402  	assert.Equal(t, int64(len(testFileContent)), uploaded, "Uploaded file size does not match")
   403  
   404  	// Upload an osdf file
   405  	uploadURL, err = url.Parse("osdf:///test/stuff/blah.txt")
   406  	assert.NoError(t, err, "Error parsing upload URL")
   407  	// Set the upload client to trust the server
   408  	UploadClient = ts.Client()
   409  	uploaded, err = UploadFile(tempFile.Name(), uploadURL, "Bearer test", testNamespace)
   410  	assert.NoError(t, err, "Error uploading file")
   411  	assert.Equal(t, int64(len(testFileContent)), uploaded, "Uploaded file size does not match")
   412  }