github.com/u-root/u-root@v7.0.1-0.20200915234505-ad7babab0a8e+incompatible/pkg/boot/menu/menu_test.go (about) 1 // Copyright 2020 the u-root Authors. All rights reserved 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package menu 6 7 import ( 8 "fmt" 9 "os" 10 "sync" 11 "testing" 12 "time" 13 14 "github.com/google/goterm/term" 15 "github.com/u-root/u-root/pkg/testutil" 16 ) 17 18 func TestMain(m *testing.M) { 19 initialTimeout = 2 * time.Second 20 subsequentTimeout = 6 * time.Second 21 22 os.Exit(m.Run()) 23 } 24 25 type testEntry struct { 26 mu sync.Mutex 27 label string 28 isDefault bool 29 load error 30 loadCalled bool 31 } 32 33 func (d *testEntry) Label() string { 34 d.mu.Lock() 35 defer d.mu.Unlock() 36 return d.label 37 } 38 39 func (d *testEntry) String() string { 40 return d.Label() 41 } 42 43 func (d *testEntry) Load() error { 44 d.mu.Lock() 45 defer d.mu.Unlock() 46 d.loadCalled = true 47 return d.load 48 } 49 50 func (d *testEntry) Exec() error { 51 d.mu.Lock() 52 defer d.mu.Unlock() 53 return nil 54 } 55 56 func (d *testEntry) LoadCalled() bool { 57 d.mu.Lock() 58 defer d.mu.Unlock() 59 return d.loadCalled 60 } 61 62 func (d *testEntry) IsDefault() bool { 63 d.mu.Lock() 64 defer d.mu.Unlock() 65 return d.isDefault 66 } 67 68 func TestChoose(t *testing.T) { 69 // This test takes too long to run for the VM test and doesn't use 70 // anything root-specific. 71 testutil.SkipIfInVMTest(t) 72 73 entry1 := &testEntry{label: "1"} 74 entry2 := &testEntry{label: "2"} 75 entry3 := &testEntry{label: "3"} 76 77 for _, tt := range []struct { 78 name string 79 entries []Entry 80 userEntry []byte 81 want Entry 82 }{ 83 { 84 name: "just_hit_enter", 85 entries: []Entry{entry1, entry2, entry3}, 86 // user just hits enter. 87 userEntry: []byte("\r\n"), 88 want: nil, 89 }, 90 { 91 name: "hit_nothing", 92 entries: []Entry{entry1, entry2, entry3}, 93 userEntry: nil, 94 want: nil, 95 }, 96 { 97 name: "hit_1", 98 entries: []Entry{entry1, entry2, entry3}, 99 userEntry: []byte("1\r\n"), 100 want: entry1, 101 }, 102 { 103 name: "hit_3", 104 entries: []Entry{entry1, entry2, entry3}, 105 userEntry: []byte("3\r\n"), 106 want: entry3, 107 }, 108 { 109 name: "tentative_hit_1", 110 entries: []Entry{entry1, entry2, entry3}, 111 // \x08 is the backspace character. 112 userEntry: []byte("2\x081\r\n"), 113 want: entry1, 114 }, 115 { 116 name: "out_of_bounds", 117 entries: []Entry{entry1, entry2, entry3}, 118 userEntry: []byte("4\r\n"), 119 want: nil, 120 }, 121 { 122 name: "not_a_number", 123 entries: []Entry{entry1, entry2, entry3}, 124 userEntry: []byte("abc\r\n"), 125 want: nil, 126 }, 127 } { 128 t.Run(tt.name, func(t *testing.T) { 129 pty, err := term.OpenPTY() 130 if err != nil { 131 t.Fatalf("%v", err) 132 } 133 defer pty.Close() 134 135 chosen := make(chan Entry) 136 go func() { 137 chosen <- Choose(pty.Slave, tt.entries...) 138 }() 139 140 // Well this sucks. 141 // 142 // We have to wait until Choose has actually started trying to read, as 143 // ttys are asynchronous. 144 // 145 // Know a better way? Halp. 146 time.Sleep(1 * time.Second) 147 148 if tt.userEntry != nil { 149 if _, err := pty.Master.Write(tt.userEntry); err != nil { 150 t.Fatalf("failed to write new-line: %v", err) 151 } 152 } 153 154 if got := <-chosen; got != tt.want { 155 t.Errorf("Choose(%#v, %#v) = %#v, want %#v", tt.userEntry, tt.entries, got, tt.want) 156 } 157 }) 158 } 159 } 160 161 func contains(s []string, t string) bool { 162 for _, u := range s { 163 if u == t { 164 return true 165 } 166 } 167 return false 168 } 169 170 func TestShowMenuAndLoad(t *testing.T) { 171 // This test takes too long to run for the VM test and doesn't use 172 // anything root-specific. 173 testutil.SkipIfInVMTest(t) 174 175 tests := []struct { 176 name string 177 entries []*testEntry 178 userEntry []byte 179 180 // calledLabels are the entries for which Do was called. 181 calledLabels []string 182 }{ 183 { 184 name: "default_entry", 185 entries: []*testEntry{ 186 {label: "1", isDefault: true, load: nil}, 187 {label: "2", isDefault: true, load: nil}, 188 }, 189 // user just hits enter. 190 userEntry: []byte("\r\n"), 191 calledLabels: []string{"1"}, 192 }, 193 { 194 name: "non_default_entry_default", 195 entries: []*testEntry{ 196 {label: "1", isDefault: false, load: nil}, 197 {label: "2", isDefault: true, load: nil}, 198 {label: "3", isDefault: true, load: nil}, 199 }, 200 // user just hits enter. 201 userEntry: []byte("\r\n"), 202 calledLabels: []string{"2"}, 203 }, 204 { 205 name: "non_default_entry_chosen_but_broken", 206 entries: []*testEntry{ 207 {label: "1", isDefault: false, load: fmt.Errorf("borked")}, 208 {label: "2", isDefault: true, load: nil}, 209 {label: "3", isDefault: true, load: nil}, 210 }, 211 userEntry: []byte("1\r\n"), 212 calledLabels: []string{"1", "2"}, 213 }, 214 { 215 name: "last_entry_works", 216 entries: []*testEntry{ 217 {label: "1", isDefault: true, load: fmt.Errorf("foo")}, 218 {label: "2", isDefault: true, load: fmt.Errorf("bar")}, 219 {label: "3", isDefault: true, load: nil}, 220 }, 221 // user just hits enter. 222 userEntry: []byte("\r\n"), 223 calledLabels: []string{"1", "2", "3"}, 224 }, 225 } 226 for _, tt := range tests { 227 t.Run(tt.name, func(t *testing.T) { 228 pty, err := term.OpenPTY() 229 if err != nil { 230 t.Fatalf("%v", err) 231 } 232 defer pty.Close() 233 234 var entries []Entry 235 for _, e := range tt.entries { 236 entries = append(entries, e) 237 } 238 239 entry := make(chan Entry) 240 go func() { 241 entry <- ShowMenuAndLoad(pty.Slave, entries...) 242 }() 243 244 // Well this sucks. 245 // 246 // We have to wait until Choose has actually started trying to read, as 247 // ttys are asynchronous. 248 // 249 // Know a better way? Halp. 250 time.Sleep(1 * time.Second) 251 252 if tt.userEntry != nil { 253 if _, err := pty.Master.Write(tt.userEntry); err != nil { 254 t.Fatalf("failed to write new-line: %v", err) 255 } 256 } 257 258 got := <-entry 259 if want := tt.calledLabels[len(tt.calledLabels)-1]; got.Label() != want { 260 t.Errorf("got label %s want label %s", got.Label(), want) 261 } 262 263 for _, entry := range tt.entries { 264 wantCalled := contains(tt.calledLabels, entry.label) 265 if wantCalled != entry.LoadCalled() { 266 t.Errorf("Entry %s gotCalled %t, wantCalled %t", entry.Label(), entry.LoadCalled(), wantCalled) 267 } 268 } 269 }) 270 } 271 }