github.com/u-root/u-root@v7.0.1-0.20200915234505-ad7babab0a8e+incompatible/pkg/boot/menu/menu_test.go (about)

     1  // Copyright 2020 the u-root Authors. All rights reserved
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package menu
     6  
     7  import (
     8  	"fmt"
     9  	"os"
    10  	"sync"
    11  	"testing"
    12  	"time"
    13  
    14  	"github.com/google/goterm/term"
    15  	"github.com/u-root/u-root/pkg/testutil"
    16  )
    17  
    18  func TestMain(m *testing.M) {
    19  	initialTimeout = 2 * time.Second
    20  	subsequentTimeout = 6 * time.Second
    21  
    22  	os.Exit(m.Run())
    23  }
    24  
    25  type testEntry struct {
    26  	mu         sync.Mutex
    27  	label      string
    28  	isDefault  bool
    29  	load       error
    30  	loadCalled bool
    31  }
    32  
    33  func (d *testEntry) Label() string {
    34  	d.mu.Lock()
    35  	defer d.mu.Unlock()
    36  	return d.label
    37  }
    38  
    39  func (d *testEntry) String() string {
    40  	return d.Label()
    41  }
    42  
    43  func (d *testEntry) Load() error {
    44  	d.mu.Lock()
    45  	defer d.mu.Unlock()
    46  	d.loadCalled = true
    47  	return d.load
    48  }
    49  
    50  func (d *testEntry) Exec() error {
    51  	d.mu.Lock()
    52  	defer d.mu.Unlock()
    53  	return nil
    54  }
    55  
    56  func (d *testEntry) LoadCalled() bool {
    57  	d.mu.Lock()
    58  	defer d.mu.Unlock()
    59  	return d.loadCalled
    60  }
    61  
    62  func (d *testEntry) IsDefault() bool {
    63  	d.mu.Lock()
    64  	defer d.mu.Unlock()
    65  	return d.isDefault
    66  }
    67  
    68  func TestChoose(t *testing.T) {
    69  	// This test takes too long to run for the VM test and doesn't use
    70  	// anything root-specific.
    71  	testutil.SkipIfInVMTest(t)
    72  
    73  	entry1 := &testEntry{label: "1"}
    74  	entry2 := &testEntry{label: "2"}
    75  	entry3 := &testEntry{label: "3"}
    76  
    77  	for _, tt := range []struct {
    78  		name      string
    79  		entries   []Entry
    80  		userEntry []byte
    81  		want      Entry
    82  	}{
    83  		{
    84  			name:    "just_hit_enter",
    85  			entries: []Entry{entry1, entry2, entry3},
    86  			// user just hits enter.
    87  			userEntry: []byte("\r\n"),
    88  			want:      nil,
    89  		},
    90  		{
    91  			name:      "hit_nothing",
    92  			entries:   []Entry{entry1, entry2, entry3},
    93  			userEntry: nil,
    94  			want:      nil,
    95  		},
    96  		{
    97  			name:      "hit_1",
    98  			entries:   []Entry{entry1, entry2, entry3},
    99  			userEntry: []byte("1\r\n"),
   100  			want:      entry1,
   101  		},
   102  		{
   103  			name:      "hit_3",
   104  			entries:   []Entry{entry1, entry2, entry3},
   105  			userEntry: []byte("3\r\n"),
   106  			want:      entry3,
   107  		},
   108  		{
   109  			name:    "tentative_hit_1",
   110  			entries: []Entry{entry1, entry2, entry3},
   111  			// \x08 is the backspace character.
   112  			userEntry: []byte("2\x081\r\n"),
   113  			want:      entry1,
   114  		},
   115  		{
   116  			name:      "out_of_bounds",
   117  			entries:   []Entry{entry1, entry2, entry3},
   118  			userEntry: []byte("4\r\n"),
   119  			want:      nil,
   120  		},
   121  		{
   122  			name:      "not_a_number",
   123  			entries:   []Entry{entry1, entry2, entry3},
   124  			userEntry: []byte("abc\r\n"),
   125  			want:      nil,
   126  		},
   127  	} {
   128  		t.Run(tt.name, func(t *testing.T) {
   129  			pty, err := term.OpenPTY()
   130  			if err != nil {
   131  				t.Fatalf("%v", err)
   132  			}
   133  			defer pty.Close()
   134  
   135  			chosen := make(chan Entry)
   136  			go func() {
   137  				chosen <- Choose(pty.Slave, tt.entries...)
   138  			}()
   139  
   140  			// Well this sucks.
   141  			//
   142  			// We have to wait until Choose has actually started trying to read, as
   143  			// ttys are asynchronous.
   144  			//
   145  			// Know a better way? Halp.
   146  			time.Sleep(1 * time.Second)
   147  
   148  			if tt.userEntry != nil {
   149  				if _, err := pty.Master.Write(tt.userEntry); err != nil {
   150  					t.Fatalf("failed to write new-line: %v", err)
   151  				}
   152  			}
   153  
   154  			if got := <-chosen; got != tt.want {
   155  				t.Errorf("Choose(%#v, %#v) = %#v, want %#v", tt.userEntry, tt.entries, got, tt.want)
   156  			}
   157  		})
   158  	}
   159  }
   160  
   161  func contains(s []string, t string) bool {
   162  	for _, u := range s {
   163  		if u == t {
   164  			return true
   165  		}
   166  	}
   167  	return false
   168  }
   169  
   170  func TestShowMenuAndLoad(t *testing.T) {
   171  	// This test takes too long to run for the VM test and doesn't use
   172  	// anything root-specific.
   173  	testutil.SkipIfInVMTest(t)
   174  
   175  	tests := []struct {
   176  		name      string
   177  		entries   []*testEntry
   178  		userEntry []byte
   179  
   180  		// calledLabels are the entries for which Do was called.
   181  		calledLabels []string
   182  	}{
   183  		{
   184  			name: "default_entry",
   185  			entries: []*testEntry{
   186  				{label: "1", isDefault: true, load: nil},
   187  				{label: "2", isDefault: true, load: nil},
   188  			},
   189  			// user just hits enter.
   190  			userEntry:    []byte("\r\n"),
   191  			calledLabels: []string{"1"},
   192  		},
   193  		{
   194  			name: "non_default_entry_default",
   195  			entries: []*testEntry{
   196  				{label: "1", isDefault: false, load: nil},
   197  				{label: "2", isDefault: true, load: nil},
   198  				{label: "3", isDefault: true, load: nil},
   199  			},
   200  			// user just hits enter.
   201  			userEntry:    []byte("\r\n"),
   202  			calledLabels: []string{"2"},
   203  		},
   204  		{
   205  			name: "non_default_entry_chosen_but_broken",
   206  			entries: []*testEntry{
   207  				{label: "1", isDefault: false, load: fmt.Errorf("borked")},
   208  				{label: "2", isDefault: true, load: nil},
   209  				{label: "3", isDefault: true, load: nil},
   210  			},
   211  			userEntry:    []byte("1\r\n"),
   212  			calledLabels: []string{"1", "2"},
   213  		},
   214  		{
   215  			name: "last_entry_works",
   216  			entries: []*testEntry{
   217  				{label: "1", isDefault: true, load: fmt.Errorf("foo")},
   218  				{label: "2", isDefault: true, load: fmt.Errorf("bar")},
   219  				{label: "3", isDefault: true, load: nil},
   220  			},
   221  			// user just hits enter.
   222  			userEntry:    []byte("\r\n"),
   223  			calledLabels: []string{"1", "2", "3"},
   224  		},
   225  	}
   226  	for _, tt := range tests {
   227  		t.Run(tt.name, func(t *testing.T) {
   228  			pty, err := term.OpenPTY()
   229  			if err != nil {
   230  				t.Fatalf("%v", err)
   231  			}
   232  			defer pty.Close()
   233  
   234  			var entries []Entry
   235  			for _, e := range tt.entries {
   236  				entries = append(entries, e)
   237  			}
   238  
   239  			entry := make(chan Entry)
   240  			go func() {
   241  				entry <- ShowMenuAndLoad(pty.Slave, entries...)
   242  			}()
   243  
   244  			// Well this sucks.
   245  			//
   246  			// We have to wait until Choose has actually started trying to read, as
   247  			// ttys are asynchronous.
   248  			//
   249  			// Know a better way? Halp.
   250  			time.Sleep(1 * time.Second)
   251  
   252  			if tt.userEntry != nil {
   253  				if _, err := pty.Master.Write(tt.userEntry); err != nil {
   254  					t.Fatalf("failed to write new-line: %v", err)
   255  				}
   256  			}
   257  
   258  			got := <-entry
   259  			if want := tt.calledLabels[len(tt.calledLabels)-1]; got.Label() != want {
   260  				t.Errorf("got label %s want label %s", got.Label(), want)
   261  			}
   262  
   263  			for _, entry := range tt.entries {
   264  				wantCalled := contains(tt.calledLabels, entry.label)
   265  				if wantCalled != entry.LoadCalled() {
   266  					t.Errorf("Entry %s gotCalled %t, wantCalled %t", entry.Label(), entry.LoadCalled(), wantCalled)
   267  				}
   268  			}
   269  		})
   270  	}
   271  }