github.com/rminnich/u-root@v7.0.0+incompatible/pkg/curl/schemes_test.go (about)

     1  // Copyright 2017-2020 the u-root Authors. All rights reserved
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package curl
     6  
     7  import (
     8  	"context"
     9  	"errors"
    10  	"fmt"
    11  	"io"
    12  	"io/ioutil"
    13  	"net/url"
    14  	"testing"
    15  
    16  	"github.com/cenkalti/backoff"
    17  	"github.com/u-root/u-root/pkg/uio"
    18  )
    19  
    20  var (
    21  	errTest = errors.New("Test error")
    22  	testURL = &url.URL{
    23  		Scheme: "fooftp",
    24  		Host:   "192.168.0.1",
    25  		Path:   "/foo/pxelinux.cfg/default",
    26  	}
    27  )
    28  
    29  var tests = []struct {
    30  	name string
    31  	// scheme returns a scheme for testing and a MockScheme to
    32  	// confirm number of calls to Fetch. The distinction is useful
    33  	// when MockScheme is decorated by a SchemeWithRetries. In many
    34  	// cases, the same value is returned twice.
    35  	scheme         func() (FileScheme, *MockScheme)
    36  	url            *url.URL
    37  	err            error
    38  	want           string
    39  	wantFetchCount uint
    40  }{
    41  	{
    42  		name: "successful fetch",
    43  		scheme: func() (FileScheme, *MockScheme) {
    44  			s := NewMockScheme("fooftp")
    45  			s.Add("192.168.0.1", "/foo/pxelinux.cfg/default", "haha")
    46  			return s, s
    47  		},
    48  		url:            testURL,
    49  		want:           "haha",
    50  		wantFetchCount: 1,
    51  	},
    52  	{
    53  		name: "scheme does not exist",
    54  		scheme: func() (FileScheme, *MockScheme) {
    55  			s := NewMockScheme("fooftp")
    56  			return s, s
    57  		},
    58  		url: &url.URL{
    59  			Scheme: "nosuch",
    60  		},
    61  		err:            ErrNoSuchScheme,
    62  		wantFetchCount: 0,
    63  	},
    64  	{
    65  		name: "host does not exist",
    66  		scheme: func() (FileScheme, *MockScheme) {
    67  			s := NewMockScheme("fooftp")
    68  			return s, s
    69  		},
    70  		url: &url.URL{
    71  			Scheme: "fooftp",
    72  			Host:   "someotherplace",
    73  		},
    74  		err:            ErrNoSuchHost,
    75  		wantFetchCount: 1,
    76  	},
    77  	{
    78  		name: "file does not exist",
    79  		scheme: func() (FileScheme, *MockScheme) {
    80  			s := NewMockScheme("fooftp")
    81  			s.Add("somehost", "somefile", "somecontent")
    82  			return s, s
    83  		},
    84  		url: &url.URL{
    85  			Scheme: "fooftp",
    86  			Host:   "somehost",
    87  			Path:   "/someotherfile",
    88  		},
    89  		err:            ErrNoSuchFile,
    90  		wantFetchCount: 1,
    91  	},
    92  	{
    93  		name: "always err",
    94  		scheme: func() (FileScheme, *MockScheme) {
    95  			s := NewMockScheme("fooftp")
    96  			s.Add("192.168.0.1", "/foo/pxelinux.cfg/default", "haha")
    97  			s.SetErr(errTest, 9999)
    98  			return s, s
    99  		},
   100  		url:            testURL,
   101  		err:            errTest,
   102  		wantFetchCount: 1,
   103  	},
   104  	{
   105  		name: "retries but not necessary",
   106  		scheme: func() (FileScheme, *MockScheme) {
   107  			s := NewMockScheme("fooftp")
   108  			s.Add("192.168.0.1", "/foo/pxelinux.cfg/default", "haha")
   109  			r := &SchemeWithRetries{
   110  				Scheme: s,
   111  				// backoff.ZeroBackOff so unit tests run fast.
   112  				BackOff: backoff.WithMaxRetries(&backoff.ZeroBackOff{}, 10),
   113  			}
   114  			return r, s
   115  		},
   116  		url:            testURL,
   117  		want:           "haha",
   118  		wantFetchCount: 1,
   119  	},
   120  	{
   121  		name: "not enough retries",
   122  		scheme: func() (FileScheme, *MockScheme) {
   123  			s := NewMockScheme("fooftp")
   124  			s.Add("192.168.0.1", "/foo/pxelinux.cfg/default", "haha")
   125  			s.SetErr(errTest, 9999)
   126  			r := &SchemeWithRetries{
   127  				Scheme: s,
   128  				// backoff.ZeroBackOff so unit tests run fast.
   129  				BackOff: backoff.WithMaxRetries(&backoff.ZeroBackOff{}, 10),
   130  			}
   131  			return r, s
   132  		},
   133  		url:            testURL,
   134  		err:            errTest,
   135  		wantFetchCount: 11,
   136  	},
   137  	{
   138  		name: "sufficient retries",
   139  		scheme: func() (FileScheme, *MockScheme) {
   140  			s := NewMockScheme("fooftp")
   141  			s.Add("192.168.0.1", "/foo/pxelinux.cfg/default", "haha")
   142  			s.SetErr(errTest, 5)
   143  			r := &SchemeWithRetries{
   144  				Scheme: s,
   145  				// backoff.ZeroBackOff so unit tests run fast.
   146  				BackOff: backoff.WithMaxRetries(&backoff.ZeroBackOff{}, 10),
   147  			}
   148  			return r, s
   149  		},
   150  		url:            testURL,
   151  		want:           "haha",
   152  		wantFetchCount: 6,
   153  	},
   154  	{
   155  		name: "retry filter",
   156  		scheme: func() (FileScheme, *MockScheme) {
   157  			s := NewMockScheme("fooftp")
   158  			s.Add("192.168.0.1", "/foo/pxelinux.cfg/default", "haha")
   159  			s.SetErr(errTest, 5)
   160  			r := &SchemeWithRetries{
   161  				DoRetry: func(u *url.URL, err error) bool {
   162  					return err != errTest
   163  				},
   164  				Scheme: s,
   165  				// backoff.ZeroBackOff so unit tests run fast.
   166  				BackOff: backoff.WithMaxRetries(&backoff.ZeroBackOff{}, 10),
   167  			}
   168  			return r, s
   169  		},
   170  		url:            testURL,
   171  		err:            errTest,
   172  		wantFetchCount: 1,
   173  	},
   174  }
   175  
   176  func TestFetch(t *testing.T) {
   177  	for i, tt := range tests {
   178  		t.Run(fmt.Sprintf("Test #%02d: %s", i, tt.name), func(t *testing.T) {
   179  			var r io.ReaderAt
   180  			var err error
   181  
   182  			fs, ms := tt.scheme()
   183  			s := make(Schemes)
   184  			s.Register(ms.Scheme, fs)
   185  
   186  			r, err = s.Fetch(context.TODO(), tt.url)
   187  			if uErr, ok := err.(*URLError); ok && uErr.Err != tt.err {
   188  				t.Errorf("Fetch() = %v, want %v", uErr.Err, tt.err)
   189  			} else if !ok && err != tt.err {
   190  				t.Errorf("Fetch() = %v, want %v", err, tt.err)
   191  			}
   192  
   193  			// Check number of calls before reading the file.
   194  			numCalled := ms.NumCalled(tt.url)
   195  			if numCalled != tt.wantFetchCount {
   196  				t.Errorf("number times Fetch() called = %v, want %v",
   197  					ms.NumCalled(tt.url), tt.wantFetchCount)
   198  			}
   199  			if err != nil {
   200  				return
   201  			}
   202  
   203  			// Read the entire file.
   204  			content, err := ioutil.ReadAll(uio.Reader(r))
   205  			if err != nil {
   206  				t.Errorf("bytes.Buffer read returned an error? %v", err)
   207  			}
   208  			if got, want := string(content), tt.want; got != want {
   209  				t.Errorf("Fetch() = %v, want %v", got, want)
   210  			}
   211  
   212  			// Check number of calls after reading the file.
   213  			numCalled = ms.NumCalled(tt.url)
   214  			if numCalled != tt.wantFetchCount {
   215  				t.Errorf("number times Fetch() called = %v, want %v",
   216  					ms.NumCalled(tt.url), tt.wantFetchCount)
   217  			}
   218  		})
   219  	}
   220  }
   221  
   222  func TestLazyFetch(t *testing.T) {
   223  	for i, tt := range tests {
   224  		t.Run(fmt.Sprintf("Test #%02d: %s", i, tt.name), func(t *testing.T) {
   225  			var r io.ReaderAt
   226  			var err error
   227  
   228  			fs, ms := tt.scheme()
   229  			s := make(Schemes)
   230  			s.Register(ms.Scheme, fs)
   231  
   232  			r, err = s.LazyFetch(tt.url)
   233  			// Errors are deferred to when file is read except for ErrNoSuchScheme.
   234  			if tt.err == ErrNoSuchScheme {
   235  				if uErr, ok := err.(*URLError); ok && uErr.Err != ErrNoSuchScheme {
   236  					t.Errorf("LazyFetch() = %v, want %v", uErr.Err, tt.err)
   237  				}
   238  			} else if err != nil {
   239  				t.Errorf("LazyFetch() = %v, want nil", err)
   240  			}
   241  
   242  			// Check number of calls before reading the file.
   243  			numCalled := ms.NumCalled(tt.url)
   244  			if numCalled != 0 {
   245  				t.Errorf("number times Fetch() called = %v, want 0", numCalled)
   246  			}
   247  			if err != nil {
   248  				return
   249  			}
   250  
   251  			// Read the entire file.
   252  			content, err := ioutil.ReadAll(uio.Reader(r))
   253  			if uErr, ok := err.(*URLError); ok && uErr.Err != tt.err {
   254  				t.Errorf("ReadAll() = %v, want %v", uErr.Err, tt.err)
   255  			} else if !ok && err != tt.err {
   256  				t.Errorf("ReadAll() = %v, want %v", err, tt.err)
   257  			}
   258  			if got, want := string(content), tt.want; got != want {
   259  				t.Errorf("ReadAll() = %v, want %v", got, want)
   260  			}
   261  
   262  			// Check number of calls after reading the file.
   263  			numCalled = ms.NumCalled(tt.url)
   264  			if numCalled != tt.wantFetchCount {
   265  				t.Errorf("number times Fetch() called = %v, want %v",
   266  					ms.NumCalled(tt.url), tt.wantFetchCount)
   267  			}
   268  		})
   269  	}
   270  }