github.com/andrewsun2898/u-root@v6.0.1-0.20200616011413-4b2895c1b815+incompatible/pkg/boot/menu/menu_test.go (about)

     1  // Copyright 2020 the u-root Authors. All rights reserved
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package menu
     6  
     7  import (
     8  	"fmt"
     9  	"sync"
    10  	"testing"
    11  	"time"
    12  
    13  	"github.com/google/goterm/term"
    14  	"github.com/u-root/u-root/pkg/testutil"
    15  )
    16  
    17  type dummyEntry struct {
    18  	mu        sync.Mutex
    19  	label     string
    20  	isDefault bool
    21  	do        error
    22  	called    bool
    23  }
    24  
    25  func (d *dummyEntry) Label() string {
    26  	d.mu.Lock()
    27  	defer d.mu.Unlock()
    28  	return d.label
    29  }
    30  
    31  func (d *dummyEntry) String() string {
    32  	return d.Label()
    33  }
    34  
    35  func (d *dummyEntry) Do() error {
    36  	d.mu.Lock()
    37  	defer d.mu.Unlock()
    38  	d.called = true
    39  	return d.do
    40  }
    41  
    42  func (d *dummyEntry) Called() bool {
    43  	d.mu.Lock()
    44  	defer d.mu.Unlock()
    45  	return d.called
    46  }
    47  
    48  func (d *dummyEntry) IsDefault() bool {
    49  	d.mu.Lock()
    50  	defer d.mu.Unlock()
    51  	return d.isDefault
    52  }
    53  
    54  func TestChoose(t *testing.T) {
    55  	// This test takes too long to run for the VM test and doesn't use
    56  	// anything root-specific.
    57  	testutil.SkipIfInVMTest(t)
    58  
    59  	entry1 := &dummyEntry{label: "1"}
    60  	entry2 := &dummyEntry{label: "2"}
    61  	entry3 := &dummyEntry{label: "3"}
    62  
    63  	for _, tt := range []struct {
    64  		name      string
    65  		entries   []Entry
    66  		userEntry []byte
    67  		want      Entry
    68  	}{
    69  		{
    70  			name:    "just_hit_enter",
    71  			entries: []Entry{entry1, entry2, entry3},
    72  			// user just hits enter.
    73  			userEntry: []byte("\r\n"),
    74  			want:      nil,
    75  		},
    76  		{
    77  			name:      "hit_nothing",
    78  			entries:   []Entry{entry1, entry2, entry3},
    79  			userEntry: nil,
    80  			want:      nil,
    81  		},
    82  		{
    83  			name:      "hit_1",
    84  			entries:   []Entry{entry1, entry2, entry3},
    85  			userEntry: []byte("1\r\n"),
    86  			want:      entry1,
    87  		},
    88  		{
    89  			name:      "hit_3",
    90  			entries:   []Entry{entry1, entry2, entry3},
    91  			userEntry: []byte("3\r\n"),
    92  			want:      entry3,
    93  		},
    94  		{
    95  			name:    "tentative_hit_1",
    96  			entries: []Entry{entry1, entry2, entry3},
    97  			// \x08 is the backspace character.
    98  			userEntry: []byte("2\x081\r\n"),
    99  			want:      entry1,
   100  		},
   101  		{
   102  			name:      "out_of_bounds",
   103  			entries:   []Entry{entry1, entry2, entry3},
   104  			userEntry: []byte("4\r\n"),
   105  			want:      nil,
   106  		},
   107  		{
   108  			name:      "not_a_number",
   109  			entries:   []Entry{entry1, entry2, entry3},
   110  			userEntry: []byte("abc\r\n"),
   111  			want:      nil,
   112  		},
   113  	} {
   114  		t.Run(tt.name, func(t *testing.T) {
   115  			t.Parallel()
   116  			pty, err := term.OpenPTY()
   117  			if err != nil {
   118  				t.Fatalf("%v", err)
   119  			}
   120  			defer pty.Close()
   121  
   122  			chosen := make(chan Entry)
   123  			go func() {
   124  				chosen <- Choose(pty.Slave, tt.entries...)
   125  			}()
   126  
   127  			// Well this sucks.
   128  			//
   129  			// We have to wait until Choose has actually started trying to read, as
   130  			// ttys are asynchronous.
   131  			//
   132  			// Know a better way? Halp.
   133  			time.Sleep(1 * time.Second)
   134  
   135  			if tt.userEntry != nil {
   136  				if _, err := pty.Master.Write(tt.userEntry); err != nil {
   137  					t.Fatalf("failed to write new-line: %v", err)
   138  				}
   139  			}
   140  
   141  			if got := <-chosen; got != tt.want {
   142  				t.Errorf("Choose(%#v, %#v) = %#v, want %#v", tt.userEntry, tt.entries, got, tt.want)
   143  			}
   144  		})
   145  	}
   146  }
   147  
   148  func contains(s []string, t string) bool {
   149  	for _, u := range s {
   150  		if u == t {
   151  			return true
   152  		}
   153  	}
   154  	return false
   155  }
   156  
   157  func TestShowMenuAndBoot(t *testing.T) {
   158  	// This test takes too long to run for the VM test and doesn't use
   159  	// anything root-specific.
   160  	testutil.SkipIfInVMTest(t)
   161  
   162  	tests := []struct {
   163  		name      string
   164  		entries   []*dummyEntry
   165  		userEntry []byte
   166  
   167  		// calledLabels are the entries for which Do was called.
   168  		calledLabels []string
   169  	}{
   170  		{
   171  			name: "default_entry",
   172  			entries: []*dummyEntry{
   173  				{label: "1", isDefault: true, do: errStopTestOnly},
   174  				{label: "2", isDefault: true, do: nil},
   175  			},
   176  			// user just hits enter.
   177  			userEntry:    []byte("\r\n"),
   178  			calledLabels: []string{"1"},
   179  		},
   180  		{
   181  			name: "non_default_entry_default",
   182  			entries: []*dummyEntry{
   183  				{label: "1", isDefault: false, do: errStopTestOnly},
   184  				{label: "2", isDefault: true, do: errStopTestOnly},
   185  				{label: "3", isDefault: true, do: nil},
   186  			},
   187  			// user just hits enter.
   188  			userEntry:    []byte("\r\n"),
   189  			calledLabels: []string{"2"},
   190  		},
   191  		{
   192  			name: "non_default_entry_chosen_but_broken",
   193  			entries: []*dummyEntry{
   194  				{label: "1", isDefault: false, do: fmt.Errorf("borked")},
   195  				{label: "2", isDefault: true, do: errStopTestOnly},
   196  				{label: "3", isDefault: true, do: nil},
   197  			},
   198  			userEntry:    []byte("1\r\n"),
   199  			calledLabels: []string{"1", "2"},
   200  		},
   201  		{
   202  			name: "last_entry_works",
   203  			entries: []*dummyEntry{
   204  				{label: "1", isDefault: true, do: nil},
   205  				{label: "2", isDefault: true, do: nil},
   206  				{label: "3", isDefault: true, do: errStopTestOnly},
   207  			},
   208  			// user just hits enter.
   209  			userEntry:    []byte("\r\n"),
   210  			calledLabels: []string{"1", "2", "3"},
   211  		},
   212  	}
   213  	for _, tt := range tests {
   214  		t.Run(tt.name, func(t *testing.T) {
   215  			t.Parallel()
   216  			pty, err := term.OpenPTY()
   217  			if err != nil {
   218  				t.Fatalf("%v", err)
   219  			}
   220  			defer pty.Close()
   221  
   222  			var entries []Entry
   223  			for _, e := range tt.entries {
   224  				entries = append(entries, e)
   225  			}
   226  
   227  			var wg sync.WaitGroup
   228  			wg.Add(1)
   229  			go func() {
   230  				ShowMenuAndBoot(pty.Slave, entries...)
   231  				wg.Done()
   232  			}()
   233  
   234  			// Well this sucks.
   235  			//
   236  			// We have to wait until Choose has actually started trying to read, as
   237  			// ttys are asynchronous.
   238  			//
   239  			// Know a better way? Halp.
   240  			time.Sleep(1 * time.Second)
   241  
   242  			if tt.userEntry != nil {
   243  				if _, err := pty.Master.Write(tt.userEntry); err != nil {
   244  					t.Fatalf("failed to write new-line: %v", err)
   245  				}
   246  			}
   247  
   248  			wg.Wait()
   249  
   250  			for _, entry := range tt.entries {
   251  				wantCalled := contains(tt.calledLabels, entry.label)
   252  				if wantCalled != entry.Called() {
   253  					t.Errorf("Entry %s gotCalled %t, wantCalled %t", entry.Label(), entry.Called(), wantCalled)
   254  				}
   255  			}
   256  		})
   257  	}
   258  }