github.com/craftyguy/u-root@v1.0.0/pkg/pxe/schemes_test.go (about) 1 // Copyright 2017-2018 the u-root Authors. All rights reserved 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package pxe 6 7 import ( 8 "errors" 9 "fmt" 10 "io" 11 "io/ioutil" 12 "net/url" 13 "path" 14 "reflect" 15 "strings" 16 "testing" 17 18 "github.com/u-root/u-root/pkg/uio" 19 ) 20 21 type MockScheme struct { 22 // scheme is the scheme name. 23 scheme string 24 25 // hosts is a map of host -> relative filename to host -> file contents. 26 hosts map[string]map[string]string 27 28 // numCalled is a map of URL string -> number of times GetFile has been 29 // called on that URL. 30 numCalled map[string]uint 31 } 32 33 func NewMockScheme(scheme string) *MockScheme { 34 return &MockScheme{ 35 scheme: scheme, 36 hosts: make(map[string]map[string]string), 37 numCalled: make(map[string]uint), 38 } 39 } 40 41 func (m *MockScheme) Add(host string, p string, content string) { 42 _, ok := m.hosts[host] 43 if !ok { 44 m.hosts[host] = make(map[string]string) 45 } 46 47 m.hosts[host][path.Clean(p)] = content 48 } 49 50 func (m *MockScheme) NumCalled(u *url.URL) uint { 51 url := u.String() 52 if c, ok := m.numCalled[url]; ok { 53 return c 54 } 55 return 0 56 } 57 58 var ( 59 errWrongScheme = errors.New("wrong scheme") 60 errNoSuchHost = errors.New("no such host exists") 61 errNoSuchFile = errors.New("no such file exists on this host") 62 ) 63 64 func (m *MockScheme) GetFile(u *url.URL) (io.ReaderAt, error) { 65 url := u.String() 66 if _, ok := m.numCalled[url]; ok { 67 m.numCalled[url]++ 68 } else { 69 m.numCalled[url] = 1 70 } 71 72 if u.Scheme != m.scheme { 73 return nil, errWrongScheme 74 } 75 76 files, ok := m.hosts[u.Host] 77 if !ok { 78 return nil, errNoSuchHost 79 } 80 81 content, ok := files[path.Clean(u.Path)] 82 if !ok { 83 return nil, errNoSuchFile 84 } 85 return strings.NewReader(content), nil 86 } 87 88 func TestGetFile(t *testing.T) { 89 for i, tt := range []struct { 90 scheme func() *MockScheme 91 url *url.URL 92 err error 93 want string 94 }{ 95 { 96 scheme: func() *MockScheme { 97 s := NewMockScheme("fooftp") 98 s.Add("192.168.0.1", "/foo/pxelinux.cfg/default", "haha") 99 return s 100 }, 101 want: "haha", 102 url: &url.URL{ 103 Scheme: "fooftp", 104 Host: "192.168.0.1", 105 Path: "/foo/pxelinux.cfg/default", 106 }, 107 }, 108 { 109 scheme: func() *MockScheme { 110 s := NewMockScheme("fooftp") 111 return s 112 }, 113 url: &url.URL{ 114 Scheme: "nosuch", 115 }, 116 err: ErrNoSuchScheme, 117 }, 118 { 119 scheme: func() *MockScheme { 120 s := NewMockScheme("fooftp") 121 return s 122 }, 123 url: &url.URL{ 124 Scheme: "fooftp", 125 Host: "someotherplace", 126 }, 127 err: errNoSuchHost, 128 }, 129 { 130 scheme: func() *MockScheme { 131 s := NewMockScheme("fooftp") 132 s.Add("somehost", "somefile", "somecontent") 133 return s 134 }, 135 url: &url.URL{ 136 Scheme: "fooftp", 137 Host: "somehost", 138 Path: "/someotherfile", 139 }, 140 err: errNoSuchFile, 141 }, 142 } { 143 t.Run(fmt.Sprintf("Test #%02d", i), func(t *testing.T) { 144 fs := tt.scheme() 145 s := make(Schemes) 146 s.Register(fs.scheme, fs) 147 148 // Test both GetFile and LazyGetFile. 149 for _, f := range []func(url *url.URL) (io.ReaderAt, error){ 150 s.GetFile, 151 s.LazyGetFile, 152 } { 153 r, err := f(tt.url) 154 if uErr, ok := err.(*URLError); ok && uErr.Err != tt.err { 155 t.Errorf("GetFile() = %v, want %v", uErr.Err, tt.err) 156 } else if !ok && err != tt.err { 157 t.Errorf("GetFile() = %v, want %v", err, tt.err) 158 } 159 if err != nil { 160 return 161 } 162 content, err := ioutil.ReadAll(uio.Reader(r)) 163 if err != nil { 164 t.Errorf("bytes.Buffer read returned an error? %v", err) 165 } 166 if got, want := string(content), tt.want; got != want { 167 t.Errorf("GetFile() = %v, want %v", got, want) 168 } 169 } 170 }) 171 } 172 } 173 174 func TestParseURL(t *testing.T) { 175 for i, tt := range []struct { 176 url string 177 wd *url.URL 178 err bool 179 want *url.URL 180 }{ 181 { 182 url: "default", 183 wd: &url.URL{ 184 Scheme: "tftp", 185 Host: "192.168.1.1", 186 Path: "/foobar/pxelinux.cfg", 187 }, 188 want: &url.URL{ 189 Scheme: "tftp", 190 Host: "192.168.1.1", 191 Path: "/foobar/pxelinux.cfg/default", 192 }, 193 }, 194 { 195 url: "http://192.168.2.1/configs/your-machine.cfg", 196 wd: &url.URL{ 197 Scheme: "tftp", 198 Host: "192.168.1.1", 199 Path: "/foobar/pxelinux.cfg", 200 }, 201 want: &url.URL{ 202 Scheme: "http", 203 Host: "192.168.2.1", 204 Path: "/configs/your-machine.cfg", 205 }, 206 }, 207 } { 208 t.Run(fmt.Sprintf("Test #%02d", i), func(t *testing.T) { 209 got, err := parseURL(tt.url, tt.wd) 210 if (err != nil) != tt.err { 211 t.Errorf("Wanted error (%v), but got %v", tt.err, err) 212 } 213 if !reflect.DeepEqual(got, tt.want) { 214 t.Errorf("parseURL() = %#v, want %#v", got, tt.want) 215 } 216 }) 217 } 218 }