github.com/ratrocket/u-root@v0.0.0-20180201221235-1cf9f48ee2cf/pkg/pxe/schemes_test.go (about)

     1  package pxe
     2  
     3  import (
     4  	"bytes"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"io/ioutil"
     9  	"net/url"
    10  	"path"
    11  	"reflect"
    12  	"testing"
    13  )
    14  
    15  type MockScheme struct {
    16  	// scheme is the scheme name.
    17  	scheme string
    18  
    19  	// hosts is a map of host -> relative filename to host -> file contents.
    20  	hosts map[string]map[string]string
    21  
    22  	// numCalled is a map of URI string -> number of times GetFile has been
    23  	// called on that URI.
    24  	numCalled map[string]uint
    25  }
    26  
    27  func NewMockScheme(scheme string) *MockScheme {
    28  	return &MockScheme{
    29  		scheme:    scheme,
    30  		hosts:     make(map[string]map[string]string),
    31  		numCalled: make(map[string]uint),
    32  	}
    33  }
    34  
    35  func (m *MockScheme) Add(host string, p string, content string) {
    36  	_, ok := m.hosts[host]
    37  	if !ok {
    38  		m.hosts[host] = make(map[string]string)
    39  	}
    40  
    41  	m.hosts[host][path.Clean(p)] = content
    42  }
    43  
    44  func (m *MockScheme) NumCalled(u *url.URL) uint {
    45  	uri := u.String()
    46  	if c, ok := m.numCalled[uri]; ok {
    47  		return c
    48  	}
    49  	return 0
    50  }
    51  
    52  var (
    53  	errWrongScheme = errors.New("wrong scheme")
    54  	errNoSuchHost  = errors.New("no such host exists")
    55  	errNoSuchFile  = errors.New("no such file exists on this host")
    56  )
    57  
    58  func (m *MockScheme) GetFile(u *url.URL) (io.Reader, error) {
    59  	uri := u.String()
    60  	if _, ok := m.numCalled[uri]; ok {
    61  		m.numCalled[uri]++
    62  	} else {
    63  		m.numCalled[uri] = 1
    64  	}
    65  
    66  	if u.Scheme != m.scheme {
    67  		return nil, errWrongScheme
    68  	}
    69  
    70  	files, ok := m.hosts[u.Host]
    71  	if !ok {
    72  		return nil, errNoSuchHost
    73  	}
    74  
    75  	content, ok := files[path.Clean(u.Path)]
    76  	if !ok {
    77  		return nil, errNoSuchFile
    78  	}
    79  	return bytes.NewBufferString(content), nil
    80  }
    81  
    82  func TestCachedFileSchemeGetFile(t *testing.T) {
    83  	for i, tt := range []struct {
    84  		fs   func() *MockScheme
    85  		uri  *url.URL
    86  		err  error
    87  		want string
    88  	}{
    89  		{
    90  			fs: func() *MockScheme {
    91  				s := NewMockScheme("fooftp")
    92  				s.Add("192.168.0.1", "/default", "haha")
    93  				return s
    94  			},
    95  			uri: &url.URL{
    96  				Scheme: "fooftp",
    97  				Host:   "192.168.0.1",
    98  				Path:   "/default",
    99  			},
   100  			want: "haha",
   101  		},
   102  		{
   103  			fs: func() *MockScheme {
   104  				return NewMockScheme("fooftp")
   105  			},
   106  			uri: &url.URL{
   107  				Scheme: "fooftp",
   108  			},
   109  			err: errNoSuchHost,
   110  		},
   111  	} {
   112  		t.Run(fmt.Sprintf("Test [%02d]", i), func(t *testing.T) {
   113  			ms := tt.fs()
   114  			fs := NewCachedFileScheme(ms)
   115  			r, err := fs.GetFile(tt.uri)
   116  			if err != tt.err {
   117  				t.Errorf("GetFile(%s) = %v, want %v", tt.uri, err, tt.err)
   118  				return
   119  			} else if err == nil {
   120  				content, err := ioutil.ReadAll(r)
   121  				if err != nil {
   122  					t.Errorf("ReadAll = %v, want nil", err)
   123  				}
   124  				if got := string(content); got != tt.want {
   125  					t.Errorf("Read(%s) got %v, want %v", tt.uri, got, tt.want)
   126  				}
   127  			}
   128  
   129  			r2, err2 := fs.GetFile(tt.uri)
   130  			if err2 != tt.err {
   131  				t.Errorf("GetFile2(%s) = %v, want %v", tt.uri, err2, tt.err)
   132  				return
   133  			} else if err2 == nil {
   134  				content2, err := ioutil.ReadAll(r2)
   135  				if err != nil {
   136  					t.Errorf("ReadAll2 = %v, want nil", err)
   137  				}
   138  				if got := string(content2); got != tt.want {
   139  					t.Errorf("Read2(%s) got %v, want %v", tt.uri, got, tt.want)
   140  				}
   141  			}
   142  
   143  			if got := ms.NumCalled(tt.uri); got != 1 {
   144  				t.Errorf("num called(%s) = %d, want 1", tt.uri, got)
   145  			}
   146  		})
   147  	}
   148  }
   149  
   150  func TestGetFile(t *testing.T) {
   151  	for i, tt := range []struct {
   152  		scheme func() *MockScheme
   153  		wd     *url.URL
   154  		uri    string
   155  		err    error
   156  		want   string
   157  	}{
   158  		{
   159  			scheme: func() *MockScheme {
   160  				s := NewMockScheme("fooftp")
   161  				s.Add("192.168.0.1", "/foo/pxelinux.cfg/default", "haha")
   162  				return s
   163  			},
   164  			want: "haha",
   165  			uri:  "default",
   166  			wd: &url.URL{
   167  				Scheme: "fooftp",
   168  				Host:   "192.168.0.1",
   169  				Path:   "/foo/pxelinux.cfg",
   170  			},
   171  		},
   172  		{
   173  			scheme: func() *MockScheme {
   174  				s := NewMockScheme("fooftp")
   175  				return s
   176  			},
   177  			uri: "nosuch://scheme/foo",
   178  			err: ErrNoSuchScheme,
   179  		},
   180  		{
   181  			scheme: func() *MockScheme {
   182  				s := NewMockScheme("fooftp")
   183  				return s
   184  			},
   185  			uri: "fooftp://someotherplace",
   186  			err: errNoSuchHost,
   187  		},
   188  		{
   189  			scheme: func() *MockScheme {
   190  				s := NewMockScheme("fooftp")
   191  				s.Add("somehost", "somefile", "somecontent")
   192  				return s
   193  			},
   194  			uri: "fooftp://somehost/someotherfile",
   195  			err: errNoSuchFile,
   196  		},
   197  	} {
   198  		t.Run(fmt.Sprintf("Test #%02d", i), func(t *testing.T) {
   199  			fs := tt.scheme()
   200  			s := make(Schemes)
   201  			s.Register(fs.scheme, fs)
   202  
   203  			// Test both GetFile and LazyGetFile.
   204  			for _, f := range []func(uri string, wd *url.URL) (io.Reader, error){
   205  				s.GetFile,
   206  				s.LazyGetFile,
   207  			} {
   208  				r, err := f(tt.uri, tt.wd)
   209  				if got, want := err, tt.err; got != want {
   210  					t.Errorf("GetFile() = %v, want %v", got, want)
   211  				}
   212  				if err != nil {
   213  					return
   214  				}
   215  				content, err := ioutil.ReadAll(r)
   216  				if err != nil {
   217  					t.Errorf("bytes.Buffer read returned an error? %v", err)
   218  				}
   219  				if got, want := string(content), tt.want; got != want {
   220  					t.Errorf("GetFile() = %v, want %v", got, want)
   221  				}
   222  			}
   223  		})
   224  	}
   225  }
   226  
   227  func TestParseURI(t *testing.T) {
   228  	for i, tt := range []struct {
   229  		uri  string
   230  		wd   *url.URL
   231  		err  bool
   232  		want *url.URL
   233  	}{
   234  		{
   235  			uri: "default",
   236  			wd: &url.URL{
   237  				Scheme: "tftp",
   238  				Host:   "192.168.1.1",
   239  				Path:   "/foobar/pxelinux.cfg",
   240  			},
   241  			want: &url.URL{
   242  				Scheme: "tftp",
   243  				Host:   "192.168.1.1",
   244  				Path:   "/foobar/pxelinux.cfg/default",
   245  			},
   246  		},
   247  		{
   248  			uri: "http://192.168.2.1/configs/your-machine.cfg",
   249  			wd: &url.URL{
   250  				Scheme: "tftp",
   251  				Host:   "192.168.1.1",
   252  				Path:   "/foobar/pxelinux.cfg",
   253  			},
   254  			want: &url.URL{
   255  				Scheme: "http",
   256  				Host:   "192.168.2.1",
   257  				Path:   "/configs/your-machine.cfg",
   258  			},
   259  		},
   260  	} {
   261  		t.Run(fmt.Sprintf("Test #%02d", i), func(t *testing.T) {
   262  			got, err := parseURI(tt.uri, tt.wd)
   263  			if (err != nil) != tt.err {
   264  				t.Errorf("Wanted error (%v), but got %v", tt.err, err)
   265  			}
   266  			if !reflect.DeepEqual(got, tt.want) {
   267  				t.Errorf("parseURI() = %#v, want %#v", got, tt.want)
   268  			}
   269  		})
   270  	}
   271  }