github.com/zaolin/u-root@v0.0.0-20200428085104-64aaafd46c6d/pkg/curl/schemes_test.go (about)

     1  // Copyright 2017-2020 the u-root Authors. All rights reserved
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package curl
     6  
     7  import (
     8  	"errors"
     9  	"fmt"
    10  	"io"
    11  	"io/ioutil"
    12  	"net/url"
    13  	"testing"
    14  
    15  	"github.com/cenkalti/backoff"
    16  	"github.com/u-root/u-root/pkg/uio"
    17  )
    18  
    19  var (
    20  	errTest = errors.New("Test error")
    21  	testURL = &url.URL{
    22  		Scheme: "fooftp",
    23  		Host:   "192.168.0.1",
    24  		Path:   "/foo/pxelinux.cfg/default",
    25  	}
    26  )
    27  
    28  var tests = []struct {
    29  	name string
    30  	// scheme returns a scheme for testing and a MockScheme to
    31  	// confirm number of calls to Fetch. The distinction is useful
    32  	// when MockScheme is decorated by a SchemeWithRetries. In many
    33  	// cases, the same value is returned twice.
    34  	scheme         func() (FileScheme, *MockScheme)
    35  	url            *url.URL
    36  	err            error
    37  	want           string
    38  	wantFetchCount uint
    39  }{
    40  	{
    41  		name: "successful fetch",
    42  		scheme: func() (FileScheme, *MockScheme) {
    43  			s := NewMockScheme("fooftp")
    44  			s.Add("192.168.0.1", "/foo/pxelinux.cfg/default", "haha")
    45  			return s, s
    46  		},
    47  		url:            testURL,
    48  		want:           "haha",
    49  		wantFetchCount: 1,
    50  	},
    51  	{
    52  		name: "scheme does not exist",
    53  		scheme: func() (FileScheme, *MockScheme) {
    54  			s := NewMockScheme("fooftp")
    55  			return s, s
    56  		},
    57  		url: &url.URL{
    58  			Scheme: "nosuch",
    59  		},
    60  		err:            ErrNoSuchScheme,
    61  		wantFetchCount: 0,
    62  	},
    63  	{
    64  		name: "host does not exist",
    65  		scheme: func() (FileScheme, *MockScheme) {
    66  			s := NewMockScheme("fooftp")
    67  			return s, s
    68  		},
    69  		url: &url.URL{
    70  			Scheme: "fooftp",
    71  			Host:   "someotherplace",
    72  		},
    73  		err:            ErrNoSuchHost,
    74  		wantFetchCount: 1,
    75  	},
    76  	{
    77  		name: "file does not exist",
    78  		scheme: func() (FileScheme, *MockScheme) {
    79  			s := NewMockScheme("fooftp")
    80  			s.Add("somehost", "somefile", "somecontent")
    81  			return s, s
    82  		},
    83  		url: &url.URL{
    84  			Scheme: "fooftp",
    85  			Host:   "somehost",
    86  			Path:   "/someotherfile",
    87  		},
    88  		err:            ErrNoSuchFile,
    89  		wantFetchCount: 1,
    90  	},
    91  	{
    92  		name: "always err",
    93  		scheme: func() (FileScheme, *MockScheme) {
    94  			s := NewMockScheme("fooftp")
    95  			s.Add("192.168.0.1", "/foo/pxelinux.cfg/default", "haha")
    96  			s.SetErr(errTest, 9999)
    97  			return s, s
    98  		},
    99  		url:            testURL,
   100  		err:            errTest,
   101  		wantFetchCount: 1,
   102  	},
   103  	{
   104  		name: "retries but not necessary",
   105  		scheme: func() (FileScheme, *MockScheme) {
   106  			s := NewMockScheme("fooftp")
   107  			s.Add("192.168.0.1", "/foo/pxelinux.cfg/default", "haha")
   108  			r := &SchemeWithRetries{
   109  				Scheme: s,
   110  				// backoff.ZeroBackOff so unit tests run fast.
   111  				BackOff: backoff.WithMaxRetries(&backoff.ZeroBackOff{}, 10),
   112  			}
   113  			return r, s
   114  		},
   115  		url:            testURL,
   116  		want:           "haha",
   117  		wantFetchCount: 1,
   118  	},
   119  	{
   120  		name: "not enough retries",
   121  		scheme: func() (FileScheme, *MockScheme) {
   122  			s := NewMockScheme("fooftp")
   123  			s.Add("192.168.0.1", "/foo/pxelinux.cfg/default", "haha")
   124  			s.SetErr(errTest, 9999)
   125  			r := &SchemeWithRetries{
   126  				Scheme: s,
   127  				// backoff.ZeroBackOff so unit tests run fast.
   128  				BackOff: backoff.WithMaxRetries(&backoff.ZeroBackOff{}, 10),
   129  			}
   130  			return r, s
   131  		},
   132  		url:            testURL,
   133  		err:            errTest,
   134  		wantFetchCount: 11,
   135  	},
   136  	{
   137  		name: "sufficient retries",
   138  		scheme: func() (FileScheme, *MockScheme) {
   139  			s := NewMockScheme("fooftp")
   140  			s.Add("192.168.0.1", "/foo/pxelinux.cfg/default", "haha")
   141  			s.SetErr(errTest, 5)
   142  			r := &SchemeWithRetries{
   143  				Scheme: s,
   144  				// backoff.ZeroBackOff so unit tests run fast.
   145  				BackOff: backoff.WithMaxRetries(&backoff.ZeroBackOff{}, 10),
   146  			}
   147  			return r, s
   148  		},
   149  		url:            testURL,
   150  		want:           "haha",
   151  		wantFetchCount: 6,
   152  	},
   153  	{
   154  		name: "retry filter",
   155  		scheme: func() (FileScheme, *MockScheme) {
   156  			s := NewMockSchemeRetryFilter("fooftp")
   157  			s.Add("192.168.0.1", "/foo/pxelinux.cfg/default", "haha")
   158  			s.SetErr(errTest, 5)
   159  			s.SetRetryFilter(func(u *url.URL, err error) bool {
   160  				return err != errTest
   161  			})
   162  			r := &SchemeWithRetries{
   163  				Scheme: s,
   164  				// backoff.ZeroBackOff so unit tests run fast.
   165  				BackOff: backoff.WithMaxRetries(&backoff.ZeroBackOff{}, 10),
   166  			}
   167  			return r, s.MockScheme
   168  		},
   169  		url:            testURL,
   170  		err:            errTest,
   171  		wantFetchCount: 1,
   172  	},
   173  }
   174  
   175  func TestFetch(t *testing.T) {
   176  	for i, tt := range tests {
   177  		t.Run(fmt.Sprintf("Test #%02d: %s", i, tt.name), func(t *testing.T) {
   178  			var r io.ReaderAt
   179  			var err error
   180  
   181  			fs, ms := tt.scheme()
   182  			s := make(Schemes)
   183  			s.Register(ms.Scheme, fs)
   184  
   185  			r, err = s.Fetch(tt.url)
   186  			if uErr, ok := err.(*URLError); ok && uErr.Err != tt.err {
   187  				t.Errorf("Fetch() = %v, want %v", uErr.Err, tt.err)
   188  			} else if !ok && err != tt.err {
   189  				t.Errorf("Fetch() = %v, want %v", err, tt.err)
   190  			}
   191  
   192  			// Check number of calls before reading the file.
   193  			numCalled := ms.NumCalled(tt.url)
   194  			if numCalled != tt.wantFetchCount {
   195  				t.Errorf("number times Fetch() called = %v, want %v",
   196  					ms.NumCalled(tt.url), tt.wantFetchCount)
   197  			}
   198  			if err != nil {
   199  				return
   200  			}
   201  
   202  			// Read the entire file.
   203  			content, err := ioutil.ReadAll(uio.Reader(r))
   204  			if err != nil {
   205  				t.Errorf("bytes.Buffer read returned an error? %v", err)
   206  			}
   207  			if got, want := string(content), tt.want; got != want {
   208  				t.Errorf("Fetch() = %v, want %v", got, want)
   209  			}
   210  
   211  			// Check number of calls after reading the file.
   212  			numCalled = ms.NumCalled(tt.url)
   213  			if numCalled != tt.wantFetchCount {
   214  				t.Errorf("number times Fetch() called = %v, want %v",
   215  					ms.NumCalled(tt.url), tt.wantFetchCount)
   216  			}
   217  		})
   218  	}
   219  }
   220  
   221  func TestLazyFetch(t *testing.T) {
   222  	for i, tt := range tests {
   223  		t.Run(fmt.Sprintf("Test #%02d: %s", i, tt.name), func(t *testing.T) {
   224  			var r io.ReaderAt
   225  			var err error
   226  
   227  			fs, ms := tt.scheme()
   228  			s := make(Schemes)
   229  			s.Register(ms.Scheme, fs)
   230  
   231  			r, err = s.LazyFetch(tt.url)
   232  			// Errors are deferred to when file is read except for ErrNoSuchScheme.
   233  			if tt.err == ErrNoSuchScheme {
   234  				if uErr, ok := err.(*URLError); ok && uErr.Err != ErrNoSuchScheme {
   235  					t.Errorf("LazyFetch() = %v, want %v", uErr.Err, tt.err)
   236  				}
   237  			} else if err != nil {
   238  				t.Errorf("LazyFetch() = %v, want nil", err)
   239  			}
   240  
   241  			// Check number of calls before reading the file.
   242  			numCalled := ms.NumCalled(tt.url)
   243  			if numCalled != 0 {
   244  				t.Errorf("number times Fetch() called = %v, want 0", numCalled)
   245  			}
   246  			if err != nil {
   247  				return
   248  			}
   249  
   250  			// Read the entire file.
   251  			content, err := ioutil.ReadAll(uio.Reader(r))
   252  			if uErr, ok := err.(*URLError); ok && uErr.Err != tt.err {
   253  				t.Errorf("ReadAll() = %v, want %v", uErr.Err, tt.err)
   254  			} else if !ok && err != tt.err {
   255  				t.Errorf("ReadAll() = %v, want %v", err, tt.err)
   256  			}
   257  			if got, want := string(content), tt.want; got != want {
   258  				t.Errorf("ReadAll() = %v, want %v", got, want)
   259  			}
   260  
   261  			// Check number of calls after reading the file.
   262  			numCalled = ms.NumCalled(tt.url)
   263  			if numCalled != tt.wantFetchCount {
   264  				t.Errorf("number times Fetch() called = %v, want %v",
   265  					ms.NumCalled(tt.url), tt.wantFetchCount)
   266  			}
   267  		})
   268  	}
   269  }