github.com/mvdan/u-root-coreutils@v0.0.0-20230122170626-c2eef2898555/pkg/boot/menu/menu_test.go (about) 1 // Copyright 2020 the u-root Authors. All rights reserved 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package menu 6 7 import ( 8 "errors" 9 "fmt" 10 "os" 11 "sync" 12 "testing" 13 "time" 14 15 "github.com/creack/pty" 16 "github.com/mvdan/u-root-coreutils/pkg/testutil" 17 ) 18 19 var inputDelay = 500 * time.Millisecond 20 21 func TestMain(m *testing.M) { 22 SetInitialTimeout(inputDelay * 2) 23 subsequentTimeout = inputDelay * 2 24 25 os.Exit(m.Run()) 26 } 27 28 type testEntry struct { 29 mu sync.Mutex 30 label string 31 cmdline string 32 isDefault bool 33 load error 34 loadCalled bool 35 } 36 37 func (d *testEntry) Label() string { 38 d.mu.Lock() 39 defer d.mu.Unlock() 40 return d.label 41 } 42 43 func (d *testEntry) String() string { 44 return d.Label() 45 } 46 47 func (d *testEntry) Edit(f func(string) string) { 48 d.mu.Lock() 49 defer d.mu.Unlock() 50 d.cmdline = f(d.cmdline) 51 } 52 53 func (d *testEntry) Load() error { 54 d.mu.Lock() 55 defer d.mu.Unlock() 56 d.loadCalled = true 57 return d.load 58 } 59 60 func (d *testEntry) Exec() error { 61 d.mu.Lock() 62 defer d.mu.Unlock() 63 return nil 64 } 65 66 func (d *testEntry) LoadCalled() bool { 67 d.mu.Lock() 68 defer d.mu.Unlock() 69 return d.loadCalled 70 } 71 72 func (d *testEntry) IsDefault() bool { 73 d.mu.Lock() 74 defer d.mu.Unlock() 75 return d.isDefault 76 } 77 78 type testEntryStringer struct { 79 testEntry 80 } 81 82 func (d *testEntryStringer) String() string { 83 return d.Label() + " string" 84 } 85 86 func TestExtendedLabel(t *testing.T) { 87 for _, tt := range []struct { 88 name string 89 entry Entry 90 want string 91 }{ 92 { 93 name: "without stringer", 94 entry: &testEntry{label: "label"}, 95 want: "label", 96 }, 97 { 98 name: "with stringer", 99 entry: &testEntryStringer{testEntry{label: "label"}}, 100 want: "label string", 101 }, 102 } { 103 t.Run(tt.name, func(t *testing.T) { 104 got := ExtendedLabel(tt.entry) 105 if got != tt.want { 106 t.Errorf("ExtendedLabel(%v) = %q; want %q", tt.entry, got, tt.want) 107 } 108 }) 109 } 110 } 111 112 var _ = MenuTerminal(&mockTerm{}) 113 114 type mockTerm struct { 115 inputSequence []ReadLine 116 readLineCnt int 117 } 118 119 func (m *mockTerm) Write(p []byte) (n int, err error) { 120 return len(p), nil 121 } 122 123 func (m *mockTerm) Close() error { return nil } 124 func (m *mockTerm) SetPrompt(s string) {} 125 func (m *mockTerm) SetEntryCallback(func()) {} 126 func (m *mockTerm) SetTimeout(time.Duration) error { return nil } 127 func (m *mockTerm) ReadLine() (string, error) { 128 defer func() { m.readLineCnt++ }() 129 if m.inputSequence == nil || m.readLineCnt >= len(m.inputSequence) { 130 // mimic timeout 131 return "", os.ErrDeadlineExceeded 132 } 133 return m.inputSequence[m.readLineCnt].string, m.inputSequence[m.readLineCnt].error 134 } 135 136 type ReadLine struct { 137 string 138 error 139 } 140 141 func TestChoose(t *testing.T) { 142 // This test takes too long to run for the VM test and doesn't use 143 // anything root-specific. 144 testutil.SkipIfInVMTest(t) 145 146 for _, tt := range []struct { 147 name string 148 userEntry []ReadLine 149 wantedEntry int 150 editingAllowed bool 151 expectedCmds []string 152 }{ 153 { 154 name: "just_hit_enter", 155 // user just hits enter. 156 userEntry: []ReadLine{{"", nil}}, 157 wantedEntry: -1, // expect nil 158 }, 159 { 160 name: "hit_nothing", 161 userEntry: []ReadLine{}, 162 wantedEntry: -1, // expect nil 163 }, 164 { 165 name: "hit_1", 166 userEntry: []ReadLine{{"1", nil}}, 167 wantedEntry: 1, 168 }, 169 { 170 name: "hit_3", 171 userEntry: []ReadLine{{"3", nil}}, 172 wantedEntry: 3, 173 }, 174 { 175 name: "out_of_bounds", 176 userEntry: []ReadLine{{"4", nil}}, 177 wantedEntry: -1, // expect nil 178 }, 179 { 180 name: "not_a_number", 181 userEntry: []ReadLine{{"abc", nil}}, 182 wantedEntry: -1, // expect nil 183 }, 184 { 185 name: "editing_allowed_override", 186 userEntry: getEditSequence(false, "1", "after"), 187 editingAllowed: true, 188 expectedCmds: []string{"after", "before", "before"}, 189 wantedEntry: -1, // expect nil 190 }, 191 { 192 name: "editing_allowed_append", 193 userEntry: getEditSequence(true, "2", "after"), 194 editingAllowed: true, 195 expectedCmds: []string{"before", "before after", "before"}, 196 wantedEntry: -1, // expect nil 197 }, 198 { 199 name: "select_after_override", 200 userEntry: append(getEditSequence(false, "1", "after"), 201 ReadLine{"1", nil}), 202 editingAllowed: true, 203 expectedCmds: []string{"after", "before", "before"}, 204 wantedEntry: 1, 205 }, 206 { 207 name: "editing_not_allowed", 208 userEntry: getEditSequence(true, "1", "after"), 209 editingAllowed: false, 210 expectedCmds: []string{"before", "before", "before"}, 211 wantedEntry: 1, // Edit attempt is parsed as a boot choice 212 }, 213 { 214 name: "edit_fail_reading_1", 215 userEntry: errorOn(1, getEditSequence(false, "1", "after")), 216 editingAllowed: true, 217 expectedCmds: []string{"before", "before", "before"}, 218 wantedEntry: -1, // expect nil 219 }, 220 { 221 name: "edit_fail_reading_2", 222 userEntry: errorOn(2, getEditSequence(false, "1", "after")), 223 editingAllowed: true, 224 expectedCmds: []string{"before", "before", "before"}, 225 wantedEntry: -1, // expect nil 226 }, 227 { 228 name: "edit_fail_reading_3", 229 userEntry: errorOn(3, getEditSequence(false, "1", "after")), 230 editingAllowed: true, 231 expectedCmds: []string{"before", "before", "before"}, 232 wantedEntry: -1, // expect nil 233 }, 234 } { 235 t.Run(tt.name, func(t *testing.T) { 236 entries := []*testEntry{ 237 {label: "1", cmdline: "before"}, 238 {label: "2", cmdline: "before"}, 239 {label: "3", cmdline: "before"}, 240 } 241 var menu []Entry 242 for _, e := range entries { 243 menu = append(menu, e) 244 } 245 246 chosen := make(chan Entry) 247 go func() { 248 m := &mockTerm{ 249 inputSequence: tt.userEntry, 250 } 251 chosen <- Choose(m, tt.editingAllowed, menu...) 252 }() 253 254 chosenWant := Entry(nil) 255 if tt.wantedEntry > 0 { 256 chosenWant = menu[tt.wantedEntry-1] // 1 based index 257 } 258 if got := <-chosen; got != chosenWant { 259 t.Errorf("Choose(%#v, %#v) = %#v, wantedEntry %#v", tt.userEntry, entries, got, tt.wantedEntry) 260 } 261 // Check for editing 262 for i, entry := range entries { 263 if i < len(tt.expectedCmds) && entry.cmdline != tt.expectedCmds[i] { 264 t.Errorf("Entry %s got cmdline %s, wanted %s", entry.Label(), entry.cmdline, tt.expectedCmds[i]) 265 } 266 } 267 }) 268 } 269 } 270 271 func errorOn(index int, arr []ReadLine) []ReadLine { 272 arr[index].error = errors.New("Expected test error") 273 return arr 274 } 275 276 func contains(s []string, t string) bool { 277 for _, u := range s { 278 if u == t { 279 return true 280 } 281 } 282 return false 283 } 284 285 func TestShowMenuAndLoadFromFile(t *testing.T) { 286 // This test takes too long to run for the VM test and doesn't use 287 // anything root-specific. 288 testutil.SkipIfInVMTest(t) 289 290 tests := []struct { 291 name string 292 entries []*testEntry 293 userEntry []byte 294 295 // calledLabels are the entries for which Do was called. 296 calledLabels []string 297 }{ 298 { 299 name: "default_entry", 300 entries: []*testEntry{ 301 {label: "1", isDefault: true, load: nil}, 302 {label: "2", isDefault: true, load: nil}, 303 }, 304 // user just hits enter. 305 userEntry: []byte("\r\n"), 306 calledLabels: []string{"1"}, 307 }, 308 { 309 name: "non_default_entry_default", 310 entries: []*testEntry{ 311 {label: "1", isDefault: false, load: nil}, 312 {label: "2", isDefault: true, load: nil}, 313 {label: "3", isDefault: true, load: nil}, 314 }, 315 // user just hits enter. 316 userEntry: []byte("\r\n"), 317 calledLabels: []string{"2"}, 318 }, 319 { 320 name: "non_default_entry_chosen_but_broken", 321 entries: []*testEntry{ 322 {label: "1", isDefault: false, load: fmt.Errorf("borked")}, 323 {label: "2", isDefault: true, load: nil}, 324 {label: "3", isDefault: true, load: nil}, 325 }, 326 userEntry: []byte("1\r\n"), 327 calledLabels: []string{"1", "2"}, 328 }, 329 { 330 name: "last_entry_works", 331 entries: []*testEntry{ 332 {label: "1", isDefault: true, load: fmt.Errorf("foo")}, 333 {label: "2", isDefault: true, load: fmt.Errorf("bar")}, 334 {label: "3", isDefault: true, load: nil}, 335 }, 336 // user just hits enter. 337 userEntry: []byte("\r\n"), 338 calledLabels: []string{"1", "2", "3"}, 339 }, 340 { 341 name: "indecisive_entry", 342 entries: []*testEntry{ 343 {label: "1", isDefault: true, load: nil}, 344 {label: "2", isDefault: true, load: nil}, 345 {label: "3", isDefault: true, load: nil}, 346 }, 347 // \x08 is the backspace character 348 userEntry: []byte("1\x082\r\n"), 349 calledLabels: []string{"2"}, 350 }, 351 { 352 name: "timeout_gets_first_default", 353 entries: []*testEntry{ 354 {label: "1", isDefault: true, load: nil}, 355 {label: "2", isDefault: true, load: nil}, 356 {label: "3", isDefault: true, load: nil}, 357 }, 358 // No input 359 userEntry: []byte{}, 360 calledLabels: []string{"1"}, 361 }, 362 } 363 for _, tt := range tests { 364 t.Run(tt.name, func(t *testing.T) { 365 master, slave, err := pty.Open() 366 if err != nil { 367 t.Fatalf("%v", err) 368 } 369 defer master.Close() 370 defer slave.Close() 371 372 var entries []Entry 373 for _, e := range tt.entries { 374 entries = append(entries, e) 375 } 376 377 timer := time.NewTimer(initialTimeout * 4) 378 entry := make(chan Entry) 379 go func() { 380 entry <- showMenuAndLoadFromFile(slave, true, entries...) 381 }() 382 383 if tt.userEntry != nil && len(tt.userEntry) > 0 { 384 // We have to wait until Choose has actually started trying to read, as 385 // ttys are asynchronous. 386 // 387 // Know a better way? Halp. 388 time.Sleep(inputDelay) 389 if _, err := master.Write(tt.userEntry); err != nil { 390 t.Fatalf("failed to write new-line: %v", err) 391 } 392 } 393 394 select { 395 case <-timer.C: 396 t.Errorf("Test %s timed out after %v", tt.name, initialTimeout) 397 case got := <-entry: 398 if want := tt.calledLabels[len(tt.calledLabels)-1]; got.Label() != want { 399 t.Errorf("got label %s wantedEntry label %s", got.Label(), want) 400 } 401 402 for _, entry := range tt.entries { 403 wantCalled := contains(tt.calledLabels, entry.label) 404 if wantCalled != entry.LoadCalled() { 405 t.Errorf("Entry %s gotCalled %t, wantCalled %t", entry.Label(), entry.LoadCalled(), wantCalled) 406 } 407 } 408 } 409 }) 410 } 411 } 412 413 func getEditSequence(append bool, bootnum string, cmdline string) []ReadLine { 414 var editOpt string 415 if append { 416 editOpt = "a" 417 } else { 418 editOpt = "o" 419 } 420 421 return []ReadLine{ 422 {"e", nil}, 423 {bootnum, nil}, 424 {editOpt, nil}, 425 {cmdline, nil}, 426 } 427 }