github.com/ratrocket/u-root@v0.0.0-20180201221235-1cf9f48ee2cf/pkg/pxe/schemes_test.go (about) 1 package pxe 2 3 import ( 4 "bytes" 5 "errors" 6 "fmt" 7 "io" 8 "io/ioutil" 9 "net/url" 10 "path" 11 "reflect" 12 "testing" 13 ) 14 15 type MockScheme struct { 16 // scheme is the scheme name. 17 scheme string 18 19 // hosts is a map of host -> relative filename to host -> file contents. 20 hosts map[string]map[string]string 21 22 // numCalled is a map of URI string -> number of times GetFile has been 23 // called on that URI. 24 numCalled map[string]uint 25 } 26 27 func NewMockScheme(scheme string) *MockScheme { 28 return &MockScheme{ 29 scheme: scheme, 30 hosts: make(map[string]map[string]string), 31 numCalled: make(map[string]uint), 32 } 33 } 34 35 func (m *MockScheme) Add(host string, p string, content string) { 36 _, ok := m.hosts[host] 37 if !ok { 38 m.hosts[host] = make(map[string]string) 39 } 40 41 m.hosts[host][path.Clean(p)] = content 42 } 43 44 func (m *MockScheme) NumCalled(u *url.URL) uint { 45 uri := u.String() 46 if c, ok := m.numCalled[uri]; ok { 47 return c 48 } 49 return 0 50 } 51 52 var ( 53 errWrongScheme = errors.New("wrong scheme") 54 errNoSuchHost = errors.New("no such host exists") 55 errNoSuchFile = errors.New("no such file exists on this host") 56 ) 57 58 func (m *MockScheme) GetFile(u *url.URL) (io.Reader, error) { 59 uri := u.String() 60 if _, ok := m.numCalled[uri]; ok { 61 m.numCalled[uri]++ 62 } else { 63 m.numCalled[uri] = 1 64 } 65 66 if u.Scheme != m.scheme { 67 return nil, errWrongScheme 68 } 69 70 files, ok := m.hosts[u.Host] 71 if !ok { 72 return nil, errNoSuchHost 73 } 74 75 content, ok := files[path.Clean(u.Path)] 76 if !ok { 77 return nil, errNoSuchFile 78 } 79 return bytes.NewBufferString(content), nil 80 } 81 82 func TestCachedFileSchemeGetFile(t *testing.T) { 83 for i, tt := range []struct { 84 fs func() *MockScheme 85 uri *url.URL 86 err error 87 want string 88 }{ 89 { 90 fs: func() *MockScheme { 91 s := NewMockScheme("fooftp") 92 s.Add("192.168.0.1", "/default", "haha") 93 return s 94 }, 95 uri: &url.URL{ 96 Scheme: "fooftp", 97 Host: "192.168.0.1", 98 Path: "/default", 99 }, 100 want: "haha", 101 }, 102 { 103 fs: func() *MockScheme { 104 return NewMockScheme("fooftp") 105 }, 106 uri: &url.URL{ 107 Scheme: "fooftp", 108 }, 109 err: errNoSuchHost, 110 }, 111 } { 112 t.Run(fmt.Sprintf("Test [%02d]", i), func(t *testing.T) { 113 ms := tt.fs() 114 fs := NewCachedFileScheme(ms) 115 r, err := fs.GetFile(tt.uri) 116 if err != tt.err { 117 t.Errorf("GetFile(%s) = %v, want %v", tt.uri, err, tt.err) 118 return 119 } else if err == nil { 120 content, err := ioutil.ReadAll(r) 121 if err != nil { 122 t.Errorf("ReadAll = %v, want nil", err) 123 } 124 if got := string(content); got != tt.want { 125 t.Errorf("Read(%s) got %v, want %v", tt.uri, got, tt.want) 126 } 127 } 128 129 r2, err2 := fs.GetFile(tt.uri) 130 if err2 != tt.err { 131 t.Errorf("GetFile2(%s) = %v, want %v", tt.uri, err2, tt.err) 132 return 133 } else if err2 == nil { 134 content2, err := ioutil.ReadAll(r2) 135 if err != nil { 136 t.Errorf("ReadAll2 = %v, want nil", err) 137 } 138 if got := string(content2); got != tt.want { 139 t.Errorf("Read2(%s) got %v, want %v", tt.uri, got, tt.want) 140 } 141 } 142 143 if got := ms.NumCalled(tt.uri); got != 1 { 144 t.Errorf("num called(%s) = %d, want 1", tt.uri, got) 145 } 146 }) 147 } 148 } 149 150 func TestGetFile(t *testing.T) { 151 for i, tt := range []struct { 152 scheme func() *MockScheme 153 wd *url.URL 154 uri string 155 err error 156 want string 157 }{ 158 { 159 scheme: func() *MockScheme { 160 s := NewMockScheme("fooftp") 161 s.Add("192.168.0.1", "/foo/pxelinux.cfg/default", "haha") 162 return s 163 }, 164 want: "haha", 165 uri: "default", 166 wd: &url.URL{ 167 Scheme: "fooftp", 168 Host: "192.168.0.1", 169 Path: "/foo/pxelinux.cfg", 170 }, 171 }, 172 { 173 scheme: func() *MockScheme { 174 s := NewMockScheme("fooftp") 175 return s 176 }, 177 uri: "nosuch://scheme/foo", 178 err: ErrNoSuchScheme, 179 }, 180 { 181 scheme: func() *MockScheme { 182 s := NewMockScheme("fooftp") 183 return s 184 }, 185 uri: "fooftp://someotherplace", 186 err: errNoSuchHost, 187 }, 188 { 189 scheme: func() *MockScheme { 190 s := NewMockScheme("fooftp") 191 s.Add("somehost", "somefile", "somecontent") 192 return s 193 }, 194 uri: "fooftp://somehost/someotherfile", 195 err: errNoSuchFile, 196 }, 197 } { 198 t.Run(fmt.Sprintf("Test #%02d", i), func(t *testing.T) { 199 fs := tt.scheme() 200 s := make(Schemes) 201 s.Register(fs.scheme, fs) 202 203 // Test both GetFile and LazyGetFile. 204 for _, f := range []func(uri string, wd *url.URL) (io.Reader, error){ 205 s.GetFile, 206 s.LazyGetFile, 207 } { 208 r, err := f(tt.uri, tt.wd) 209 if got, want := err, tt.err; got != want { 210 t.Errorf("GetFile() = %v, want %v", got, want) 211 } 212 if err != nil { 213 return 214 } 215 content, err := ioutil.ReadAll(r) 216 if err != nil { 217 t.Errorf("bytes.Buffer read returned an error? %v", err) 218 } 219 if got, want := string(content), tt.want; got != want { 220 t.Errorf("GetFile() = %v, want %v", got, want) 221 } 222 } 223 }) 224 } 225 } 226 227 func TestParseURI(t *testing.T) { 228 for i, tt := range []struct { 229 uri string 230 wd *url.URL 231 err bool 232 want *url.URL 233 }{ 234 { 235 uri: "default", 236 wd: &url.URL{ 237 Scheme: "tftp", 238 Host: "192.168.1.1", 239 Path: "/foobar/pxelinux.cfg", 240 }, 241 want: &url.URL{ 242 Scheme: "tftp", 243 Host: "192.168.1.1", 244 Path: "/foobar/pxelinux.cfg/default", 245 }, 246 }, 247 { 248 uri: "http://192.168.2.1/configs/your-machine.cfg", 249 wd: &url.URL{ 250 Scheme: "tftp", 251 Host: "192.168.1.1", 252 Path: "/foobar/pxelinux.cfg", 253 }, 254 want: &url.URL{ 255 Scheme: "http", 256 Host: "192.168.2.1", 257 Path: "/configs/your-machine.cfg", 258 }, 259 }, 260 } { 261 t.Run(fmt.Sprintf("Test #%02d", i), func(t *testing.T) { 262 got, err := parseURI(tt.uri, tt.wd) 263 if (err != nil) != tt.err { 264 t.Errorf("Wanted error (%v), but got %v", tt.err, err) 265 } 266 if !reflect.DeepEqual(got, tt.want) { 267 t.Errorf("parseURI() = %#v, want %#v", got, tt.want) 268 } 269 }) 270 } 271 }