github.com/mvdan/u-root-coreutils@v0.0.0-20230122170626-c2eef2898555/pkg/curl/schemes_test.go (about)

     1  // Copyright 2017-2020 the u-root Authors. All rights reserved
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package curl
     6  
     7  import (
     8  	"context"
     9  	"errors"
    10  	"fmt"
    11  	"io"
    12  	"net/http"
    13  	"net/http/httptest"
    14  	"net/url"
    15  	"testing"
    16  
    17  	"github.com/cenkalti/backoff/v4"
    18  	"github.com/mvdan/u-root-coreutils/pkg/uio"
    19  )
    20  
    21  var (
    22  	errTest = errors.New("Test error")
    23  	testURL = &url.URL{
    24  		Scheme: "fooftp",
    25  		Host:   "192.168.0.1",
    26  		Path:   "/foo/pxelinux.cfg/default",
    27  	}
    28  )
    29  
    30  var tests = []struct {
    31  	name string
    32  	// scheme returns a scheme for testing and a MockScheme to
    33  	// confirm number of calls to Fetch. The distinction is useful
    34  	// when MockScheme is decorated by a SchemeWithRetries. In many
    35  	// cases, the same value is returned twice.
    36  	scheme         func() (FileScheme, *MockScheme)
    37  	url            *url.URL
    38  	err            error
    39  	want           string
    40  	wantFetchCount uint
    41  }{
    42  	{
    43  		name: "successful fetch",
    44  		scheme: func() (FileScheme, *MockScheme) {
    45  			s := NewMockScheme("fooftp")
    46  			s.Add("192.168.0.1", "/foo/pxelinux.cfg/default", "haha")
    47  			return s, s
    48  		},
    49  		url:            testURL,
    50  		want:           "haha",
    51  		wantFetchCount: 1,
    52  	},
    53  	{
    54  		name: "scheme does not exist",
    55  		scheme: func() (FileScheme, *MockScheme) {
    56  			s := NewMockScheme("fooftp")
    57  			return s, s
    58  		},
    59  		url: &url.URL{
    60  			Scheme: "nosuch",
    61  		},
    62  		err:            ErrNoSuchScheme,
    63  		wantFetchCount: 0,
    64  	},
    65  	{
    66  		name: "host does not exist",
    67  		scheme: func() (FileScheme, *MockScheme) {
    68  			s := NewMockScheme("fooftp")
    69  			return s, s
    70  		},
    71  		url: &url.URL{
    72  			Scheme: "fooftp",
    73  			Host:   "someotherplace",
    74  		},
    75  		err:            ErrNoSuchHost,
    76  		wantFetchCount: 1,
    77  	},
    78  	{
    79  		name: "file does not exist",
    80  		scheme: func() (FileScheme, *MockScheme) {
    81  			s := NewMockScheme("fooftp")
    82  			s.Add("somehost", "somefile", "somecontent")
    83  			return s, s
    84  		},
    85  		url: &url.URL{
    86  			Scheme: "fooftp",
    87  			Host:   "somehost",
    88  			Path:   "/someotherfile",
    89  		},
    90  		err:            ErrNoSuchFile,
    91  		wantFetchCount: 1,
    92  	},
    93  	{
    94  		name: "always err",
    95  		scheme: func() (FileScheme, *MockScheme) {
    96  			s := NewMockScheme("fooftp")
    97  			s.Add("192.168.0.1", "/foo/pxelinux.cfg/default", "haha")
    98  			s.SetErr(errTest, 9999)
    99  			return s, s
   100  		},
   101  		url:            testURL,
   102  		err:            errTest,
   103  		wantFetchCount: 1,
   104  	},
   105  	{
   106  		name: "retries but not necessary",
   107  		scheme: func() (FileScheme, *MockScheme) {
   108  			s := NewMockScheme("fooftp")
   109  			s.Add("192.168.0.1", "/foo/pxelinux.cfg/default", "haha")
   110  			r := &SchemeWithRetries{
   111  				Scheme: s,
   112  				// backoff.ZeroBackOff so unit tests run fast.
   113  				BackOff: backoff.WithMaxRetries(&backoff.ZeroBackOff{}, 10),
   114  			}
   115  			return r, s
   116  		},
   117  		url:            testURL,
   118  		want:           "haha",
   119  		wantFetchCount: 1,
   120  	},
   121  	{
   122  		name: "not enough retries",
   123  		scheme: func() (FileScheme, *MockScheme) {
   124  			s := NewMockScheme("fooftp")
   125  			s.Add("192.168.0.1", "/foo/pxelinux.cfg/default", "haha")
   126  			s.SetErr(errTest, 9999)
   127  			r := &SchemeWithRetries{
   128  				Scheme: s,
   129  				// backoff.ZeroBackOff so unit tests run fast.
   130  				BackOff: backoff.WithMaxRetries(&backoff.ZeroBackOff{}, 10),
   131  			}
   132  			return r, s
   133  		},
   134  		url:            testURL,
   135  		err:            errTest,
   136  		wantFetchCount: 11,
   137  	},
   138  	{
   139  		name: "sufficient retries",
   140  		scheme: func() (FileScheme, *MockScheme) {
   141  			s := NewMockScheme("fooftp")
   142  			s.Add("192.168.0.1", "/foo/pxelinux.cfg/default", "haha")
   143  			s.SetErr(errTest, 5)
   144  			r := &SchemeWithRetries{
   145  				Scheme: s,
   146  				// backoff.ZeroBackOff so unit tests run fast.
   147  				BackOff: backoff.WithMaxRetries(&backoff.ZeroBackOff{}, 10),
   148  			}
   149  			return r, s
   150  		},
   151  		url:            testURL,
   152  		want:           "haha",
   153  		wantFetchCount: 6,
   154  	},
   155  	{
   156  		name: "retry filter",
   157  		scheme: func() (FileScheme, *MockScheme) {
   158  			s := NewMockScheme("fooftp")
   159  			s.Add("192.168.0.1", "/foo/pxelinux.cfg/default", "haha")
   160  			s.SetErr(errTest, 5)
   161  			r := &SchemeWithRetries{
   162  				DoRetry: func(u *url.URL, err error) bool {
   163  					return err != errTest
   164  				},
   165  				Scheme: s,
   166  				// backoff.ZeroBackOff so unit tests run fast.
   167  				BackOff: backoff.WithMaxRetries(&backoff.ZeroBackOff{}, 10),
   168  			}
   169  			return r, s
   170  		},
   171  		url:            testURL,
   172  		err:            errTest,
   173  		wantFetchCount: 1,
   174  	},
   175  }
   176  
   177  func TestFetchWithoutCache(t *testing.T) {
   178  	for i, tt := range tests {
   179  		t.Run(fmt.Sprintf("Test #%02d: %s", i, tt.name), func(t *testing.T) {
   180  			var r io.Reader
   181  			var err error
   182  
   183  			fs, ms := tt.scheme()
   184  			s := make(Schemes)
   185  			s.Register(ms.Scheme, fs)
   186  
   187  			r, err = s.FetchWithoutCache(context.TODO(), tt.url)
   188  			if uErr, ok := err.(*URLError); ok && uErr.Err != tt.err {
   189  				t.Errorf("FetchWithoutCache() = %v, want %v", uErr.Err, tt.err)
   190  			} else if !ok && err != tt.err {
   191  				t.Errorf("FetchWithoutCache() = %v, want %v", err, tt.err)
   192  			}
   193  
   194  			// Check number of calls before reading the file.
   195  			numCalled := ms.NumCalled(tt.url)
   196  			if numCalled != tt.wantFetchCount {
   197  				t.Errorf("number times Fetch() called = %v, want %v",
   198  					ms.NumCalled(tt.url), tt.wantFetchCount)
   199  			}
   200  			if err != nil {
   201  				return
   202  			}
   203  
   204  			// Read the entire file.
   205  			content, err := io.ReadAll(r)
   206  			if err != nil {
   207  				t.Errorf("bytes.Buffer read returned an error? %v", err)
   208  			}
   209  			if got, want := string(content), tt.want; got != want {
   210  				t.Errorf("Fetch() = %v, want %v", got, want)
   211  			}
   212  
   213  			// Check number of calls after reading the file.
   214  			numCalled = ms.NumCalled(tt.url)
   215  			if numCalled != tt.wantFetchCount {
   216  				t.Errorf("number times Fetch() called = %v, want %v",
   217  					ms.NumCalled(tt.url), tt.wantFetchCount)
   218  			}
   219  		})
   220  	}
   221  }
   222  
   223  func TestFetch(t *testing.T) {
   224  	for i, tt := range tests {
   225  		t.Run(fmt.Sprintf("Test #%02d: %s", i, tt.name), func(t *testing.T) {
   226  			var r io.ReaderAt
   227  			var err error
   228  
   229  			fs, ms := tt.scheme()
   230  			s := make(Schemes)
   231  			s.Register(ms.Scheme, fs)
   232  
   233  			r, err = s.Fetch(context.TODO(), tt.url)
   234  			if uErr, ok := err.(*URLError); ok && uErr.Err != tt.err {
   235  				t.Errorf("Fetch() = %v, want %v", uErr.Err, tt.err)
   236  			} else if !ok && err != tt.err {
   237  				t.Errorf("Fetch() = %v, want %v", err, tt.err)
   238  			}
   239  
   240  			// Check number of calls before reading the file.
   241  			numCalled := ms.NumCalled(tt.url)
   242  			if numCalled != tt.wantFetchCount {
   243  				t.Errorf("number times Fetch() called = %v, want %v",
   244  					ms.NumCalled(tt.url), tt.wantFetchCount)
   245  			}
   246  			if err != nil {
   247  				return
   248  			}
   249  
   250  			// Read the entire file.
   251  			content, err := io.ReadAll(uio.Reader(r))
   252  			if err != nil {
   253  				t.Errorf("bytes.Buffer read returned an error? %v", err)
   254  			}
   255  			if got, want := string(content), tt.want; got != want {
   256  				t.Errorf("Fetch() = %v, want %v", got, want)
   257  			}
   258  
   259  			// Check number of calls after reading the file.
   260  			numCalled = ms.NumCalled(tt.url)
   261  			if numCalled != tt.wantFetchCount {
   262  				t.Errorf("number times Fetch() called = %v, want %v",
   263  					ms.NumCalled(tt.url), tt.wantFetchCount)
   264  			}
   265  		})
   266  	}
   267  }
   268  
   269  func TestLazyFetch(t *testing.T) {
   270  	for i, tt := range tests {
   271  		t.Run(fmt.Sprintf("Test #%02d: %s", i, tt.name), func(t *testing.T) {
   272  			var r io.ReaderAt
   273  			var err error
   274  
   275  			fs, ms := tt.scheme()
   276  			s := make(Schemes)
   277  			s.Register(ms.Scheme, fs)
   278  
   279  			r, err = s.LazyFetch(tt.url)
   280  			// Errors are deferred to when file is read except for ErrNoSuchScheme.
   281  			if tt.err == ErrNoSuchScheme {
   282  				if uErr, ok := err.(*URLError); ok && uErr.Err != ErrNoSuchScheme {
   283  					t.Errorf("LazyFetch() = %v, want %v", uErr.Err, tt.err)
   284  				}
   285  			} else if err != nil {
   286  				t.Errorf("LazyFetch() = %v, want nil", err)
   287  			}
   288  
   289  			// Check number of calls before reading the file.
   290  			numCalled := ms.NumCalled(tt.url)
   291  			if numCalled != 0 {
   292  				t.Errorf("number times Fetch() called = %v, want 0", numCalled)
   293  			}
   294  			if err != nil {
   295  				return
   296  			}
   297  
   298  			// Read the entire file.
   299  			content, err := io.ReadAll(uio.Reader(r))
   300  			if uErr, ok := err.(*URLError); ok && uErr.Err != tt.err {
   301  				t.Errorf("ReadAll() = %v, want %v", uErr.Err, tt.err)
   302  			} else if !ok && err != tt.err {
   303  				t.Errorf("ReadAll() = %v, want %v", err, tt.err)
   304  			}
   305  			if got, want := string(content), tt.want; got != want {
   306  				t.Errorf("ReadAll() = %v, want %v", got, want)
   307  			}
   308  
   309  			// Check number of calls after reading the file.
   310  			numCalled = ms.NumCalled(tt.url)
   311  			if numCalled != tt.wantFetchCount {
   312  				t.Errorf("number times Fetch() called = %v, want %v",
   313  					ms.NumCalled(tt.url), tt.wantFetchCount)
   314  			}
   315  		})
   316  	}
   317  }
   318  
   319  func TestHttpFetches(t *testing.T) {
   320  	c := "fetch content"
   321  	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   322  		fmt.Fprint(w, c)
   323  	}))
   324  	defer ts.Close()
   325  
   326  	fURL, err := url.Parse(ts.URL)
   327  	if err != nil {
   328  		t.Fatalf("url.Parse(%s) = %v, want no error", ts.URL, err)
   329  	}
   330  
   331  	// Fetch need to fetch the content as is.
   332  	fetchFile, err := Fetch(context.Background(), fURL)
   333  	if err != nil {
   334  		t.Errorf("Fetch(context.Background(), %s) = %v, want no error", fURL, err)
   335  	}
   336  	got, err := io.ReadAll(io.NewSectionReader(fetchFile, 0, int64(len(c))))
   337  	if err != nil {
   338  		t.Errorf("io.ReadAll(%v) = %v, want no error", fetchFile, err)
   339  	}
   340  	if string(got) != c {
   341  		t.Errorf("got %s, want %s", got, c)
   342  	}
   343  
   344  	// FetchWithoutCache need to fetch the content as is.
   345  	fetchFileNoCache, err := FetchWithoutCache(context.Background(), fURL)
   346  	if err != nil {
   347  		t.Errorf("FetchWithoutCache(context.Background(), %s) = %v, want no error", fURL, err)
   348  	}
   349  	got, err = io.ReadAll(fetchFileNoCache)
   350  	if err != nil {
   351  		t.Errorf("io.ReadAll(%s) = %v, want no error", fetchFileNoCache, err)
   352  	}
   353  	if string(got) != c {
   354  		t.Errorf("got %s, want %s", got, c)
   355  	}
   356  }