github.com/mvdan/u-root-coreutils@v0.0.0-20230122170626-c2eef2898555/pkg/boot/menu/menu_test.go (about)

     1  // Copyright 2020 the u-root Authors. All rights reserved
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package menu
     6  
     7  import (
     8  	"errors"
     9  	"fmt"
    10  	"os"
    11  	"sync"
    12  	"testing"
    13  	"time"
    14  
    15  	"github.com/creack/pty"
    16  	"github.com/mvdan/u-root-coreutils/pkg/testutil"
    17  )
    18  
    19  var inputDelay = 500 * time.Millisecond
    20  
    21  func TestMain(m *testing.M) {
    22  	SetInitialTimeout(inputDelay * 2)
    23  	subsequentTimeout = inputDelay * 2
    24  
    25  	os.Exit(m.Run())
    26  }
    27  
    28  type testEntry struct {
    29  	mu         sync.Mutex
    30  	label      string
    31  	cmdline    string
    32  	isDefault  bool
    33  	load       error
    34  	loadCalled bool
    35  }
    36  
    37  func (d *testEntry) Label() string {
    38  	d.mu.Lock()
    39  	defer d.mu.Unlock()
    40  	return d.label
    41  }
    42  
    43  func (d *testEntry) String() string {
    44  	return d.Label()
    45  }
    46  
    47  func (d *testEntry) Edit(f func(string) string) {
    48  	d.mu.Lock()
    49  	defer d.mu.Unlock()
    50  	d.cmdline = f(d.cmdline)
    51  }
    52  
    53  func (d *testEntry) Load() error {
    54  	d.mu.Lock()
    55  	defer d.mu.Unlock()
    56  	d.loadCalled = true
    57  	return d.load
    58  }
    59  
    60  func (d *testEntry) Exec() error {
    61  	d.mu.Lock()
    62  	defer d.mu.Unlock()
    63  	return nil
    64  }
    65  
    66  func (d *testEntry) LoadCalled() bool {
    67  	d.mu.Lock()
    68  	defer d.mu.Unlock()
    69  	return d.loadCalled
    70  }
    71  
    72  func (d *testEntry) IsDefault() bool {
    73  	d.mu.Lock()
    74  	defer d.mu.Unlock()
    75  	return d.isDefault
    76  }
    77  
    78  type testEntryStringer struct {
    79  	testEntry
    80  }
    81  
    82  func (d *testEntryStringer) String() string {
    83  	return d.Label() + " string"
    84  }
    85  
    86  func TestExtendedLabel(t *testing.T) {
    87  	for _, tt := range []struct {
    88  		name  string
    89  		entry Entry
    90  		want  string
    91  	}{
    92  		{
    93  			name:  "without stringer",
    94  			entry: &testEntry{label: "label"},
    95  			want:  "label",
    96  		},
    97  		{
    98  			name:  "with stringer",
    99  			entry: &testEntryStringer{testEntry{label: "label"}},
   100  			want:  "label string",
   101  		},
   102  	} {
   103  		t.Run(tt.name, func(t *testing.T) {
   104  			got := ExtendedLabel(tt.entry)
   105  			if got != tt.want {
   106  				t.Errorf("ExtendedLabel(%v) = %q; want %q", tt.entry, got, tt.want)
   107  			}
   108  		})
   109  	}
   110  }
   111  
   112  var _ = MenuTerminal(&mockTerm{})
   113  
   114  type mockTerm struct {
   115  	inputSequence []ReadLine
   116  	readLineCnt   int
   117  }
   118  
   119  func (m *mockTerm) Write(p []byte) (n int, err error) {
   120  	return len(p), nil
   121  }
   122  
   123  func (m *mockTerm) Close() error                   { return nil }
   124  func (m *mockTerm) SetPrompt(s string)             {}
   125  func (m *mockTerm) SetEntryCallback(func())        {}
   126  func (m *mockTerm) SetTimeout(time.Duration) error { return nil }
   127  func (m *mockTerm) ReadLine() (string, error) {
   128  	defer func() { m.readLineCnt++ }()
   129  	if m.inputSequence == nil || m.readLineCnt >= len(m.inputSequence) {
   130  		// mimic timeout
   131  		return "", os.ErrDeadlineExceeded
   132  	}
   133  	return m.inputSequence[m.readLineCnt].string, m.inputSequence[m.readLineCnt].error
   134  }
   135  
   136  type ReadLine struct {
   137  	string
   138  	error
   139  }
   140  
   141  func TestChoose(t *testing.T) {
   142  	// This test takes too long to run for the VM test and doesn't use
   143  	// anything root-specific.
   144  	testutil.SkipIfInVMTest(t)
   145  
   146  	for _, tt := range []struct {
   147  		name           string
   148  		userEntry      []ReadLine
   149  		wantedEntry    int
   150  		editingAllowed bool
   151  		expectedCmds   []string
   152  	}{
   153  		{
   154  			name: "just_hit_enter",
   155  			// user just hits enter.
   156  			userEntry:   []ReadLine{{"", nil}},
   157  			wantedEntry: -1, // expect nil
   158  		},
   159  		{
   160  			name:        "hit_nothing",
   161  			userEntry:   []ReadLine{},
   162  			wantedEntry: -1, // expect nil
   163  		},
   164  		{
   165  			name:        "hit_1",
   166  			userEntry:   []ReadLine{{"1", nil}},
   167  			wantedEntry: 1,
   168  		},
   169  		{
   170  			name:        "hit_3",
   171  			userEntry:   []ReadLine{{"3", nil}},
   172  			wantedEntry: 3,
   173  		},
   174  		{
   175  			name:        "out_of_bounds",
   176  			userEntry:   []ReadLine{{"4", nil}},
   177  			wantedEntry: -1, // expect nil
   178  		},
   179  		{
   180  			name:        "not_a_number",
   181  			userEntry:   []ReadLine{{"abc", nil}},
   182  			wantedEntry: -1, // expect nil
   183  		},
   184  		{
   185  			name:           "editing_allowed_override",
   186  			userEntry:      getEditSequence(false, "1", "after"),
   187  			editingAllowed: true,
   188  			expectedCmds:   []string{"after", "before", "before"},
   189  			wantedEntry:    -1, // expect nil
   190  		},
   191  		{
   192  			name:           "editing_allowed_append",
   193  			userEntry:      getEditSequence(true, "2", "after"),
   194  			editingAllowed: true,
   195  			expectedCmds:   []string{"before", "before after", "before"},
   196  			wantedEntry:    -1, // expect nil
   197  		},
   198  		{
   199  			name: "select_after_override",
   200  			userEntry: append(getEditSequence(false, "1", "after"),
   201  				ReadLine{"1", nil}),
   202  			editingAllowed: true,
   203  			expectedCmds:   []string{"after", "before", "before"},
   204  			wantedEntry:    1,
   205  		},
   206  		{
   207  			name:           "editing_not_allowed",
   208  			userEntry:      getEditSequence(true, "1", "after"),
   209  			editingAllowed: false,
   210  			expectedCmds:   []string{"before", "before", "before"},
   211  			wantedEntry:    1, // Edit attempt is parsed as a boot choice
   212  		},
   213  		{
   214  			name:           "edit_fail_reading_1",
   215  			userEntry:      errorOn(1, getEditSequence(false, "1", "after")),
   216  			editingAllowed: true,
   217  			expectedCmds:   []string{"before", "before", "before"},
   218  			wantedEntry:    -1, // expect nil
   219  		},
   220  		{
   221  			name:           "edit_fail_reading_2",
   222  			userEntry:      errorOn(2, getEditSequence(false, "1", "after")),
   223  			editingAllowed: true,
   224  			expectedCmds:   []string{"before", "before", "before"},
   225  			wantedEntry:    -1, // expect nil
   226  		},
   227  		{
   228  			name:           "edit_fail_reading_3",
   229  			userEntry:      errorOn(3, getEditSequence(false, "1", "after")),
   230  			editingAllowed: true,
   231  			expectedCmds:   []string{"before", "before", "before"},
   232  			wantedEntry:    -1, // expect nil
   233  		},
   234  	} {
   235  		t.Run(tt.name, func(t *testing.T) {
   236  			entries := []*testEntry{
   237  				{label: "1", cmdline: "before"},
   238  				{label: "2", cmdline: "before"},
   239  				{label: "3", cmdline: "before"},
   240  			}
   241  			var menu []Entry
   242  			for _, e := range entries {
   243  				menu = append(menu, e)
   244  			}
   245  
   246  			chosen := make(chan Entry)
   247  			go func() {
   248  				m := &mockTerm{
   249  					inputSequence: tt.userEntry,
   250  				}
   251  				chosen <- Choose(m, tt.editingAllowed, menu...)
   252  			}()
   253  
   254  			chosenWant := Entry(nil)
   255  			if tt.wantedEntry > 0 {
   256  				chosenWant = menu[tt.wantedEntry-1] // 1 based index
   257  			}
   258  			if got := <-chosen; got != chosenWant {
   259  				t.Errorf("Choose(%#v, %#v) = %#v, wantedEntry %#v", tt.userEntry, entries, got, tt.wantedEntry)
   260  			}
   261  			// Check for editing
   262  			for i, entry := range entries {
   263  				if i < len(tt.expectedCmds) && entry.cmdline != tt.expectedCmds[i] {
   264  					t.Errorf("Entry %s got cmdline %s, wanted %s", entry.Label(), entry.cmdline, tt.expectedCmds[i])
   265  				}
   266  			}
   267  		})
   268  	}
   269  }
   270  
   271  func errorOn(index int, arr []ReadLine) []ReadLine {
   272  	arr[index].error = errors.New("Expected test error")
   273  	return arr
   274  }
   275  
   276  func contains(s []string, t string) bool {
   277  	for _, u := range s {
   278  		if u == t {
   279  			return true
   280  		}
   281  	}
   282  	return false
   283  }
   284  
   285  func TestShowMenuAndLoadFromFile(t *testing.T) {
   286  	// This test takes too long to run for the VM test and doesn't use
   287  	// anything root-specific.
   288  	testutil.SkipIfInVMTest(t)
   289  
   290  	tests := []struct {
   291  		name      string
   292  		entries   []*testEntry
   293  		userEntry []byte
   294  
   295  		// calledLabels are the entries for which Do was called.
   296  		calledLabels []string
   297  	}{
   298  		{
   299  			name: "default_entry",
   300  			entries: []*testEntry{
   301  				{label: "1", isDefault: true, load: nil},
   302  				{label: "2", isDefault: true, load: nil},
   303  			},
   304  			// user just hits enter.
   305  			userEntry:    []byte("\r\n"),
   306  			calledLabels: []string{"1"},
   307  		},
   308  		{
   309  			name: "non_default_entry_default",
   310  			entries: []*testEntry{
   311  				{label: "1", isDefault: false, load: nil},
   312  				{label: "2", isDefault: true, load: nil},
   313  				{label: "3", isDefault: true, load: nil},
   314  			},
   315  			// user just hits enter.
   316  			userEntry:    []byte("\r\n"),
   317  			calledLabels: []string{"2"},
   318  		},
   319  		{
   320  			name: "non_default_entry_chosen_but_broken",
   321  			entries: []*testEntry{
   322  				{label: "1", isDefault: false, load: fmt.Errorf("borked")},
   323  				{label: "2", isDefault: true, load: nil},
   324  				{label: "3", isDefault: true, load: nil},
   325  			},
   326  			userEntry:    []byte("1\r\n"),
   327  			calledLabels: []string{"1", "2"},
   328  		},
   329  		{
   330  			name: "last_entry_works",
   331  			entries: []*testEntry{
   332  				{label: "1", isDefault: true, load: fmt.Errorf("foo")},
   333  				{label: "2", isDefault: true, load: fmt.Errorf("bar")},
   334  				{label: "3", isDefault: true, load: nil},
   335  			},
   336  			// user just hits enter.
   337  			userEntry:    []byte("\r\n"),
   338  			calledLabels: []string{"1", "2", "3"},
   339  		},
   340  		{
   341  			name: "indecisive_entry",
   342  			entries: []*testEntry{
   343  				{label: "1", isDefault: true, load: nil},
   344  				{label: "2", isDefault: true, load: nil},
   345  				{label: "3", isDefault: true, load: nil},
   346  			},
   347  			// \x08 is the backspace character
   348  			userEntry:    []byte("1\x082\r\n"),
   349  			calledLabels: []string{"2"},
   350  		},
   351  		{
   352  			name: "timeout_gets_first_default",
   353  			entries: []*testEntry{
   354  				{label: "1", isDefault: true, load: nil},
   355  				{label: "2", isDefault: true, load: nil},
   356  				{label: "3", isDefault: true, load: nil},
   357  			},
   358  			// No input
   359  			userEntry:    []byte{},
   360  			calledLabels: []string{"1"},
   361  		},
   362  	}
   363  	for _, tt := range tests {
   364  		t.Run(tt.name, func(t *testing.T) {
   365  			master, slave, err := pty.Open()
   366  			if err != nil {
   367  				t.Fatalf("%v", err)
   368  			}
   369  			defer master.Close()
   370  			defer slave.Close()
   371  
   372  			var entries []Entry
   373  			for _, e := range tt.entries {
   374  				entries = append(entries, e)
   375  			}
   376  
   377  			timer := time.NewTimer(initialTimeout * 4)
   378  			entry := make(chan Entry)
   379  			go func() {
   380  				entry <- showMenuAndLoadFromFile(slave, true, entries...)
   381  			}()
   382  
   383  			if tt.userEntry != nil && len(tt.userEntry) > 0 {
   384  				// We have to wait until Choose has actually started trying to read, as
   385  				// ttys are asynchronous.
   386  				//
   387  				// Know a better way? Halp.
   388  				time.Sleep(inputDelay)
   389  				if _, err := master.Write(tt.userEntry); err != nil {
   390  					t.Fatalf("failed to write new-line: %v", err)
   391  				}
   392  			}
   393  
   394  			select {
   395  			case <-timer.C:
   396  				t.Errorf("Test %s timed out after %v", tt.name, initialTimeout)
   397  			case got := <-entry:
   398  				if want := tt.calledLabels[len(tt.calledLabels)-1]; got.Label() != want {
   399  					t.Errorf("got label %s wantedEntry label %s", got.Label(), want)
   400  				}
   401  
   402  				for _, entry := range tt.entries {
   403  					wantCalled := contains(tt.calledLabels, entry.label)
   404  					if wantCalled != entry.LoadCalled() {
   405  						t.Errorf("Entry %s gotCalled %t, wantCalled %t", entry.Label(), entry.LoadCalled(), wantCalled)
   406  					}
   407  				}
   408  			}
   409  		})
   410  	}
   411  }
   412  
   413  func getEditSequence(append bool, bootnum string, cmdline string) []ReadLine {
   414  	var editOpt string
   415  	if append {
   416  		editOpt = "a"
   417  	} else {
   418  		editOpt = "o"
   419  	}
   420  
   421  	return []ReadLine{
   422  		{"e", nil},
   423  		{bootnum, nil},
   424  		{editOpt, nil},
   425  		{cmdline, nil},
   426  	}
   427  }