github.com/ericwq/aprilsh@v0.0.0-20240517091432-958bc568daa0/frontend/client/client_test.go (about)

     1  // Copyright 2022~2024 wangqi. All rights reserved.
     2  // Use of this source code is governed by a MIT-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package main
     6  
     7  import (
     8  	"errors"
     9  	"io"
    10  	"os"
    11  	"strings"
    12  	"sync"
    13  	"testing"
    14  	"time"
    15  
    16  	"github.com/creack/pty"
    17  	"github.com/ericwq/aprilsh/frontend"
    18  )
    19  
    20  func TestPrintColors(t *testing.T) {
    21  	tc := []struct {
    22  		label  string
    23  		term   string
    24  		expect []string
    25  	}{
    26  		{"lookup terminfo failed", "NotExist", []string{"Dynamic load terminfo failed."}},
    27  		{"TERM is empty", "", []string{"The TERM is empty string."}},
    28  		{"TERM doesn't exit", "-remove", []string{"The TERM doesn't exist."}},
    29  		{"normal found", "xterm-256color", []string{"xterm-256color", "256"}},
    30  		// {"dynamic found", "xfce", []string{"xfce 8 (dynamic)"}},
    31  		{"dynamic not found", "xxx", []string{"Dynamic load terminfo failed."}},
    32  	}
    33  
    34  	for _, v := range tc {
    35  		t.Run(v.label, func(t *testing.T) {
    36  			// intercept stdout
    37  			saveStdout := os.Stdout
    38  			r, w, _ := os.Pipe()
    39  			os.Stdout = w
    40  			// save original TERM
    41  			term := os.Getenv("TERM")
    42  
    43  			// set TERM according to test case
    44  			if v.term == "-remove" {
    45  				os.Unsetenv("TERM")
    46  			} else {
    47  				os.Setenv("TERM", v.term)
    48  			}
    49  
    50  			printColors()
    51  
    52  			// restore stdout
    53  			w.Close()
    54  			b, _ := io.ReadAll(r)
    55  			os.Stdout = saveStdout
    56  			r.Close()
    57  
    58  			// validate the result
    59  			result := string(b)
    60  			found := 0
    61  			for i := range v.expect {
    62  				if strings.Contains(result, v.expect[i]) {
    63  					found++
    64  				}
    65  			}
    66  			if found != len(v.expect) {
    67  				t.Errorf("#test %s expect %q, got %q\n", v.label, v.expect, result)
    68  			}
    69  
    70  			// restore original TERM
    71  			os.Setenv("TERM", term)
    72  		})
    73  	}
    74  }
    75  
    76  func TestMainRun_Parameters(t *testing.T) {
    77  	tc := []struct {
    78  		label  string
    79  		args   []string
    80  		term   string
    81  		expect []string
    82  	}{
    83  		{
    84  			"no parameters",
    85  			[]string{frontend.CommandClientName},
    86  			"xterm-256color",
    87  			[]string{"destination (user@host[:port]) is mandatory."},
    88  		},
    89  		{
    90  			"just version",
    91  			[]string{frontend.CommandClientName, "-version"},
    92  			"xterm-256color",
    93  			[]string{
    94  				frontend.CommandClientName, frontend.AprilshPackageName,
    95  				"Copyright (c) 2022~2024 wangqi <ericwq057@qq.com>", "remote shell support intermittent or mobile network.",
    96  			},
    97  		},
    98  		{
    99  			"just help",
   100  			[]string{frontend.CommandClientName, "-h"},
   101  			"xterm-256color",
   102  			[]string{
   103  				"Usage:", frontend.CommandClientName, "Options:", "-c", "--colors",
   104  				"print the number of terminal color",
   105  			},
   106  		},
   107  		{
   108  			"just colors",
   109  			[]string{frontend.CommandClientName, "-c", "-v"},
   110  			"xterm-256color",
   111  			[]string{"xterm-256color", "256"},
   112  		},
   113  		{
   114  			"invalid target parameter",
   115  			[]string{frontend.CommandClientName, "invalid", "target", "parameter"},
   116  			"xterm-256color",
   117  			[]string{"only one destination (user@host[:port]) is allowed."},
   118  		},
   119  		{
   120  			"destination no second part",
   121  			[]string{frontend.CommandClientName, "malform@"},
   122  			"xterm-256color",
   123  			[]string{"destination should be in the form of user@host[:port]"},
   124  		},
   125  		{
   126  			"destination no first part",
   127  			[]string{frontend.CommandClientName, "@malform"},
   128  			"xterm-256color",
   129  			[]string{"destination should be in the form of user@host[:port]"},
   130  		},
   131  		{
   132  			"infvalid port number",
   133  			[]string{frontend.CommandClientName, "-p", "7s"},
   134  			"xterm-256color",
   135  			[]string{"invalid value \"7s\" for flag -p: parse error"},
   136  		},
   137  	}
   138  
   139  	for _, v := range tc {
   140  		t.Run(v.label, func(t *testing.T) {
   141  			// intercept stdout
   142  			saveStdout := os.Stdout
   143  			r, w, _ := os.Pipe()
   144  			os.Stdout = w
   145  
   146  			// prepare data
   147  			os.Args = v.args
   148  			os.Setenv("TERM", v.term)
   149  			// test main
   150  			main()
   151  
   152  			// restore stdout
   153  			w.Close()
   154  			out, _ := io.ReadAll(r)
   155  			os.Stdout = saveStdout
   156  			r.Close()
   157  
   158  			// validate the result
   159  			result := string(out)
   160  			found := 0
   161  			for i := range v.expect {
   162  				if strings.Contains(result, v.expect[i]) {
   163  					// fmt.Printf("found %s\n", expect[i])
   164  					found++
   165  				}
   166  			}
   167  			if found != len(v.expect) {
   168  				t.Errorf("#test expect %s, got \n%s\n", v.expect, result)
   169  			}
   170  		})
   171  	}
   172  }
   173  
   174  func TestBuildConfig(t *testing.T) {
   175  	targetMsg := "destination should be in the form of user@host[:port]"
   176  	modeMsg := _PREDICTION_DISPLAY + " unknown prediction mode."
   177  	tc := []struct {
   178  		label       string
   179  		target      string
   180  		predictMode string
   181  		expect      string
   182  		ok          bool
   183  	}{
   184  		{"valid target, empty mode", "usr@localhost", "", "", true},
   185  		{"valid target, lack of mode", "gig@factory", "mode", modeMsg, false},
   186  		{"valid target, valid mode", "vfab@factory", "aLwaYs", "", true},
   187  		{"invalid target", "factory", "", targetMsg, false},
   188  		{"invalid @target", "@factory", "", targetMsg, false},
   189  		{"invalid target@", "factory@", "", targetMsg, false},
   190  	}
   191  
   192  	for _, v := range tc {
   193  		t.Run(v.label, func(t *testing.T) {
   194  			var conf Config
   195  			conf.destination = []string{v.target}
   196  
   197  			// prepare parse result
   198  			var host string
   199  			var user string
   200  			idx := strings.Index(v.target, "@")
   201  			if idx > 0 && idx < len(v.target)-1 {
   202  				host = v.target[idx+1:]
   203  				user = v.target[:idx]
   204  			}
   205  
   206  			os.Setenv(_PREDICTION_DISPLAY, v.predictMode)
   207  
   208  			got, ok := conf.buildConfig()
   209  			if got != v.expect {
   210  				t.Errorf("#test buildConfig() %s expect %q, got %s\n", v.label, v.expect, got)
   211  			}
   212  			if conf.user != user || conf.host != host {
   213  				t.Errorf("#test buildConfig() %q config.user expect %s, got %s\n", v.label, user, conf.user)
   214  				t.Errorf("#test buildConfig() %q config.host expect %s, got %s\n", v.label, host, conf.host)
   215  			}
   216  			if conf.predictMode != strings.ToLower(v.predictMode) {
   217  				t.Errorf("#test buildConfig() conf.predictMode expect %q, got %q\n", v.predictMode, conf.predictMode)
   218  			}
   219  			if ok != v.ok {
   220  				t.Errorf("#test buildConfig() expect %t, got %t\n", v.ok, ok)
   221  			}
   222  		})
   223  	}
   224  }
   225  
   226  func TestBuildConfig2(t *testing.T) {
   227  	tc := []struct {
   228  		label     string
   229  		conf      *Config
   230  		expectStr string
   231  		ok        bool
   232  	}{
   233  		{"destination without port", &Config{destination: []string{"usr@host"}}, "", true},
   234  		{"destination with port", &Config{destination: []string{"usr@host:23"}}, "", true},
   235  		{"destination with wrong port",
   236  			&Config{destination: []string{"usr@host:a23"}}, "please check destination, illegal port number.", false},
   237  	}
   238  	for _, v := range tc {
   239  		t.Run(v.label, func(t *testing.T) {
   240  			got, ok := v.conf.buildConfig()
   241  			if ok != v.ok || got != v.expectStr {
   242  				t.Errorf("%q expect (%s,%t) got (%s,%t)\n", v.label, v.expectStr, v.ok, got, ok)
   243  			}
   244  		})
   245  	}
   246  }
   247  
   248  // func TestFetchKey(t *testing.T) {
   249  // 	tc := []struct {
   250  // 		label string
   251  // 		conf  *Config
   252  // 		pwd   string
   253  // 		msg   string
   254  // 	}{
   255  // 		{"wrong host", &Config{user: "ide", host: "wrong", port: 60000}, "password", "dial tcp"},
   256  // 	}
   257  // 	for _, v := range tc {
   258  // 		t.Run(v.label, func(t *testing.T) {
   259  // 			v.conf.pwd = v.pwd
   260  // 			got := v.conf.fetchKey()
   261  // 			if !strings.Contains(got.Error(), v.msg) {
   262  // 				t.Errorf("#test %q expect %q contains %q.\n", v.label, got, v.msg)
   263  // 			}
   264  // 		})
   265  // 	}
   266  // }
   267  
   268  func TestGetPassword(t *testing.T) {
   269  
   270  	tc := []struct {
   271  		label  string
   272  		conf   *Config
   273  		pwd    string //input
   274  		expect string
   275  	}{
   276  		{"normal get password", &Config{}, "password\n", "password"},
   277  		{"just CR", &Config{}, "\n", ""},
   278  	}
   279  	for _, v := range tc {
   280  		t.Run(v.label, func(t *testing.T) {
   281  			// intercept stdout
   282  			saveStdout := os.Stdout
   283  			r, w, _ := os.Pipe()
   284  			os.Stdout = w
   285  
   286  			// get password require pts file.
   287  			ptmx, pts, err := pty.Open()
   288  			if err != nil {
   289  				err = errors.New("invalid parameter")
   290  			}
   291  
   292  			// prepare input data
   293  			ptmx.WriteString(v.pwd)
   294  
   295  			got, err := getPassword("password", pts)
   296  
   297  			ptmx.Close()
   298  			pts.Close()
   299  
   300  			// restore stdout
   301  			w.Close()
   302  			out, _ := io.ReadAll(r)
   303  			os.Stdout = saveStdout
   304  			r.Close()
   305  
   306  			// validate the result.
   307  			if err != nil {
   308  				t.Errorf("#test %q report %s\n", v.label, err)
   309  			}
   310  			if got != v.expect {
   311  				t.Errorf("#test %q expect %q, got %q. out=%s\n", v.label, v.expect, got, out)
   312  			}
   313  
   314  		})
   315  	}
   316  }
   317  
   318  func TestGetPasswordFail(t *testing.T) {
   319  	// conf := &Config{}
   320  
   321  	// intercept stdout
   322  	saveStdout := os.Stdout
   323  	r, w, _ := os.Pipe()
   324  	os.Stdout = w
   325  
   326  	got, err := getPassword("password", r)
   327  
   328  	// restore stdout
   329  	w.Close()
   330  	out, _ := io.ReadAll(r)
   331  	os.Stdout = saveStdout
   332  	r.Close()
   333  
   334  	// validate, for non-tty input, getPassword return err: inappropriate ioctl for device
   335  	if err == nil {
   336  		t.Errorf("#test getPassword fail expt %q, got=%q, err=%s, out=%s\n", "", got, err, out)
   337  	}
   338  }
   339  
   340  func TestGetPasswordFail2(t *testing.T) {
   341  	// store stdout/in, open pts pair
   342  	ptmx, pts, err := pty.Open()
   343  	if err != nil {
   344  		t.Errorf("failed to open pts, %s\n", err)
   345  		return
   346  	}
   347  	saveStdout := os.Stdout
   348  	saveStdin := os.Stdin
   349  	os.Stdout = pts
   350  	os.Stdin = pts
   351  
   352  	expect := "hello world"
   353  
   354  	// provide the input
   355  	var wg sync.WaitGroup
   356  	wg.Add(1)
   357  	go func() {
   358  		defer wg.Done()
   359  		// make sure we provide input after the getPassword()
   360  		timer := time.NewTimer(time.Duration(2) * time.Millisecond)
   361  		<-timer.C
   362  		ptmx.WriteString(expect + "\n") // \n  is important for getPassword()
   363  	}()
   364  
   365  	// waiting for the input
   366  	wg.Add(1)
   367  	var got string
   368  	var err2 error
   369  	go func() {
   370  		defer wg.Done()
   371  		got, err2 = getPassword("password", pts)
   372  	}()
   373  	wg.Wait()
   374  
   375  	// close pts paire and restore stdou/stdin
   376  	ptmx.Close()
   377  	pts.Close()
   378  	os.Stdout = saveStdout
   379  	os.Stdin = saveStdin
   380  
   381  	// validate, for non-tty input, getPassword return err: inappropriate ioctl for device
   382  	if err2 != nil || got != expect {
   383  		t.Errorf("#test getPassword fail expt %q, got=%q, err=%s\n", expect, got, err)
   384  	}
   385  }
   386  
   387  func TestSshAgentFail(t *testing.T) {
   388  	tc := []struct {
   389  		label  string
   390  		env    bool
   391  		expect string
   392  	}{
   393  		{"lack of SSH_AUTH_SOCK", false, "Failed to connect ssh agent."},
   394  	}
   395  	for _, v := range tc {
   396  		t.Run(v.label, func(t *testing.T) {
   397  			old := os.Getenv("SSH_AUTH_SOCK")
   398  			defer os.Setenv("SSH_AUTH_SOCK", old)
   399  
   400  			// intercept stdout
   401  			saveStdout := os.Stdout
   402  			r, w, _ := os.Pipe()
   403  			os.Stdout = w
   404  
   405  			// clear SSH_AUTH_SOCK
   406  			if !v.env {
   407  				os.Unsetenv("SSH_AUTH_SOCK")
   408  			}
   409  			// run the test
   410  			sshAgent()
   411  
   412  			// restore stdout
   413  			w.Close()
   414  			out, _ := io.ReadAll(r)
   415  			os.Stdout = saveStdout
   416  			r.Close()
   417  
   418  			got := string(out)
   419  			if !strings.HasPrefix(got, v.expect) {
   420  				t.Errorf("%q expect %q got %q\n", v.label, v.expect, got)
   421  			}
   422  		})
   423  	}
   424  }
   425  
   426  func TestErrors(t *testing.T) {
   427  	tc := []struct {
   428  		label  string
   429  		error  error
   430  		expect string
   431  	}{
   432  		{"hostkeyChangeError", &hostkeyChangeError{hostname: "some.where"},
   433  			"REMOTE HOST IDENTIFICATION HAS CHANGED"},
   434  		{"responseErr without error", &responseError{}, "<nil>"},
   435  		{"responseErr error", &responseError{Msg: "hello", Err: errors.New("world")}, "hello, world"},
   436  	}
   437  	for _, v := range tc {
   438  		t.Run(v.label, func(t *testing.T) {
   439  
   440  			got := v.error.Error()
   441  			if !strings.Contains(got, v.expect) {
   442  				t.Errorf("%q expect %q got %q\n", v.label, v.expect, got)
   443  			}
   444  
   445  		})
   446  	}
   447  }
   448  
   449  func TestPublicKeyFileFail(t *testing.T) {
   450  	tc := []struct {
   451  		label  string
   452  		file   string
   453  		expect string
   454  	}{
   455  		{"file doesn't exist", "/do/es/not/exist", "Unable to read private key"},
   456  		{"is not private key", "/etc/hosts", "Unable to parse private key"},
   457  	}
   458  	for _, v := range tc {
   459  		t.Run(v.label, func(t *testing.T) {
   460  
   461  			// intercept stdout
   462  			saveStdout := os.Stdout
   463  			r, w, _ := os.Pipe()
   464  			os.Stdout = w
   465  
   466  			// run the test
   467  			publicKeyFile(v.file)
   468  
   469  			// restore stdout
   470  			w.Close()
   471  			out, _ := io.ReadAll(r)
   472  			os.Stdout = saveStdout
   473  			r.Close()
   474  
   475  			// validate the output
   476  			got := string(out)
   477  			if !strings.Contains(got, v.expect) {
   478  				t.Errorf("%q expect %q got %q\n", v.label, v.expect, got)
   479  			}
   480  		})
   481  	}
   482  }