github.com/ericwq/aprilsh@v0.0.0-20240517091432-958bc568daa0/frontend/server/server_test.go (about)

     1  // Copyright 2022~2024 wangqi. All rights reserved.
     2  // Use of this source code is governed by a MIT-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package main
     6  
     7  import (
     8  	"bytes"
     9  	"errors"
    10  	"flag"
    11  	"fmt"
    12  	"io"
    13  	"net"
    14  	"os"
    15  	"os/exec"
    16  	"reflect"
    17  	"runtime"
    18  	"strconv"
    19  	"strings"
    20  	"sync"
    21  	"syscall"
    22  	"testing"
    23  	"time"
    24  
    25  	"log/slog"
    26  
    27  	"github.com/creack/pty"
    28  	"github.com/ericwq/aprilsh/frontend"
    29  	"github.com/ericwq/aprilsh/network"
    30  	"github.com/ericwq/aprilsh/statesync"
    31  	"github.com/ericwq/aprilsh/util"
    32  	"golang.org/x/sys/unix"
    33  )
    34  
    35  func TestPrintMotd(t *testing.T) {
    36  	// darwin doesn't has the following motd files, so we add /etc/hosts for testing.
    37  	files := []string{"/run/motd.dynamic", "/var/run/motd.dynamic", "/etc/motd", "/etc/hosts"}
    38  
    39  	var output bytes.Buffer
    40  
    41  	found := false
    42  	for i := range files {
    43  		output.Reset()
    44  		if printMotd(&output, files[i]) {
    45  			if output.Len() > 0 { // we got and print the file content
    46  				found = true
    47  				break
    48  			}
    49  		}
    50  	}
    51  
    52  	// validate the result
    53  	if !found {
    54  		t.Errorf("#test expect found %s, found nothing\n", files)
    55  	}
    56  
    57  	output.Reset()
    58  
    59  	// creat a .hide file and write long token into it
    60  	fName := ".hide"
    61  	hide, _ := os.Create(fName)
    62  	for i := 0; i < 1025; i++ {
    63  		data := bytes.Repeat([]byte{'s'}, 64)
    64  		hide.Write(data)
    65  	}
    66  	hide.Close()
    67  
    68  	if printMotd(&output, fName) {
    69  		t.Errorf("#test printMotd should return false, instead it return true.")
    70  	}
    71  
    72  	os.Remove(fName)
    73  }
    74  
    75  func TestPrintVersion(t *testing.T) {
    76  	// intercept stdout
    77  	saveStdout := os.Stdout
    78  	r, w, _ := os.Pipe()
    79  	os.Stdout = w
    80  	// initLog()
    81  
    82  	expect := []string{frontend.CommandServerName, "version", "git commit", "wangqi <ericwq057@qq.com>"}
    83  
    84  	printVersion()
    85  
    86  	// restore stdout
    87  	w.Close()
    88  	b, _ := io.ReadAll(r)
    89  	os.Stdout = saveStdout
    90  	r.Close()
    91  
    92  	// validate the result
    93  	result := string(b)
    94  	found := 0
    95  	for i := range expect {
    96  		if strings.Contains(result, expect[i]) {
    97  			found++
    98  		}
    99  	}
   100  	if found != len(expect) {
   101  		t.Errorf("#test printVersion expect %q, got %q\n", expect, result)
   102  	}
   103  }
   104  
   105  var cmdOptions = "[-s] [-v[v]] [-i LOCALADDR] [-p PORT[:PORT2]] [-l NAME=VALUE] [-- command...]"
   106  
   107  func TestPrintUsage(t *testing.T) {
   108  	tc := []struct {
   109  		label  string
   110  		hints  string
   111  		expect []string
   112  	}{
   113  		{"no hint", "", []string{"Usage:", frontend.CommandServerName, cmdOptions}},
   114  		{"some hints", "some hints", []string{"Usage:", frontend.CommandServerName, "some hints", cmdOptions}},
   115  	}
   116  
   117  	for _, v := range tc {
   118  		t.Run(v.label, func(t *testing.T) {
   119  
   120  			out := captureOutputRun(func() {
   121  				frontend.PrintUsage(v.hints, usage)
   122  			})
   123  
   124  			// validate the result
   125  			result := string(out)
   126  			found := 0
   127  			for i := range v.expect {
   128  				if strings.Contains(result, v.expect[i]) {
   129  					found++
   130  				}
   131  			}
   132  			if found != len(v.expect) {
   133  				t.Errorf("#test printUsage expect %s, got %s\n", v.expect, result)
   134  			}
   135  		})
   136  	}
   137  }
   138  
   139  func TestChdirHomedir(t *testing.T) {
   140  	// save the current dir
   141  	oldPwd := os.Getenv("PWD")
   142  
   143  	// use the HOME
   144  	got := ""
   145  	if !chdirHomedir("") {
   146  		got = os.Getenv("PWD")
   147  		t.Errorf("#test chdirHomedir expect change to home directory, got %s\n", got)
   148  	}
   149  
   150  	// validate the PWD
   151  	got = os.Getenv("PWD")
   152  	// fmt.Printf("#test chdirHomedir home=%q\n", got)
   153  	if got == oldPwd {
   154  		t.Errorf("#test chdirHomedir home dir %q, is different from old dir %q\n", got, oldPwd)
   155  	}
   156  
   157  	// unset HOME
   158  	os.Unsetenv("HOME")
   159  	// validate the false
   160  	if chdirHomedir("") {
   161  		t.Errorf("#test chdirHomedir return false.\n")
   162  	}
   163  
   164  	// use the parameter as HOME
   165  	if chdirHomedir("/does/not/exist") {
   166  		t.Errorf("#test chdirHomedir should return false\n")
   167  	}
   168  
   169  	// restore the current dir and PWD
   170  	os.Chdir(oldPwd)
   171  	os.Setenv("PWD", oldPwd)
   172  }
   173  
   174  func TestGetHomeDir(t *testing.T) {
   175  	tc := []struct {
   176  		label  string
   177  		env    string
   178  		expect string
   179  	}{
   180  		{"normal case", "/home/aprish", "/home/aprish"},
   181  		{"no HOME case", "", ""}, // for unix anc macOS, no HOME means getHomeDir() return ""
   182  	}
   183  
   184  	for _, v := range tc {
   185  		oldHome := os.Getenv("HOME")
   186  		if v.env == "" { // unset HOME env
   187  			os.Unsetenv("HOME")
   188  		} else {
   189  			os.Setenv("HOME", v.env)
   190  		}
   191  		got := getHomeDir()
   192  
   193  		if got != v.expect {
   194  			t.Errorf("%s getHomeDir() expect %q got %q\n", v.label, v.expect, got)
   195  		}
   196  		os.Setenv("HOME", oldHome)
   197  	}
   198  }
   199  
   200  func TestMotdHushed(t *testing.T) {
   201  	label := "#test motdHushed "
   202  	if motdHushed() != false {
   203  		t.Errorf("%s should report false, got %t\n", label, motdHushed())
   204  	}
   205  
   206  	cmd := exec.Command("touch", ".hushlogin")
   207  	if err := cmd.Run(); err != nil {
   208  		t.Errorf("%s create .hushlogin failed, %s\n", label, err)
   209  	}
   210  	if motdHushed() != true {
   211  		t.Errorf("%s should report true, got %t\n", label, motdHushed())
   212  	}
   213  
   214  	cmd = exec.Command("rm", ".hushlogin")
   215  	if err := cmd.Run(); err != nil {
   216  		t.Errorf("%s delete .hushlogin failed, %s\n", label, err)
   217  	}
   218  }
   219  
   220  func TestMainHelp(t *testing.T) {
   221  	testHelpFunc := func() {
   222  		// prepare data
   223  		os.Args = []string{frontend.CommandServerName, "--help"}
   224  		// test help
   225  		main()
   226  	}
   227  
   228  	out := captureOutputRun(testHelpFunc)
   229  
   230  	// validate result
   231  	expect := []string{"Usage:", frontend.CommandServerName, cmdOptions}
   232  
   233  	// validate the result
   234  	result := string(out)
   235  	found := 0
   236  	for i := range expect {
   237  		if strings.Contains(result, expect[i]) {
   238  			found++
   239  		}
   240  	}
   241  	if found != len(expect) {
   242  		t.Errorf("#test printUsage expect %q, got %q\n", expect, result)
   243  	}
   244  }
   245  
   246  // capture the stdout and run the
   247  func captureOutputRun(f func()) []byte {
   248  	// save the stdout,stderr and create replaced pipe
   249  	stderr := os.Stderr
   250  	stdout := os.Stdout
   251  	r, w, _ := os.Pipe()
   252  	// replace stdout,stderr with pipe writer
   253  	// alll the output to stdout,stderr is captured
   254  	os.Stderr = w
   255  	os.Stdout = w
   256  
   257  	util.Logger.CreateLogger(w, true, slog.LevelDebug)
   258  
   259  	// os.Args is a "global variable", so keep the state from before the test, and restore it after.
   260  	oldArgs := os.Args
   261  	defer func() { os.Args = oldArgs }()
   262  
   263  	f()
   264  
   265  	// close pipe writer
   266  	w.Close()
   267  	// get the output
   268  	out, _ := io.ReadAll(r)
   269  	os.Stderr = stderr
   270  	os.Stdout = stdout
   271  	r.Close()
   272  
   273  	return out
   274  }
   275  
   276  func TestMainVersion(t *testing.T) {
   277  
   278  	testHelpFunc := func() {
   279  		// prepare data
   280  		os.Args = []string{frontend.CommandServerName, "--version"}
   281  		// test
   282  		main()
   283  
   284  	}
   285  
   286  	out := captureOutputRun(testHelpFunc)
   287  
   288  	// validate result
   289  	expect := []string{frontend.CommandServerName, "go version", "git commit", "wangqi <ericwq057@qq.com>",
   290  		"remote shell support intermittent or mobile network."}
   291  	result := string(out)
   292  	found := 0
   293  	for i := range expect {
   294  		if strings.Contains(result, expect[i]) {
   295  			found++
   296  		}
   297  	}
   298  	if found != len(expect) {
   299  		t.Errorf("#test printVersion expect %q, got %q\n", expect, result)
   300  	}
   301  }
   302  
   303  func TestMainParseFlagsError(t *testing.T) {
   304  	testFunc := func() {
   305  		// prepare data
   306  		os.Args = []string{frontend.CommandServerName, "--foo"}
   307  		// test
   308  		main()
   309  	}
   310  
   311  	out := captureOutputRun(testFunc)
   312  
   313  	// validate result
   314  	expect := []string{"flag provided but not defined: -foo"}
   315  	found := 0
   316  	for i := range expect {
   317  		if strings.Contains(string(out), expect[i]) {
   318  			found++
   319  		}
   320  	}
   321  	if found != len(expect) {
   322  		t.Errorf("#test parserError expect %q, got \n%s\n", expect, out)
   323  	}
   324  }
   325  
   326  func TestParseFlagsUsage(t *testing.T) {
   327  	usageArgs := []string{"-help", "-h", "--help"}
   328  
   329  	for _, arg := range usageArgs {
   330  		t.Run(arg, func(t *testing.T) {
   331  			conf, output, err := parseFlags("prog", []string{arg})
   332  			if err != flag.ErrHelp {
   333  				t.Errorf("err got %v, want ErrHelp", err)
   334  			}
   335  			if conf != nil {
   336  				t.Errorf("conf got %v, want nil", conf)
   337  			}
   338  			if strings.Index(output, "Usage of") < 0 {
   339  				t.Errorf("output can't find \"Usage of\": %q", output)
   340  			}
   341  		})
   342  	}
   343  }
   344  
   345  func TestMainRun(t *testing.T) {
   346  	tc := []struct {
   347  		label  string
   348  		args   []string
   349  		expect []string
   350  	}{
   351  		{"run main and killed by signal",
   352  			[]string{frontend.CommandServerName, "-locale",
   353  				"LC_ALL=en_US.UTF-8", "-p", "6100", "--", "/bin/sh", "-sh"},
   354  			[]string{frontend.CommandServerName, "start listening on", "gitTag",
   355  				/* "got signal: SIGHUP",  */ "got signal: SIGTERM or SIGINT",
   356  				"stop listening", "6100"}},
   357  		{"main killed by -a", // auto stop after 1 second
   358  			[]string{frontend.CommandServerName, "-verbose", "-auto", "1", "-locale",
   359  				"LC_ALL=en_US.UTF-8", "-p", "6200", "--", "/bin/sh", "-sh"},
   360  			[]string{frontend.CommandServerName, "start listening on", "gitTag",
   361  				"stop listening", "6200"}},
   362  		{"main killed by -a, write to syslog", // auto stop after 1 second
   363  			[]string{frontend.CommandServerName, "-auto", "1", "-locale",
   364  				"LC_ALL=en_US.UTF-8", "-p", "6300", "--", "/bin/sh", "-sh"},
   365  			[]string{}}, // log write to syslog, we can't get anything
   366  	}
   367  
   368  	for _, v := range tc {
   369  
   370  		if strings.Contains(v.label, "by signal") {
   371  			// shutdown after 15ms
   372  			time.AfterFunc(time.Duration(15)*time.Millisecond, func() {
   373  				util.Logger.Debug("#test kill process by signal")
   374  				syscall.Kill(os.Getpid(), syscall.SIGTERM)
   375  				// syscall.Kill(os.Getpid(), syscall.SIGHUP)
   376  			})
   377  		}
   378  
   379  		testFunc := func() {
   380  			os.Args = v.args
   381  			main()
   382  		}
   383  
   384  		out := captureOutputRun(testFunc)
   385  
   386  		// validate the result from printWelcome
   387  		result := string(out)
   388  		found := 0
   389  		for i := range v.expect {
   390  			if strings.Contains(result, v.expect[i]) {
   391  				// fmt.Printf("found %s\n", expect[i])
   392  				found++
   393  			}
   394  		}
   395  		if found != len(v.expect) {
   396  			t.Errorf("#test expect %q, got %s\n", v.expect, result)
   397  		}
   398  		// fmt.Printf("###\n%s\n###\n", string(out))
   399  	}
   400  }
   401  
   402  func testMainBuildConfigFail(t *testing.T) {
   403  	testFunc := func() {
   404  		// prepare parameter
   405  		os.Args = []string{frontend.CommandServerName, "-locale", "LC_ALL=en_US.UTF-8",
   406  			"-p", "6100", "--", "/bin/sh", "-sh"}
   407  		// test
   408  		main()
   409  	}
   410  
   411  	// prepare for buildConfig fail
   412  	// buildConfigTest = true
   413  	out := captureOutputRun(testFunc)
   414  
   415  	// restore the condition
   416  	// buildConfigTest = false
   417  
   418  	// validate the result
   419  	expect := []string{"needs a UTF-8 native locale to run"}
   420  	result := string(out)
   421  	found := 0
   422  	for i := range expect {
   423  		if strings.Contains(result, expect[i]) {
   424  			found++
   425  		}
   426  	}
   427  	if found != len(expect) {
   428  		t.Errorf("#test buildConfig() expect %q, got %s\n", expect, result)
   429  	}
   430  }
   431  
   432  func TestParseFlagsCorrect(t *testing.T) {
   433  	tc := []struct {
   434  		args []string
   435  		conf Config
   436  	}{
   437  		{
   438  			[]string{"-locale", "ALL=en_US.UTF-8", "-l", "LANG=UTF-8"},
   439  			Config{
   440  				version: false, server: false, verbose: 0, desiredIP: "", desiredPort: "8100",
   441  				locales:     localeFlag{"ALL": "en_US.UTF-8", "LANG": "UTF-8"},
   442  				commandPath: "", commandArgv: []string{}, withMotd: false,
   443  			},
   444  		},
   445  		{
   446  			[]string{"--", "/bin/sh", "-sh"},
   447  			Config{
   448  				version: false, server: false, verbose: 0, desiredIP: "", desiredPort: "8100",
   449  				locales:     localeFlag{},
   450  				commandPath: "", commandArgv: []string{"/bin/sh", "-sh"}, withMotd: false,
   451  			},
   452  		},
   453  		{
   454  			[]string{"--", ""},
   455  			Config{
   456  				version: false, server: false, verbose: 0, desiredIP: "", desiredPort: "8100",
   457  				locales:     localeFlag{},
   458  				commandPath: "", commandArgv: []string{""}, withMotd: false,
   459  			},
   460  		},
   461  	}
   462  
   463  	for _, v := range tc {
   464  		t.Run(strings.Join(v.args, " "), func(t *testing.T) {
   465  			conf, output, err := parseFlags("prog", v.args)
   466  			if err != nil {
   467  				t.Errorf("err got %v, want nil", err)
   468  			}
   469  			if output != "" {
   470  				t.Errorf("output got %q, want empty", output)
   471  			}
   472  			if !reflect.DeepEqual(*conf, v.conf) {
   473  				t.Logf("#test parseFlags got commandArgv=%+v\n", conf.commandArgv)
   474  				t.Errorf("conf got \n%+v, want \n%+v", *conf, v.conf)
   475  			}
   476  		})
   477  	}
   478  }
   479  
   480  func TestGetShell(t *testing.T) {
   481  	tc := []struct {
   482  		label  string
   483  		expect string
   484  	}{
   485  		{"get unix shell from cmd", "fill later"},
   486  	}
   487  
   488  	var err error
   489  	tc[0].expect, err = util.GetShell()
   490  	if err != nil {
   491  		t.Errorf("#test getShell() reports %q\n", err)
   492  	}
   493  
   494  	for _, v := range tc {
   495  		if got, _ := util.GetShell(); got != v.expect {
   496  			if got != v.expect {
   497  				t.Errorf("#test getShell() %s expect %q, got %q\n", v.label, v.expect, got)
   498  			}
   499  		}
   500  	}
   501  }
   502  
   503  func TestParseFlagsError(t *testing.T) {
   504  	tests := []struct {
   505  		args   []string
   506  		errstr string
   507  	}{
   508  		{[]string{"-foo"}, "flag provided but not defined"},
   509  		// {[]string{"-color", "joe"}, "invalid value"},
   510  		{[]string{"-locale", "a=b=c"}, "malform locale parameter"},
   511  	}
   512  
   513  	for _, tt := range tests {
   514  		t.Run(strings.Join(tt.args, " "), func(t *testing.T) {
   515  			conf, output, err := parseFlags("prog", tt.args)
   516  			if conf != nil {
   517  				t.Errorf("conf got %v, want nil", conf)
   518  			}
   519  			if strings.Index(err.Error(), tt.errstr) < 0 {
   520  				t.Errorf("err got %q, want to find %q", err.Error(), tt.errstr)
   521  			}
   522  			if strings.Index(output, "Usage of prog") < 0 {
   523  				t.Errorf("output got %q", output)
   524  			}
   525  		})
   526  	}
   527  }
   528  
   529  // func TestMainParameters(t *testing.T) {
   530  // 	// flag is a global variable, reset it before test
   531  // 	flag.CommandLine = flag.NewFlagSet("TestMainParameters", flag.ExitOnError)
   532  // 	testParaFunc := func() {
   533  // 		// prepare data
   534  // 		os.Args = []string{COMMAND_NAME, "--", "/bin/sh","-sh"} //"-l LC_ALL=en_US.UTF-8", "--"}
   535  // 		// test
   536  // 		main()
   537  // 	}
   538  //
   539  // 	out := captureStdoutRun(testParaFunc)
   540  //
   541  // 	// validate result
   542  // 	expect := []string{"main", "commandPath=", "commandArgv=", "withMotd=", "locales=", "color="}
   543  // 	result := string(out)
   544  // 	found := 0
   545  // 	for i := range expect {
   546  // 		if strings.Contains(result, expect[i]) {
   547  // 			found++
   548  // 		}
   549  // 	}
   550  // 	if found != len(expect) {
   551  // 		t.Errorf("#test main() expect %s, got %s\n", expect, result)
   552  // 	}
   553  // }
   554  
   555  func TestMainServerPortrangeError(t *testing.T) {
   556  	testFunc := func() {
   557  		os.Args = []string{frontend.CommandServerName, "-s", "-p=3a"}
   558  		os.Setenv("SSH_CONNECTION", "172.17.0.1 58774 172.17.0.2 22")
   559  		main()
   560  	}
   561  
   562  	out := captureOutputRun(testFunc)
   563  	// validate port range check
   564  	expect := "Bad UDP port"
   565  	got := string(out)
   566  	if !strings.Contains(got, expect) {
   567  		t.Errorf("#test --port should contains %q, got %s\n", expect, got)
   568  	}
   569  }
   570  
   571  func TestGetSSHip(t *testing.T) {
   572  	tc := []struct {
   573  		label  string
   574  		env    string
   575  		expect string
   576  		ok     bool
   577  	}{
   578  		{"no env variable", "", "Warning: SSH_CONNECTION not found; binding to any interface.", false},
   579  		{"ipv4 address", "172.17.0.1 58774 172.17.0.2 22", "172.17.0.2", true},
   580  		{"malform variable", " 1 2 3 4",
   581  			"Warning: Could not parse SSH_CONNECTION; binding to any interface.", false},
   582  		{"ipv6 address", "fe80::14d5:1215:f8c9:11fa%en0 42000 fe80::aede:48ff:fe00:1122%en5 22",
   583  			"fe80::aede:48ff:fe00:1122%en5", true},
   584  		{"ipv4 mapped address", "::FFFF:172.17.0.1 42200 ::FFFF:129.144.52.38 22", "129.144.52.38", true},
   585  	}
   586  
   587  	for _, v := range tc {
   588  
   589  		os.Setenv("SSH_CONNECTION", v.env)
   590  		got, ok := getSSHip()
   591  		if got != v.expect || ok != v.ok {
   592  			t.Errorf("%q expect %q, got %q, ok=%t\n", v.label, v.expect, got, ok)
   593  		}
   594  	}
   595  }
   596  
   597  func TestGetShellNameFrom(t *testing.T) {
   598  	tc := []struct {
   599  		label     string
   600  		shellPath string
   601  		shellName string
   602  	}{
   603  		{"normal", "/bin/sh", "-sh"},
   604  		{"no slash sign", "noslash", "-noslash"},
   605  	}
   606  
   607  	for _, v := range tc {
   608  		got := getShellNameFrom(v.shellPath)
   609  		if got != v.shellName {
   610  			t.Errorf("%q expect %q, got %q\n", v.label, v.shellName, got)
   611  		}
   612  	}
   613  }
   614  
   615  func TestGetTimeFrom(t *testing.T) {
   616  	tc := []struct {
   617  		lable      string
   618  		key, value string
   619  		expect     int64
   620  	}{
   621  		{"positive int64", "ENV1", "123", 123},
   622  		{"malform int64", "ENV2", "123a", 0},
   623  		{"negative int64", "ENV3", "-123", 0},
   624  	}
   625  
   626  	// save the stdout and create replaced pipe
   627  	rescueStdout := os.Stdout
   628  	r, w, _ := os.Pipe()
   629  	os.Stdout = w
   630  	// initLog()
   631  
   632  	oldArgs := os.Args
   633  	defer func() { os.Args = oldArgs }()
   634  
   635  	for _, v := range tc {
   636  		os.Setenv(v.key, v.value)
   637  
   638  		got := getTimeFrom(v.key, 0)
   639  		if got != v.expect {
   640  			t.Errorf("%s expct %d, got %d\n", v.lable, v.expect, got)
   641  		}
   642  	}
   643  
   644  	// read and restore the stdout
   645  	w.Close()
   646  	io.ReadAll(r)
   647  	os.Stdout = rescueStdout
   648  }
   649  
   650  /*
   651  	func testPTY() error {
   652  		// Create arbitrary command.
   653  		c := exec.Command("bash")
   654  
   655  		// Start the command with a pty.
   656  		ptmx, err := pty.Start(c)
   657  		if err != nil {
   658  			return err
   659  		}
   660  		// Make sure to close the pty at the end.
   661  		defer func() { _ = ptmx.Close() }() // Best effort.
   662  
   663  		// Handle pty size.
   664  		ch := make(chan os.Signal, 1)
   665  		signal.Notify(ch, syscall.SIGWINCH)
   666  		go func() {
   667  			for range ch {
   668  				if err := pty.InheritSize(os.Stdin, ptmx); err != nil {
   669  					log.Printf("error resizing pty: %s", err)
   670  				}
   671  			}
   672  		}()
   673  		ch <- syscall.SIGWINCH                        // Initial resize.
   674  		defer func() { signal.Stop(ch); close(ch) }() // Cleanup signals when done.
   675  
   676  		// Set stdin in raw mode.
   677  		oldState, err := term.MakeRaw(int(os.Stdin.Fd()))
   678  		if err != nil {
   679  			panic(err)
   680  		}
   681  		defer func() { _ = term.Restore(int(os.Stdin.Fd()), oldState) }() // Best effort.
   682  
   683  		// Copy stdin to the pty and the pty to stdout.
   684  		// NOTE: The goroutine will keep reading until the next keystroke before returning.
   685  		go func() { _, _ = io.Copy(ptmx, os.Stdin) }()
   686  		_, _ = io.Copy(os.Stdout, ptmx)
   687  
   688  		return nil
   689  	}
   690  */
   691  
   692  func TestMainSrvStart(t *testing.T) {
   693  	tc := []struct {
   694  		label    string
   695  		pause    int    // pause between client send and read
   696  		resp     string // response client read
   697  		shutdown int    // pause before shutdown message
   698  		conf     Config
   699  	}{
   700  		{
   701  			"start normally", 100, frontend.AprilshMsgOpen + "7101,", 150,
   702  			Config{
   703  				version: false, server: true, verbose: 0, desiredIP: "", desiredPort: "7100",
   704  				locales:     localeFlag{"LC_ALL": "en_US.UTF-8", "LANG": "en_US.UTF-8"},
   705  				commandPath: "/bin/sh", commandArgv: []string{"/bin/sh"}, withMotd: false,
   706  				addSource: false,
   707  			},
   708  		},
   709  	}
   710  
   711  	if runtime.GOARCH == "riscv64" {
   712  		t.Skip("riscv64 timer is not as accurate as other platform, skip this test.")
   713  	}
   714  	// the test start child process, which is /usr/bin/apshd
   715  	// which means you need to compile /usr/bin/apshd before test
   716  	for _, v := range tc {
   717  		t.Run(v.label, func(t *testing.T) {
   718  			// init log
   719  			// util.Logger.CreateLogger(io.Discard, true, slog.LevelDebug)
   720  			util.Logger.CreateLogger(os.Stderr, true, slog.LevelDebug)
   721  
   722  			srv := newMainSrv(&v.conf)
   723  
   724  			// send shutdown message after some time
   725  			timer1 := time.NewTimer(time.Duration(v.shutdown) * time.Millisecond)
   726  			go func() {
   727  				<-timer1.C
   728  				// fmt.Printf("#test start PID:%d\n", os.Getpid())
   729  				// all the go test run in the same process
   730  				// syscall.Kill(os.Getpid(), syscall.SIGHUP)
   731  				// syscall.Kill(os.Getpid(), syscall.SIGTERM)
   732  				srv.downChan <- true
   733  				// stop the worker correctly, because mockRunWorker2 failed to
   734  				// do it on purpose.
   735  				// srv.exChan <- fmt.Sprintf("%d", srv.maxPort)
   736  			}()
   737  
   738  			srv.start(&v.conf)
   739  
   740  			// mock client operation
   741  			// fmt.Printf("#test mark=%d\n", 100)
   742  			resp := mockClient(v.conf.desiredPort, v.pause, frontend.AprilshMsgOpen)
   743  			// fmt.Printf("#test mark=%s\n", resp)
   744  			if !strings.Contains(resp, v.resp) {
   745  				t.Errorf("#test run expect %q got %q\n", v.resp, resp)
   746  			}
   747  
   748  			srv.wait()
   749  			// e, err := os.Executable()
   750  			// fmt.Fprintf(os.Stderr, "Executable=%s, err=%s\n", e, err)
   751  			// fmt.Fprintf(os.Stderr, "Args[0]   =%s\n", os.Args[0])
   752  			// fmt.Fprintf(os.Stderr, "CWD       =%s\n", os.Args[0])
   753  		})
   754  	}
   755  }
   756  
   757  func TestStartFail(t *testing.T) {
   758  	tc := []struct {
   759  		label  string
   760  		pause  int    // pause between client send and read
   761  		resp   string // response client read
   762  		finish int    // pause before shutdown message
   763  		conf   Config
   764  	}{
   765  		{
   766  			"illegal port", 20, "", 150,
   767  			Config{
   768  				version: false, server: true, verbose: 0, desiredIP: "", desiredPort: "7000a",
   769  				locales:     localeFlag{"LC_ALL": "en_US.UTF-8", "LANG": "en_US.UTF-8"},
   770  				commandPath: "/bin/sh", commandArgv: []string{"/bin/sh"}, withMotd: false,
   771  			},
   772  		},
   773  	}
   774  
   775  	for _, v := range tc {
   776  		t.Run(v.label, func(t *testing.T) {
   777  			// intercept logW
   778  			var b strings.Builder
   779  			util.Logger.CreateLogger(&b, true, slog.LevelDebug)
   780  
   781  			// srv := newMainSrv(&v.conf, mockRunWorker)
   782  			m := newMainSrv(&v.conf)
   783  
   784  			// defer func() {
   785  			// 	logW = log.New(os.Stdout, "WARN: ", log.Ldate|log.Ltime|log.Lshortfile)
   786  			// }()
   787  
   788  			// start mainserver
   789  			m.start(&v.conf)
   790  			// fmt.Println("#test start fail!")
   791  
   792  			// validate result: result contains WARN and COMMAND_NAME
   793  			expect := []string{"WARN", "listen failed"}
   794  			result := b.String()
   795  			found := 0
   796  			for i := range expect {
   797  				if strings.Contains(result, expect[i]) {
   798  					found++
   799  				}
   800  			}
   801  			if found != 2 {
   802  				t.Errorf("#test start() expect %q, got %q\n", expect, result)
   803  			}
   804  		})
   805  	}
   806  }
   807  
   808  // the mock runWorker send the key, pause some time and close the
   809  // worker by send finish message
   810  func mockRunWorker(conf *Config, exChan chan string, whChan chan workhorse) error {
   811  	// send the mock key
   812  	// fmt.Println("#mockRunWorker send mock key to run().")
   813  	exChan <- "This is the mock key"
   814  
   815  	// pause some time
   816  	time.Sleep(time.Duration(2) * time.Millisecond)
   817  
   818  	whChan <- workhorse{}
   819  
   820  	// notify the server
   821  	// fmt.Println("#mockRunWorker finish run().")
   822  	exChan <- conf.desiredPort
   823  	return nil
   824  }
   825  
   826  // the mock runWorker send the key, pause some time and try to close the
   827  // worker by send wrong finish message: port+"x"
   828  func mockRunWorker2(conf *Config, exChan chan string, whChan chan workhorse) error {
   829  	// send the mock key
   830  	exChan <- "mock key from mockRunWorker2"
   831  
   832  	// pause some time
   833  	time.Sleep(time.Duration(2) * time.Millisecond)
   834  
   835  	// fail to stop the worker on purpose
   836  	exChan <- conf.desiredPort + "x"
   837  
   838  	whChan <- workhorse{}
   839  
   840  	return nil
   841  }
   842  
   843  // mock client connect to the port, send handshake message, pause some time
   844  // return the response message.
   845  func mockClient(port string, pause int, action string, ex ...string) string {
   846  	server_addr, _ := net.ResolveUDPAddr("udp", "localhost:"+port)
   847  	local_addr, _ := net.ResolveUDPAddr("udp", "localhost:0")
   848  	conn, _ := net.DialUDP("udp", local_addr, server_addr)
   849  
   850  	defer conn.Close()
   851  
   852  	// send handshake message based on action & port
   853  	var txbuf []byte
   854  	switch action {
   855  	case frontend.AprilshMsgOpen:
   856  		switch len(ex) {
   857  		case 0:
   858  			txbuf = []byte(frontend.AprilshMsgOpen + "xterm," + getCurrentUser() + "@localhost")
   859  		case 1:
   860  			// the request missing the ','
   861  			txbuf = []byte(fmt.Sprintf("%s%s", frontend.AprilshMsgOpen, ex[0]))
   862  		}
   863  	case frontend.AprishMsgClose:
   864  		p, _ := strconv.Atoi(port)
   865  		switch len(ex) {
   866  		case 0:
   867  			txbuf = []byte(fmt.Sprintf("%s%d", frontend.AprishMsgClose, p+1))
   868  		case 1:
   869  			p2, err := strconv.Atoi(ex[0])
   870  			if err == nil {
   871  				txbuf = []byte(fmt.Sprintf("%s%d", frontend.AprishMsgClose, p2)) // 1 digital parameter: wrong port
   872  			} else {
   873  				txbuf = []byte(fmt.Sprintf("%s%s", frontend.AprishMsgClose, ex[0])) // 1 str parameter: malform port
   874  			}
   875  		case 2:
   876  			txbuf = []byte(fmt.Sprintf("%s%d", "unknow header:", p+1)) // 2 parameters: unknow header
   877  		}
   878  	}
   879  
   880  	_, err := conn.Write(txbuf)
   881  	// fmt.Printf("#mockClient send %q to server: %v from %v\n", txbuf, server_addr, conn.LocalAddr())
   882  	if err != nil {
   883  		fmt.Printf("#mockClient send %s, error %s\n", string(txbuf), err)
   884  	}
   885  
   886  	// pause some time
   887  	time.Sleep(time.Duration(pause) * time.Millisecond)
   888  
   889  	// read the response
   890  	rxbuf := make([]byte, 512)
   891  	n, _, err := conn.ReadFromUDP(rxbuf)
   892  
   893  	// fmt.Printf("#mockClient read %q from server: %v\n", rxbuf[0:n], server_addr)
   894  	return string(rxbuf[0:n])
   895  }
   896  
   897  func TestPrintWelcome(t *testing.T) {
   898  	// open pts master and slave first.
   899  	pty, tty, err := pty.Open()
   900  	if err != nil {
   901  		t.Errorf("#test printWelcome Open %s\n", err)
   902  	}
   903  
   904  	// clean pts fd
   905  	defer func() {
   906  		if err != nil {
   907  			pty.Close()
   908  			tty.Close()
   909  		}
   910  	}()
   911  
   912  	// pty master doesn't support IUTF8
   913  	flag, err := util.CheckIUTF8(int(pty.Fd()))
   914  	if flag {
   915  		t.Errorf("#test printWelcome master got %t, expect %t\n", flag, false)
   916  	}
   917  
   918  	expect := []string{"Warning: termios IUTF8 flag not defined."}
   919  
   920  	tc := []struct {
   921  		label string
   922  		tty   *os.File
   923  	}{
   924  		{"tty doesn't support IUTF8 flag", pty},
   925  		{"tty failed with checkIUTF8", os.Stdin},
   926  	}
   927  
   928  	for _, v := range tc {
   929  		// intercept stdout
   930  		saveStdout := os.Stdout
   931  		r, w, _ := os.Pipe()
   932  		os.Stdout = w
   933  		util.Logger.CreateLogger(w, true, slog.LevelDebug)
   934  
   935  		// printWelcome(os.Getpid(), 6000, v.tty)
   936  		printWelcome(v.tty)
   937  
   938  		// restore stdout
   939  		w.Close()
   940  		b, _ := io.ReadAll(r)
   941  		os.Stdout = saveStdout
   942  		r.Close()
   943  
   944  		// validate the result
   945  		result := string(b)
   946  		found := 0
   947  		for i := range expect {
   948  			if strings.Contains(result, expect[i]) {
   949  				found++
   950  			}
   951  		}
   952  		if found != len(expect) {
   953  			t.Errorf("#test printWelcome expect %q, got %s\n", expect, result)
   954  		}
   955  	}
   956  }
   957  
   958  func TestListenFail(t *testing.T) {
   959  	tc := []struct {
   960  		label  string
   961  		port   string
   962  		repeat bool // if true, will listen twice.
   963  	}{
   964  		{"illegal port number", "22a", false},
   965  		{"port already in use", "60001", true}, // 60001 is the docker port on macOS
   966  	}
   967  	for _, v := range tc {
   968  		conf := &Config{desiredPort: v.port}
   969  		// s := newMainSrv(conf, mockRunWorker)
   970  		s := newMainSrv(conf)
   971  
   972  		var e error
   973  		e = s.listen(conf)
   974  		// fmt.Printf("#test %q got 1st error: %q\n", v.label, e)
   975  		if v.repeat {
   976  			e = s.listen(conf)
   977  			// fmt.Printf("#test %q got 2nd error: %q\n", v.label, e)
   978  		}
   979  
   980  		// check the error does happens
   981  		if e == nil {
   982  			t.Errorf("#test %q expect error return, got nil\n", v.label)
   983  		}
   984  
   985  		// close the listen port
   986  		if v.repeat {
   987  			s.exChan <- conf.desiredPort
   988  		}
   989  	}
   990  }
   991  
   992  // func testRunFail(t *testing.T) {
   993  // 	tc := []struct {
   994  // 		label  string
   995  // 		pause  int    // pause between client send and read
   996  // 		resp   string // response client read
   997  // 		finish int    // pause before shutdown message
   998  // 		conf   Config
   999  // 	}{
  1000  // 		{
  1001  // 			"worker failed with wrong port number", 100, frontend.AprilshMsgOpen + "7101,mock key from mockRunWorker2\n", 30,
  1002  // 			Config{
  1003  // 				version: false, server: true, verbose: 1, desiredIP: "", desiredPort: "7100",
  1004  // 				locales:     localeFlag{"LC_ALL": "en_US.UTF-8", "LANG": "en_US.UTF-8"},
  1005  // 				commandPath: "/bin/sh", commandArgv: []string{"/bin/sh"}, withMotd: false,
  1006  // 				addSource: false,
  1007  // 			},
  1008  // 		},
  1009  // 	}
  1010  //
  1011  // 	for _, v := range tc {
  1012  // 		t.Run(v.label, func(t *testing.T) {
  1013  // 			// intercept stdout
  1014  // 			saveStdout := os.Stdout
  1015  // 			r, w, _ := os.Pipe()
  1016  // 			os.Stdout = w
  1017  // 			// initLog()
  1018  //
  1019  // 			// util.Logger.CreateLogger(w, true, slog.LevelDebug)
  1020  // 			util.Logger.CreateLogger(os.Stderr, true, slog.LevelDebug)
  1021  //
  1022  // 			// srv := newMainSrv(&v.conf, mockRunWorker2)
  1023  // 			srv := newMainSrv(&v.conf)
  1024  //
  1025  // 			// send shutdown message after some time
  1026  // 			timer1 := time.NewTimer(time.Duration(v.finish) * time.Millisecond)
  1027  // 			go func() {
  1028  // 				<-timer1.C
  1029  // 				// prepare to shudown the mainSrv
  1030  // 				// syscall.Kill(syscall.Getpid(), syscall.SIGTERM)
  1031  // 				srv.downChan <- true
  1032  // 				// stop the worker correctly, because mockRunWorker2 failed to
  1033  // 				// do it on purpose.
  1034  // 				port, _ := strconv.Atoi(v.conf.desiredPort)
  1035  // 				srv.exChan <- fmt.Sprintf("%d", port+1)
  1036  // 				util.Logger.Debug("send port to exChan", "port", port+1)
  1037  // 			}()
  1038  // 			// fmt.Println("#test start timer for shutdown")
  1039  //
  1040  // 			srv.start(&v.conf)
  1041  //
  1042  // 			// mock client operation
  1043  // 			resp := mockClient(v.conf.desiredPort, v.pause, frontend.AprilshMsgOpen)
  1044  //
  1045  // 			// validate the result.
  1046  // 			if resp != v.resp {
  1047  // 				t.Errorf("#test run expect %q got %q\n", v.resp, resp)
  1048  // 			}
  1049  //
  1050  // 			srv.wait()
  1051  //
  1052  // 			// restore stdout
  1053  // 			w.Close()
  1054  // 			io.ReadAll(r)
  1055  // 			os.Stdout = saveStdout
  1056  // 			r.Close()
  1057  // 		})
  1058  // 	}
  1059  //
  1060  // 	// test case for run() without connection
  1061  //
  1062  // 	srv2 := &mainSrv{}
  1063  // 	srv2.run(&Config{})
  1064  // }
  1065  
  1066  func TestRunFail2(t *testing.T) {
  1067  	tc := []struct {
  1068  		label  string
  1069  		pause  int    // pause between client send and read
  1070  		resp   string // response client read
  1071  		finish int    // pause before shutdown message
  1072  		conf   Config
  1073  	}{
  1074  		{
  1075  			"read udp error", 20, "7101,This is the mock key", 150,
  1076  			Config{
  1077  				version: false, server: true, verbose: 0, desiredIP: "", desiredPort: "7100",
  1078  				locales:     localeFlag{"LC_ALL": "en_US.UTF-8", "LANG": "en_US.UTF-8"},
  1079  				commandPath: "/bin/sh", commandArgv: []string{"/bin/sh"}, withMotd: false,
  1080  			},
  1081  		},
  1082  	}
  1083  
  1084  	for _, v := range tc {
  1085  		t.Run(v.label, func(t *testing.T) {
  1086  			// intercept stdout
  1087  			saveStdout := os.Stdout
  1088  			r, w, _ := os.Pipe()
  1089  			os.Stdout = w
  1090  			// initLog()
  1091  			util.Logger.CreateLogger(w, true, slog.LevelDebug)
  1092  
  1093  			// srv := newMainSrv(&v.conf, mockRunWorker)
  1094  			srv := newMainSrv(&v.conf)
  1095  
  1096  			// send shutdown message after some time
  1097  			timer1 := time.NewTimer(time.Duration(v.finish) * time.Millisecond)
  1098  			go func() {
  1099  				<-timer1.C
  1100  				srv.downChan <- true
  1101  				// syscall.Kill(syscall.Getpid(), syscall.SIGTERM)
  1102  			}()
  1103  			// fmt.Println("#test start timer for shutdown")
  1104  
  1105  			srv.start(&v.conf)
  1106  
  1107  			// close the connection, this will cause read error: use of closed network connection.
  1108  			srv.conn.Close()
  1109  
  1110  			srv.wait()
  1111  
  1112  			// restore stdout
  1113  			w.Close()
  1114  			io.ReadAll(r)
  1115  			os.Stdout = saveStdout
  1116  			r.Close()
  1117  		})
  1118  	}
  1119  }
  1120  
  1121  func TestMaxPortLimit(t *testing.T) {
  1122  	tc := []struct {
  1123  		label        string
  1124  		maxPortLimit int
  1125  		pause        int    // pause between client send and read
  1126  		resp         string // response client read
  1127  		shutdownTime int    // pause before shutdown message
  1128  		conf         Config
  1129  	}{
  1130  		{
  1131  			"run() over max port", 0, 20, "over max port limit", 150,
  1132  			Config{
  1133  				version: false, server: true, verbose: 0, desiredIP: "", desiredPort: "7700",
  1134  				locales:     localeFlag{"LC_ALL": "en_US.UTF-8", "LANG": "en_US.UTF-8"},
  1135  				commandPath: "/bin/sh", commandArgv: []string{"/bin/sh"}, withMotd: false,
  1136  			},
  1137  		},
  1138  	}
  1139  
  1140  	for _, v := range tc {
  1141  		t.Run(v.label, func(t *testing.T) {
  1142  			// intercept stdout
  1143  
  1144  			util.Logger.CreateLogger(io.Discard, true, slog.LevelDebug)
  1145  
  1146  			// init mainSrv and workers
  1147  			// m := newMainSrv(&v.conf, runWorker)
  1148  			m := newMainSrv(&v.conf)
  1149  
  1150  			// save maxPortLimit
  1151  			old := maxPortLimit
  1152  			maxPortLimit = v.maxPortLimit
  1153  
  1154  			// send shutdown message after some time
  1155  			timer1 := time.NewTimer(time.Duration(v.shutdownTime) * time.Millisecond)
  1156  			go func() {
  1157  				<-timer1.C
  1158  				m.downChan <- true
  1159  			}()
  1160  
  1161  			// start mainserver
  1162  			m.start(&v.conf)
  1163  
  1164  			// mock client operation
  1165  			resp := mockClient(v.conf.desiredPort, v.pause, frontend.AprilshMsgOpen)
  1166  
  1167  			m.wait()
  1168  
  1169  			if !strings.Contains(resp, v.resp) {
  1170  				t.Errorf("%q expect response %q, got %q\n ", v.label, v.resp, resp)
  1171  			}
  1172  
  1173  			// restore maxPortLimit
  1174  			maxPortLimit = old
  1175  		})
  1176  	}
  1177  }
  1178  
  1179  func TestMalformRequest(t *testing.T) {
  1180  	tc := []struct {
  1181  		label        string
  1182  		pause        int    // pause between client send and read
  1183  		resp         string // response client read
  1184  		shutdownTime int    // pause before shutdown message
  1185  		conf         Config
  1186  	}{
  1187  		{
  1188  			"run() malform request", 20, "malform request", 150,
  1189  			Config{
  1190  				version: false, server: true, verbose: 0, desiredIP: "", desiredPort: "7700",
  1191  				locales:     localeFlag{"LC_ALL": "en_US.UTF-8", "LANG": "en_US.UTF-8"},
  1192  				commandPath: "/bin/sh", commandArgv: []string{"/bin/sh"}, withMotd: false,
  1193  			},
  1194  		},
  1195  	}
  1196  
  1197  	for _, v := range tc {
  1198  		// intercept stdout
  1199  
  1200  		util.Logger.CreateLogger(io.Discard, true, slog.LevelDebug)
  1201  
  1202  		// init mainSrv and workers
  1203  		// m := newMainSrv(&v.conf, runWorker)
  1204  		m := newMainSrv(&v.conf)
  1205  
  1206  		// send shutdown message after some time
  1207  		timer1 := time.NewTimer(time.Duration(v.shutdownTime) * time.Millisecond)
  1208  		go func() {
  1209  			<-timer1.C
  1210  			syscall.Kill(os.Getpid(), syscall.SIGHUP) // add SIGHUP test condition
  1211  			time.Sleep(time.Duration(v.shutdownTime+5) * time.Millisecond)
  1212  			m.downChan <- true
  1213  		}()
  1214  
  1215  		// start mainserver
  1216  		m.start(&v.conf)
  1217  
  1218  		// mock client operation
  1219  		resp := mockClient(v.conf.desiredPort, v.pause, frontend.AprilshMsgOpen, "extraParam")
  1220  
  1221  		m.wait()
  1222  
  1223  		if !strings.Contains(resp, v.resp) {
  1224  			t.Errorf("%q expect response %q, got %q\n ", v.label, v.resp, resp)
  1225  		}
  1226  	}
  1227  }
  1228  
  1229  func mockServe(ptmx *os.File, pts *os.File, pw *io.PipeWriter, terminal *statesync.Complete, // x chan bool,
  1230  	network *network.Transport[*statesync.Complete, *statesync.UserStream],
  1231  	networkTimeout int64, networkSignaledTimeout int64, user string) error {
  1232  	time.Sleep(10 * time.Millisecond)
  1233  	// x <- true
  1234  	return nil
  1235  }
  1236  
  1237  // the mock runWorker send empty key, pause some time and close the worker
  1238  func failRunWorker(conf *Config, exChan chan string, whChan chan *workhorse) error {
  1239  	// send the empty key
  1240  	// fmt.Println("#mockRunWorker send mock key to run().")
  1241  	exChan <- ""
  1242  
  1243  	// pause some time
  1244  	time.Sleep(time.Duration(2) * time.Millisecond)
  1245  
  1246  	// notify this worker is done
  1247  	defer func() {
  1248  		exChan <- conf.desiredPort
  1249  	}()
  1250  
  1251  	whChan <- &workhorse{}
  1252  	return errors.New("failed worker.")
  1253  }
  1254  
  1255  func TestRunWorkerKillSignal(t *testing.T) {
  1256  	tc := []struct {
  1257  		label  string
  1258  		pause  int    // pause between client send and read
  1259  		resp   string // response client read
  1260  		finish int    // pause before shutdown message
  1261  		conf   Config
  1262  	}{
  1263  		{
  1264  			"runWorker stopped by signal kill", 10, frontend.AprilshMsgOpen + "7101,", 150,
  1265  			Config{
  1266  				version: false, server: true, flowControl: _FC_SKIP_PIPE_LOCK, desiredIP: "", desiredPort: "7100",
  1267  				locales:     localeFlag{"LC_ALL": "en_US.UTF-8", "LANG": "en_US.UTF-8"},
  1268  				commandPath: "/bin/sh", commandArgv: []string{"-sh"}, withMotd: true,
  1269  			},
  1270  		},
  1271  	}
  1272  
  1273  	for _, v := range tc {
  1274  		t.Run(v.label, func(t *testing.T) {
  1275  
  1276  			// intercept stdout
  1277  			saveStdout := os.Stdout
  1278  			r, w, _ := os.Pipe()
  1279  			os.Stdout = w
  1280  
  1281  			util.Logger.CreateLogger(w, true, slog.LevelDebug)
  1282  			// util.Logger.CreateLogger(os.Stderr, true, slog.LevelDebug)
  1283  
  1284  			// set serve func and runWorker func
  1285  			v.conf.serve = mockServe
  1286  			// srv := newMainSrv(&v.conf, runWorker)
  1287  			srv := newMainSrv(&v.conf)
  1288  
  1289  			/// set commandPath and commandArgv based on environment
  1290  			v.conf.commandPath = os.Getenv("SHELL")
  1291  			v.conf.commandArgv = []string{getShellNameFrom(v.conf.commandPath)}
  1292  
  1293  			// send kill signal after some time (finish ms)
  1294  			timer1 := time.NewTimer(time.Duration(v.finish) * time.Millisecond)
  1295  			go func() {
  1296  				<-timer1.C
  1297  				srv.downChan <- true
  1298  			}()
  1299  
  1300  			srv.start(&v.conf)
  1301  
  1302  			// mock client operation
  1303  			resp := mockClient(v.conf.desiredPort, v.pause, frontend.AprilshMsgOpen)
  1304  			if !strings.HasPrefix(resp, v.resp) {
  1305  				t.Errorf("#test run expect %q got %q\n", v.resp, resp)
  1306  			}
  1307  
  1308  			srv.wait()
  1309  
  1310  			// restore stdout
  1311  			w.Close()
  1312  			io.ReadAll(r)
  1313  			os.Stdout = saveStdout
  1314  			r.Close()
  1315  		})
  1316  	}
  1317  }
  1318  
  1319  // func testRunWorkerFail(t *testing.T) {
  1320  // 	tc := []struct {
  1321  // 		label string
  1322  // 		conf  Config
  1323  // 	}{
  1324  // 		{
  1325  // 			"openPTS fail", Config{
  1326  // 				version: false, server: true, flowControl: _FC_OPEN_PTS_FAIL, desiredIP: "", desiredPort: "7100",
  1327  // 				locales: localeFlag{"LC_ALL": "en_US.UTF-8", "LANG": "en_US.UTF-8"}, term: "kitty",
  1328  // 				commandPath: "/bin/xxxsh", commandArgv: []string{"-sh"}, withMotd: false,
  1329  // 			},
  1330  // 		},
  1331  // 		{
  1332  // 			"startShell fail", Config{
  1333  // 				version: false, server: true, flowControl: _FC_SKIP_START_SHELL, desiredIP: "", desiredPort: "7200",
  1334  // 				locales: localeFlag{"LC_ALL": "en_US.UTF-8", "LANG": "en_US.UTF-8"}, term: "kitty",
  1335  // 				commandPath: "/bin/xxxsh", commandArgv: []string{"-sh"}, withMotd: false,
  1336  // 			},
  1337  // 		},
  1338  // 		// {
  1339  // 		// 	"shell.Wait fail", Config{
  1340  // 		// 		version: false, server: true, verbose: _VERBOSE_SKIP_READ_PIPE, desiredIP: "", desiredPort: "7300",
  1341  // 		// 		locales: localeFlag{"LC_ALL": "en_US.UTF-8", "LANG": "en_US.UTF-8"}, term: "kitty",
  1342  // 		// 		commandPath: "echo", commandArgv: []string{"2"}, withMotd: false,
  1343  // 		// 	},
  1344  // 		// },
  1345  // 	}
  1346  //
  1347  // 	exChan := make(chan string, 1)
  1348  // 	whChan := make(chan workhorse, 1)
  1349  //
  1350  // 	for _, v := range tc {
  1351  // 		t.Run(v.label, func(t *testing.T) {
  1352  //
  1353  // 			// intercept log output
  1354  // 			util.Logger.CreateLogger(io.Discard, true, slog.LevelDebug)
  1355  //
  1356  // 			var wg sync.WaitGroup
  1357  // 			var hasWorkhorse bool
  1358  // 			v.conf.serve = mockServe
  1359  // 			if strings.Contains(v.label, "shell.Wait fail") {
  1360  // 				v.conf.commandPath, _ = exec.LookPath(v.conf.commandPath)
  1361  // 				hasWorkhorse = true // last one has effective work horse.
  1362  // 			}
  1363  //
  1364  // 			wg.Add(1)
  1365  // 			go func() {
  1366  // 				defer wg.Done()
  1367  // 				<-exChan       // get the key
  1368  // 				wh := <-whChan // get the workhorse
  1369  // 				if hasWorkhorse {
  1370  // 					if wh.child == nil {
  1371  // 						t.Errorf("#test runWorker fail should return empty workhorse\n")
  1372  // 					}
  1373  // 					wh.child.Kill()
  1374  // 				} else if strings.Contains(v.label, "openPTS fail") {
  1375  // 					if wh.child != nil {
  1376  // 						t.Errorf("#test runWorker fail should return empty workhorse\n")
  1377  // 					}
  1378  // 					msg := <-exChan // get the done message
  1379  // 					if msg != v.conf.desiredPort {
  1380  // 						t.Errorf("#test runWorker fail should return %s, got %s\n", v.conf.desiredPort, msg)
  1381  // 					}
  1382  // 				} else if strings.Contains(v.label, "startShell fail") {
  1383  // 					if wh.child != nil {
  1384  // 						t.Errorf("#test runWorker fail should return empty workhorse\n")
  1385  // 					}
  1386  // 					msg := <-exChan // get the done message
  1387  // 					if msg != v.conf.desiredPort+":shutdown" {
  1388  // 						t.Errorf("#test runWorker fail should return %s, got %s\n", v.conf.desiredPort, msg)
  1389  // 					}
  1390  // 				}
  1391  // 			}()
  1392  //
  1393  // 			// TODO disable it for the time being
  1394  // 			// if hasWorkhorse {
  1395  // 			// 	if err := runWorker(&v.conf, exChan, whChan); err != nil {
  1396  // 			// 		t.Errorf("#test runWorker should not report error.\n")
  1397  // 			// 	}
  1398  // 			// } else {
  1399  // 			// 	if err := runWorker(&v.conf, exChan, whChan); err == nil {
  1400  // 			// 		t.Errorf("#test runWorker should report error.\n")
  1401  // 			// 	}
  1402  // 			// }
  1403  //
  1404  // 			wg.Wait()
  1405  // 		})
  1406  // 	}
  1407  // }
  1408  
  1409  func TestRunCloseFail(t *testing.T) {
  1410  	tc := []struct {
  1411  		label  string
  1412  		pause  int      // pause between client send and read
  1413  		resp1  string   // response of start action
  1414  		resp2  string   // response of stop action
  1415  		exp    []string // ex parameter
  1416  		finish int      // pause before shutdown message
  1417  		conf   Config
  1418  	}{
  1419  		// {
  1420  		// 	"runWorker stopped by " + frontend.AprishMsgClose, 20, frontend.AprilshMsgOpen + "7111,", frontend.AprishMsgClose + "done",
  1421  		// 	[]string{},
  1422  		// 	150,
  1423  		// 	Config{
  1424  		// 		version: false, server: true, flowControl: _FC_SKIP_PIPE_LOCK, desiredIP: "", desiredPort: "7110",
  1425  		// 		locales:     localeFlag{"LC_ALL": "en_US.UTF-8", "LANG": "en_US.UTF-8"},
  1426  		// 		commandPath: "/bin/sh", commandArgv: []string{"-sh"}, withMotd: true,
  1427  		// 	},
  1428  		// },
  1429  		// {
  1430  		// 	"runWorker stop port not exist", 5, frontend.AprilshMsgOpen + "7121,", frontend.AprishMsgClose + "port does not exist",
  1431  		// 	[]string{"7100"},
  1432  		// 	150,
  1433  		// 	Config{
  1434  		// 		version: false, server: true, flowControl: _FC_SKIP_PIPE_LOCK, desiredIP: "", desiredPort: "7120",
  1435  		// 		locales:     localeFlag{"LC_ALL": "en_US.UTF-8", "LANG": "en_US.UTF-8"},
  1436  		// 		commandPath: "/bin/sh", commandArgv: []string{"-sh"}, withMotd: true,
  1437  		// 	},
  1438  		// },
  1439  		// {
  1440  		// 	"runWorker stop wrong port number", 5, frontend.AprilshMsgOpen + "7131,", frontend.AprishMsgClose + "wrong port number",
  1441  		// 	[]string{"7121x"},
  1442  		// 	150,
  1443  		// 	Config{
  1444  		// 		version: false, server: true, flowControl: _FC_SKIP_PIPE_LOCK, desiredIP: "", desiredPort: "7130",
  1445  		// 		locales:     localeFlag{"LC_ALL": "en_US.UTF-8", "LANG": "en_US.UTF-8"},
  1446  		// 		commandPath: "/bin/sh", commandArgv: []string{"-sh"}, withMotd: true,
  1447  		// 	},
  1448  		// },
  1449  		{
  1450  			"runWorker stop unknow request", 25, frontend.AprilshMsgOpen + "7141,", frontend.AprishMsgClose + "unknow request",
  1451  			[]string{"two", "params"},
  1452  			150,
  1453  			Config{
  1454  				version: false, server: true, flowControl: _FC_SKIP_PIPE_LOCK, desiredIP: "", desiredPort: "7140",
  1455  				locales:     localeFlag{"LC_ALL": "en_US.UTF-8", "LANG": "en_US.UTF-8"},
  1456  				commandPath: "/bin/sh", commandArgv: []string{"-sh"}, withMotd: true,
  1457  			},
  1458  		},
  1459  	}
  1460  
  1461  	if runtime.GOARCH == "s390x" {
  1462  		t.Skip("for s390x, skip this test.")
  1463  	}
  1464  	for _, v := range tc {
  1465  		t.Run(v.label, func(t *testing.T) {
  1466  
  1467  			util.Logger.CreateLogger(io.Discard, true, slog.LevelDebug)
  1468  			// util.Logger.CreateLogger(os.Stderr, true, slog.LevelDebug)
  1469  
  1470  			// set serve func and runWorker func
  1471  			v.conf.serve = mockServe
  1472  			// srv := newMainSrv(&v.conf, runWorker)
  1473  			srv := newMainSrv(&v.conf)
  1474  
  1475  			/// set commandPath and commandArgv based on environment
  1476  			v.conf.commandPath = os.Getenv("SHELL")
  1477  			v.conf.commandArgv = []string{getShellNameFrom(v.conf.commandPath)}
  1478  
  1479  			// send shutdown message after some time (finish ms)
  1480  			timer1 := time.NewTimer(time.Duration(v.finish) * time.Millisecond)
  1481  			go func() {
  1482  				<-timer1.C
  1483  				srv.downChan <- true
  1484  			}()
  1485  
  1486  			srv.start(&v.conf)
  1487  
  1488  			// start a new connection
  1489  			resp1 := mockClient(v.conf.desiredPort, v.pause, frontend.AprilshMsgOpen)
  1490  			if !strings.HasPrefix(resp1, v.resp1) {
  1491  				t.Errorf("#test run expect %q got %q\n", v.resp1, resp1)
  1492  			}
  1493  			// fmt.Printf("#test got response resp1=%s\n", resp1)
  1494  
  1495  			time.Sleep(10 * time.Millisecond)
  1496  
  1497  			// stop the new connection
  1498  			resp2 := mockClient(v.conf.desiredPort, v.pause, frontend.AprishMsgClose, v.exp...)
  1499  			if !strings.HasPrefix(resp2, v.resp2) {
  1500  				t.Errorf("#test run expect %q got %q\n", v.resp1, resp2)
  1501  			}
  1502  
  1503  			// fmt.Printf("#test got response resp2=%s\n", resp2)
  1504  			// stop the connection
  1505  			if len(v.exp) > 0 {
  1506  				expect := frontend.AprishMsgClose + "done"
  1507  				resp2 := mockClient(v.conf.desiredPort, v.pause, frontend.AprishMsgClose)
  1508  				if !strings.HasPrefix(resp2, expect) {
  1509  					t.Errorf("#test run stop the connection expect %q got %q\n", v.resp1, resp2)
  1510  				}
  1511  			}
  1512  
  1513  			// fmt.Printf("#test got stop response resp2=%s\n", resp2)
  1514  			srv.wait()
  1515  		})
  1516  	}
  1517  }
  1518  
  1519  func TestRunWith2Clients(t *testing.T) {
  1520  	tc := []struct {
  1521  		label  string
  1522  		pause  int      // pause between client send and read
  1523  		resp1  string   // response of start action
  1524  		resp2  string   // response of stop action
  1525  		resp3  string   // response of additinoal open request
  1526  		exp    []string // ex parameter
  1527  		finish int      // pause before shutdown message
  1528  		conf   Config
  1529  	}{
  1530  		{
  1531  			"open aprilsh with duplicate request", 20, frontend.AprilshMsgOpen + "7101,", frontend.AprishMsgClose + "done",
  1532  			frontend.AprilshMsgOpen + "7102", []string{}, 150,
  1533  			Config{
  1534  				version: false, server: true, flowControl: _FC_SKIP_PIPE_LOCK, desiredIP: "", desiredPort: "7100",
  1535  				locales:     localeFlag{"LC_ALL": "en_US.UTF-8", "LANG": "en_US.UTF-8"},
  1536  				commandPath: "/bin/sh", commandArgv: []string{"-sh"}, withMotd: true,
  1537  			},
  1538  		},
  1539  	}
  1540  
  1541  	for _, v := range tc {
  1542  		t.Run(v.label, func(t *testing.T) {
  1543  
  1544  			// intercept stdout
  1545  			util.Logger.CreateLogger(io.Discard, true, slog.LevelDebug)
  1546  
  1547  			// set serve func and runWorker func
  1548  			v.conf.serve = mockServe
  1549  			// srv := newMainSrv(&v.conf, runWorker)
  1550  			srv := newMainSrv(&v.conf)
  1551  
  1552  			/// set commandPath and commandArgv based on environment
  1553  			v.conf.commandPath = os.Getenv("SHELL")
  1554  			v.conf.commandArgv = []string{getShellNameFrom(v.conf.commandPath)}
  1555  
  1556  			srv.start(&v.conf)
  1557  
  1558  			// start a new connection
  1559  			resp1 := mockClient(v.conf.desiredPort, v.pause, frontend.AprilshMsgOpen)
  1560  			if !strings.HasPrefix(resp1, v.resp1) {
  1561  				t.Errorf("#test first client start expect %q got %q\n", v.resp1, resp1)
  1562  			}
  1563  			// fmt.Printf("#test got 1 response %q\n", resp1)
  1564  
  1565  			// start a new connection
  1566  			resp3 := mockClient(v.conf.desiredPort, v.pause, frontend.AprilshMsgOpen)
  1567  			if !strings.HasPrefix(resp3, v.resp3) {
  1568  				t.Errorf("#test second client start expect %q got %q\n", v.resp3, resp3)
  1569  			}
  1570  			// fmt.Printf("#test got 3 response %q\n", resp3)
  1571  
  1572  			// stop the new connection
  1573  			resp2 := mockClient(v.conf.desiredPort, v.pause, frontend.AprishMsgClose, v.exp...)
  1574  			if !strings.HasPrefix(resp2, v.resp2) {
  1575  				t.Errorf("#test firt client stop expect %q got %q\n", v.resp1, resp2)
  1576  			}
  1577  			// fmt.Printf("#test got 2 response %q\n", resp2)
  1578  
  1579  			// send shutdown message after some time (finish ms)
  1580  			timer1 := time.NewTimer(time.Duration(v.finish) * time.Millisecond)
  1581  			go func() {
  1582  				<-timer1.C
  1583  				srv.downChan <- true
  1584  			}()
  1585  
  1586  			srv.wait()
  1587  		})
  1588  	}
  1589  }
  1590  
  1591  func TestStartShellError(t *testing.T) {
  1592  	tc := []struct {
  1593  		label    string
  1594  		errStr   string
  1595  		pts      *os.File
  1596  		pr       *io.PipeReader
  1597  		utmpHost string
  1598  		conf     Config
  1599  	}{
  1600  		{"first error return", "fail to start shell", os.Stdout, nil, "",
  1601  			Config{flowControl: _FC_SKIP_START_SHELL},
  1602  		},
  1603  		{"IUTF8 error return", strENOTTY, os.Stdin, nil, "",
  1604  			Config{},
  1605  		}, // os.Stdin doesn't support IUTF8 flag, startShell should failed
  1606  	}
  1607  
  1608  	util.Logger.CreateLogger(io.Discard, true, slog.LevelDebug)
  1609  
  1610  	for _, v := range tc {
  1611  		t.Run(v.label, func(t *testing.T) {
  1612  			// open pty master and slave
  1613  			ptmx, pts, _ := pty.Open()
  1614  			if v.pts == nil {
  1615  				v.pts = pts
  1616  			}
  1617  
  1618  			// open pipe for parameter
  1619  			pr, pw := io.Pipe()
  1620  			if v.pr == nil {
  1621  				v.pr = pr
  1622  			}
  1623  
  1624  			_, err := startShellProcess(v.pts, v.pr, v.utmpHost, &v.conf)
  1625  			// fmt.Printf("%#v\n", err)
  1626  
  1627  			// validate error
  1628  			if !strings.Contains(err.Error(), v.errStr) {
  1629  				t.Errorf("%q should report %q, got %q\n", v.label, v.errStr, err)
  1630  			}
  1631  
  1632  			pr.Close()
  1633  			pw.Close()
  1634  			ptmx.Close()
  1635  			pts.Close()
  1636  		})
  1637  	}
  1638  }
  1639  
  1640  func TestOpenPTS(t *testing.T) {
  1641  
  1642  	tc := []struct {
  1643  		label  string
  1644  		ws     unix.Winsize
  1645  		errStr string
  1646  	}{
  1647  		{"invalid parameter error", unix.Winsize{}, "invalid parameter"},
  1648  		{"invalid parameter error", unix.Winsize{Row: 4, Col: 4}, ""},
  1649  	}
  1650  
  1651  	for i, v := range tc {
  1652  		t.Run(v.label, func(t *testing.T) {
  1653  			var ptmx, pts *os.File
  1654  			var err error
  1655  			if i == 0 {
  1656  				ptmx, pts, err = openPTS(nil)
  1657  			} else {
  1658  				ptmx, pts, err = openPTS(&v.ws)
  1659  			}
  1660  			defer ptmx.Close()
  1661  			defer pts.Close()
  1662  			if i == 0 {
  1663  				if !strings.Contains(err.Error(), v.errStr) {
  1664  					t.Errorf("%q should report %q, got %q\n", v.label, v.errStr, err)
  1665  					fmt.Printf("%#v\n", err)
  1666  				}
  1667  			} else {
  1668  				if err != nil {
  1669  					t.Errorf("%q expect no error, got %s\n", v.label, err)
  1670  				}
  1671  			}
  1672  		})
  1673  	}
  1674  }
  1675  
  1676  // func testGetCurrentUser(t *testing.T) {
  1677  // 	// normal invocation
  1678  // 	userCurrentTest = false
  1679  // 	uid := fmt.Sprintf("%d", os.Getuid())
  1680  // 	expect, _ := user.LookupId(uid)
  1681  //
  1682  // 	got := getCurrentUser()
  1683  // 	if len(got) == 0 || expect.Username != got {
  1684  // 		t.Errorf("#test getCurrentUser expect %s, got %s\n", expect.Username, got)
  1685  // 	}
  1686  //
  1687  // 	// getCurrentUser fail
  1688  // 	old := userCurrentTest
  1689  // 	defer func() {
  1690  // 		userCurrentTest = old
  1691  // 	}()
  1692  //
  1693  // 	// intercept log output
  1694  // 	var b strings.Builder
  1695  // 	util.Logger.CreateLogger(&b, true, slog.LevelDebug)
  1696  //
  1697  // 	userCurrentTest = true
  1698  // 	got = getCurrentUser()
  1699  // 	if got != "" {
  1700  // 		t.Errorf("#test getCurrentUser expect empty string, got %s\n", got)
  1701  // 	}
  1702  // 	// restore logW
  1703  // 	// logW = log.New(os.Stdout, "WARN: ", log.Ldate|log.Ltime|log.Lshortfile)
  1704  // }
  1705  
  1706  func TestGetAvailablePort(t *testing.T) {
  1707  	tc := []struct {
  1708  		label      string
  1709  		max        int // pre-condition before getAvailabePort
  1710  		expectPort int
  1711  		expectMax  int
  1712  		workers    map[int]*workhorse
  1713  	}{
  1714  		{
  1715  			"empty worker list", 6001, 6001, 6002,
  1716  			map[int]*workhorse{},
  1717  		},
  1718  		{
  1719  			"lart gap empty worker", 6008, 6001, 6002,
  1720  			map[int]*workhorse{},
  1721  		},
  1722  		{
  1723  			"add one port", 6002, 6002, 6003,
  1724  			map[int]*workhorse{6001: {}},
  1725  		},
  1726  		{
  1727  			"shrink max", 6013, 6002, 6003,
  1728  			map[int]*workhorse{6001: {}},
  1729  		},
  1730  		{
  1731  			"right most", 6004, 6004, 6005,
  1732  			map[int]*workhorse{6001: {}, 6002: {}, 6003: {}},
  1733  		},
  1734  		{
  1735  			"left most", 6006, 6001, 6006,
  1736  			map[int]*workhorse{6003: {}, 6004: {}, 6005: {}},
  1737  		},
  1738  		{
  1739  			"middle hole", 6009, 6004, 6009,
  1740  			map[int]*workhorse{6001: {}, 6002: {}, 6003: {}, 6008: {}},
  1741  		},
  1742  		{
  1743  			"border shape hole", 6019, 6002, 6019,
  1744  			map[int]*workhorse{6001: {}, 6018: {}},
  1745  		},
  1746  	}
  1747  
  1748  	conf := &Config{desiredPort: "6000"}
  1749  
  1750  	for _, v := range tc {
  1751  		t.Run(v.label, func(t *testing.T) {
  1752  			// intercept log output
  1753  			util.Logger.CreateLogger(io.Discard, true, slog.LevelDebug)
  1754  
  1755  			srv := newMainSrv(conf)
  1756  			srv.workers = v.workers
  1757  			srv.maxPort = v.max
  1758  
  1759  			got := srv.getAvailabePort()
  1760  
  1761  			if got != v.expectPort {
  1762  				t.Errorf("%q expect port=%d, got %d\n", v.label, v.expectPort, got)
  1763  			}
  1764  
  1765  			if srv.maxPort != v.expectMax {
  1766  				t.Errorf("%q expect maxPort=%d, got %d\n", v.label, v.expectMax, srv.maxPort)
  1767  			}
  1768  		})
  1769  	}
  1770  }
  1771  
  1772  // func TestIsPortExist(t *testing.T) {
  1773  // 	tc := []struct {
  1774  // 		label string
  1775  // 		port  int
  1776  // 		ret   bool
  1777  // 	}{
  1778  // 		{"port exist", 101, true},
  1779  // 		{"port does not exist", 10, false},
  1780  // 	}
  1781  //
  1782  // 	// prepare workers data
  1783  // 	conf := &Config{desiredPort: "6000"}
  1784  //
  1785  // 	srv := newMainSrv(conf, mockRunWorker)
  1786  // 	srv.workers[100] = &workhorse{nil, os.Stderr}
  1787  // 	srv.workers[101] = &workhorse{nil, os.Stdout}
  1788  // 	srv.workers[111] = &workhorse{nil, os.Stdin}
  1789  //
  1790  // 	for _, v := range tc {
  1791  // 		t.Run(v.label, func(t *testing.T) {
  1792  // 			got := srv.isPortExist(v.port)
  1793  // 			if got != v.ret {
  1794  // 				t.Errorf("%q port %d: expect %t, got %t\n", v.label, v.port, v.ret, got)
  1795  // 			}
  1796  //
  1797  // 		})
  1798  // 	}
  1799  // }
  1800  
  1801  func BenchmarkGetAvailablePort(b *testing.B) {
  1802  
  1803  	conf := &Config{desiredPort: "100"}
  1804  	srv := newMainSrv(conf)
  1805  	srv.workers[100] = &workhorse{}
  1806  	srv.workers[101] = &workhorse{}
  1807  	srv.workers[102] = &workhorse{}
  1808  
  1809  	srv.maxPort = 102
  1810  
  1811  	for i := 0; i < b.N; i++ {
  1812  		srv.getAvailabePort()
  1813  		srv.maxPort-- // hedge maxPort++ in getAvailabePort
  1814  	}
  1815  }
  1816  
  1817  func TestCheckPortAvailable(t *testing.T) {
  1818  	tc := []struct {
  1819  		label  string
  1820  		port   int
  1821  		expect bool
  1822  	}{
  1823  		{"wrong port number", -200, false},
  1824  		{"duplicate por number", 8022, false},
  1825  	}
  1826  
  1827  	cfg := &Config{desiredPort: "8022"}
  1828  	ms := newMainSrv(cfg)
  1829  	for _, v := range tc {
  1830  		t.Run(v.label, func(t *testing.T) {
  1831  			// take the port
  1832  			ms.listen(cfg)
  1833  
  1834  			// validate tc
  1835  			got := checkPortAvailable(v.port)
  1836  			if got != v.expect {
  1837  				t.Errorf("%s expect %t, got %t\n", v.label, v.expect, got)
  1838  			}
  1839  			// clear port
  1840  			ms.conn.Close()
  1841  		})
  1842  	}
  1843  }
  1844  
  1845  func TestHandleMessage(t *testing.T) {
  1846  
  1847  	tc := []struct {
  1848  		label   string
  1849  		content string
  1850  		reason  string
  1851  	}{
  1852  		{"no colon", "no colon", "lack of ':'"},
  1853  		{"no comma", "no:comma", "lack of ','"},
  1854  		{"wrong port number", "no:comma,x", "invalid port number"},
  1855  		{"non-existence port number", "no:6000,x", "non-existence port number"},
  1856  		{"invalid serve shutdown", _ServeHeader + ":8100,not shutdown", "invalid shutdown"},
  1857  		{"kill shell process failed", _ServeHeader + ":8100,shutdown", "kill shell process failed"},
  1858  		{"invalid run shutdown", _RunHeader + ":8100,not shutdown", "invalid shutdown"},
  1859  		{"invalid shell pid", _ShellHeader + ":8100,x", "invalid shell pid"},
  1860  		{"unknown header", "unknow:8100,x", "unknown header"},
  1861  	}
  1862  
  1863  	cfg := &Config{desiredPort: "8022"}
  1864  	ms := newMainSrv(cfg)
  1865  	ms.workers[8100] = &workhorse{shellPid: 0}
  1866  	// ms.workers[8110] = &workhorse{shellPid: os.Getpid()}
  1867  
  1868  	for _, v := range tc {
  1869  		t.Run(v.label, func(t *testing.T) {
  1870  			_, err := ms.handleMessage(v.content)
  1871  			var messagError *messageError
  1872  
  1873  			if errors.As(err, &messagError) {
  1874  				if messagError.reason != v.reason {
  1875  					t.Errorf("%s expect %q, got %q\n", v.label, v.reason, messagError.reason)
  1876  					// } else {
  1877  					// 	t.Logf("go error %#v\n", messagError.err)
  1878  				}
  1879  			} else {
  1880  				t.Errorf("%s expect %v, got %v\n", v.label, messagError, err)
  1881  			}
  1882  		})
  1883  	}
  1884  }
  1885  
  1886  func TestBeginChild(t *testing.T) {
  1887  	tc := []struct {
  1888  		label      string
  1889  		pause      int    // pause between client send and read
  1890  		resp       string // response	for beginClientConn().
  1891  		shutdown   int    // pause before shutdown message
  1892  		clientConf Config
  1893  		conf       Config
  1894  	}{
  1895  		{
  1896  			"normal beginClientConn", 100, frontend.AprilshMsgOpen + "7101,", 150,
  1897  			Config{desiredPort: "7100", term: "xterm-256color", destination: getCurrentUser() + "@localhost"},
  1898  			Config{
  1899  				version: false, server: false, desiredIP: "", desiredPort: "7100",
  1900  				locales:     localeFlag{"LC_ALL": "en_US.UTF-8", "LANG": "en_US.UTF-8"},
  1901  				commandPath: "/bin/sh", commandArgv: []string{"/bin/sh"}, withMotd: false,
  1902  				// addSource: false, verbose: util.TraceLevel,
  1903  			},
  1904  		},
  1905  	}
  1906  
  1907  	for _, v := range tc {
  1908  		t.Run(v.label, func(t *testing.T) {
  1909  			util.Logger.CreateLogger(io.Discard, true, slog.LevelDebug)
  1910  			// util.Logger.CreateLogger(os.Stderr, true, slog.LevelDebug)
  1911  
  1912  			srv := newMainSrv(&v.conf)
  1913  			// send shutdown message after some time
  1914  			timer1 := time.NewTimer(time.Duration(v.shutdown) * time.Millisecond)
  1915  			go func() {
  1916  				<-timer1.C
  1917  				// prepare to shudown the mainSrv
  1918  				// syscall.Kill(syscall.Getpid(), syscall.SIGTERM)
  1919  				srv.downChan <- true
  1920  			}()
  1921  
  1922  			srv.start(&v.conf)
  1923  
  1924  			// intercept stdout
  1925  			saveStdout := os.Stdout
  1926  			r, w, _ := os.Pipe()
  1927  			os.Stdout = w
  1928  
  1929  			beginChild(&v.clientConf)
  1930  
  1931  			// restore stdout
  1932  			w.Close()
  1933  			output, _ := io.ReadAll(r)
  1934  			os.Stdout = saveStdout
  1935  			r.Close()
  1936  
  1937  			// validate the result.
  1938  			resp := strings.TrimSpace(string(output))
  1939  			// fmt.Printf("output from beginChild= %q\n", resp)
  1940  			if !strings.HasPrefix(resp, v.resp) {
  1941  				t.Errorf("#test beginChild expect start with %q got %q\n", v.resp, resp)
  1942  			}
  1943  			srv.wait()
  1944  		})
  1945  	}
  1946  }
  1947  
  1948  func TestMainBeginChild(t *testing.T) {
  1949  	tc := []struct {
  1950  		label    string
  1951  		resp     string // response for beginChild().
  1952  		shutdown int    // pause before shutdown message
  1953  		args     []string
  1954  		conf     Config
  1955  	}{
  1956  		{
  1957  			"main begin child", frontend.AprilshMsgOpen + "7151,", 150,
  1958  			[]string{"/usr/bin/apshd", "-b", "-destination", getCurrentUser() + "@localhost",
  1959  				"-p", "7150", "-t", "xterm-256color", "-vv"},
  1960  			Config{
  1961  				desiredIP: "", desiredPort: "7150", // autoStop: 1,
  1962  				locales:     localeFlag{"LC_ALL": "en_US.UTF-8", "LANG": "en_US.UTF-8"},
  1963  				commandPath: "/bin/sh", commandArgv: []string{"/bin/sh"}, withMotd: false,
  1964  				// addSource: false,  verbose: util.TraceLevel,
  1965  			},
  1966  		},
  1967  	}
  1968  
  1969  	for _, v := range tc {
  1970  		t.Run(v.label, func(t *testing.T) {
  1971  			r, w, _ := os.Pipe()
  1972  			// save stdout
  1973  			oldStdout := os.Stdout
  1974  
  1975  			// util.Logger.CreateLogger(os.Stderr, true, slog.LevelDebug)
  1976  			util.Logger.CreateLogger(io.Discard, true, slog.LevelDebug)
  1977  
  1978  			srv := newMainSrv(&v.conf)
  1979  			srv.start(&v.conf)
  1980  
  1981  			// send shutdown message after some time
  1982  			timer1 := time.NewTimer(time.Duration(v.shutdown) * time.Millisecond)
  1983  			go func() {
  1984  				<-timer1.C
  1985  				// prepare to shudown the mainSrv
  1986  				srv.downChan <- true
  1987  			}()
  1988  
  1989  			testFunc := func() {
  1990  				os.Args = v.args
  1991  				os.Stdout = w
  1992  				main()
  1993  
  1994  				// restore stdout
  1995  				os.Stdout = oldStdout
  1996  			}
  1997  
  1998  			testFunc()
  1999  			srv.wait()
  2000  
  2001  			// close pipe writer, get the output
  2002  			w.Close()
  2003  			output, _ := io.ReadAll(r)
  2004  			r.Close()
  2005  
  2006  			// validate the result.
  2007  			resp := string(output)
  2008  			if !strings.Contains(resp, v.resp) {
  2009  				t.Errorf("%q expect start with %q got \n%s\n", v.label, v.resp, resp)
  2010  			}
  2011  		})
  2012  	}
  2013  }
  2014  
  2015  // https://coralogix.com/blog/optimizing-a-golang-service-to-reduce-over-40-cpu/
  2016  func TestRunChild(t *testing.T) {
  2017  	portStr := "7200"
  2018  	port, _ := strconv.Atoi(portStr)
  2019  	serverPortStr := "7100"
  2020  
  2021  	tc := []struct {
  2022  		label     string
  2023  		shutdown  int    // pause before shutdown message
  2024  		conf      Config // config for mainSrv
  2025  		childConf Config // config for child
  2026  	}{
  2027  		{
  2028  			"early shutdown", 100,
  2029  			Config{
  2030  				desiredIP: "", desiredPort: serverPortStr,
  2031  				locales:     localeFlag{"LC_ALL": "en_US.UTF-8", "LANG": "en_US.UTF-8"},
  2032  				commandPath: "/bin/sh", commandArgv: []string{"/bin/sh"}, withMotd: false,
  2033  				addSource: true, verbose: util.DebugLevel,
  2034  			},
  2035  			Config{desiredPort: portStr, term: "xterm", destination: getCurrentUser() + "@localhost",
  2036  				serve: serve, verbose: 0, addSource: false},
  2037  		},
  2038  		{
  2039  			"skip pipe lock", 100,
  2040  			Config{
  2041  				desiredIP: "", desiredPort: serverPortStr,
  2042  				locales:     localeFlag{"LC_ALL": "en_US.UTF-8", "LANG": "en_US.UTF-8"},
  2043  				commandPath: "/bin/sh", commandArgv: []string{"/bin/sh"}, withMotd: false,
  2044  				addSource: true, verbose: util.DebugLevel,
  2045  			},
  2046  			Config{desiredPort: portStr, destination: getCurrentUser() + "@localhost",
  2047  				commandPath: "/bin/sh", commandArgv: []string{"/bin/sh"}, withMotd: true,
  2048  				flowControl: _FC_SKIP_PIPE_LOCK, serve: serve, verbose: 0, addSource: false},
  2049  		},
  2050  		{
  2051  			"skip start shell", 100,
  2052  			Config{
  2053  				desiredIP: "", desiredPort: serverPortStr,
  2054  				locales:     localeFlag{"LC_ALL": "en_US.UTF-8", "LANG": "en_US.UTF-8"},
  2055  				commandPath: "/bin/sh", commandArgv: []string{"/bin/sh"}, withMotd: false,
  2056  				addSource: true, verbose: util.DebugLevel,
  2057  			},
  2058  			Config{desiredPort: portStr, destination: getCurrentUser() + "@localhost",
  2059  				commandPath: "/bin/sh", commandArgv: []string{"/bin/sh"}, withMotd: false,
  2060  				flowControl: _FC_SKIP_START_SHELL, serve: serve, verbose: 0, addSource: false},
  2061  		},
  2062  		{
  2063  			"open pts failed", 100,
  2064  			Config{
  2065  				desiredIP: "", desiredPort: serverPortStr,
  2066  				locales:     localeFlag{"LC_ALL": "en_US.UTF-8", "LANG": "en_US.UTF-8"},
  2067  				commandPath: "/bin/sh", commandArgv: []string{"/bin/sh"}, withMotd: false,
  2068  				addSource: true, verbose: util.DebugLevel,
  2069  			},
  2070  			Config{desiredPort: portStr, term: "xterm", destination: getCurrentUser() + "@localhost",
  2071  				commandPath: "/bin/sh", commandArgv: []string{"/bin/sh"}, withMotd: false,
  2072  				flowControl: _FC_OPEN_PTS_FAIL, serve: serve, verbose: 0, addSource: false},
  2073  		},
  2074  	}
  2075  
  2076  	for _, v := range tc {
  2077  		t.Run(v.label, func(t *testing.T) {
  2078  			util.Logger.CreateLogger(io.Discard, true, slog.LevelDebug)
  2079  			// util.Logger.CreateLogger(os.Stderr, true, slog.LevelDebug)
  2080  
  2081  			srv := newMainSrv(&v.conf)
  2082  
  2083  			// listen UDS
  2084  			uxConn, err := srv.uxListen()
  2085  			if err != nil {
  2086  				util.Logger.Warn("listen unix domain socket failed", "error", err)
  2087  				return
  2088  			}
  2089  
  2090  			// receive UDS feed
  2091  			srv.wg.Add(1)
  2092  			go func() {
  2093  				srv.uxServe(uxConn, 2, func(c chan string, resp string) {
  2094  					ret, err := srv.handleMessage(resp)
  2095  					if err != nil {
  2096  						util.Logger.Warn("fake uxServe failed", "error", err)
  2097  						return
  2098  					}
  2099  
  2100  					if ret != "" {
  2101  						util.Logger.Debug("fake uxServe got key", "key", ret)
  2102  						return
  2103  					}
  2104  
  2105  					// stop uxServe if the worker is done
  2106  					if resp == _RunHeader+":"+portStr+",shutdown" {
  2107  						srv.uxdownChan <- true
  2108  					}
  2109  
  2110  					// stop shell process once we got shell pid
  2111  					if strings.HasPrefix(resp, _ShellHeader+":"+portStr) {
  2112  						if srv.workers[port].shellPid > 0 {
  2113  							util.Logger.Debug("fake uxServe kill the shell", "shellPid", srv.workers[port].shellPid)
  2114  							shell, err := os.FindProcess(srv.workers[port].shellPid)
  2115  							if err = shell.Kill(); err != nil {
  2116  								util.Logger.Debug("fake uxServe", "error", err)
  2117  							}
  2118  						}
  2119  					}
  2120  				})
  2121  				srv.wg.Done()
  2122  			}()
  2123  
  2124  			// start runChild
  2125  			srv.wg.Add(1)
  2126  			go func() {
  2127  				// add this worker
  2128  				srv.workers[port] = &workhorse{}
  2129  				runChild(&v.childConf)
  2130  				srv.wg.Done()
  2131  			}()
  2132  
  2133  			if strings.Contains(v.label, "shutdown") {
  2134  				// send shutdown message after some time
  2135  				timer1 := time.NewTimer(time.Duration(v.shutdown) * time.Millisecond)
  2136  				go func() {
  2137  					<-timer1.C
  2138  					// prepare to shudown the mainSrv
  2139  					syscall.Kill(syscall.Getpid(), syscall.SIGTERM)
  2140  					srv.uxdownChan <- true
  2141  				}()
  2142  			}
  2143  
  2144  			// validate if we can quit this test
  2145  			srv.wait()
  2146  		})
  2147  	}
  2148  }
  2149  
  2150  func TestRunFail(t *testing.T) {
  2151  	m := mainSrv{}
  2152  	cfg := &Config{}
  2153  	m.run(cfg)
  2154  	// run return if m.conn is nil
  2155  }
  2156  
  2157  func TestUxListenFail(t *testing.T) {
  2158  	old := unixsockAddr
  2159  	defer func() {
  2160  		unixsockAddr = old
  2161  	}()
  2162  
  2163  	unixsockAddr = "/etc/hosts"
  2164  	m := mainSrv{}
  2165  	_, err := m.uxListen()
  2166  	if err == nil {
  2167  		t.Errorf("uxListen expect error got nil\n")
  2168  	}
  2169  }
  2170  
  2171  func TestRunChildFail(t *testing.T) {
  2172  	old := unixsockAddr
  2173  	defer func() {
  2174  		unixsockAddr = old
  2175  	}()
  2176  
  2177  	unixsockAddr = "/etc/hosts"
  2178  	err := runChild(&Config{})
  2179  	if err == nil {
  2180  		t.Errorf("uxListen expect error got nil\n")
  2181  	}
  2182  }
  2183  
  2184  func TestMainRunChildFail(t *testing.T) {
  2185  	old := unixsockAddr
  2186  	defer func() {
  2187  		unixsockAddr = old
  2188  	}()
  2189  
  2190  	args := []string{"/usr/bin/apshd", "-c", "-p", "6160", "-vv"}
  2191  
  2192  	r, w, _ := os.Pipe()
  2193  	// save stdout
  2194  	oldStderr := os.Stderr
  2195  
  2196  	// error condition
  2197  	unixsockAddr = "/etc/hosts"
  2198  
  2199  	// run the test
  2200  	testFunc := func() {
  2201  		os.Args = args
  2202  		os.Stderr = w
  2203  		main()
  2204  
  2205  		// restore stdout
  2206  		os.Stderr = oldStderr
  2207  	}
  2208  	testFunc()
  2209  
  2210  	// close pipe writer, get the output
  2211  	w.Close()
  2212  	output, _ := io.ReadAll(r)
  2213  	r.Close()
  2214  
  2215  	// validate the result
  2216  	got := string(output)
  2217  	expect := "init uds client failed"
  2218  	if !strings.Contains(got, expect) {
  2219  		t.Errorf("runChild expect %q got %q\n", expect, got)
  2220  	}
  2221  }
  2222  
  2223  func TestStartFail2(t *testing.T) {
  2224  
  2225  	// intercept log
  2226  	var w strings.Builder
  2227  	util.Logger.CreateLogger(&w, true, slog.LevelDebug)
  2228  
  2229  	cfg := &Config{desiredPort: "7230"}
  2230  	m := mainSrv{}
  2231  
  2232  	// this will cause  uxListen failed
  2233  	old := unixsockAddr
  2234  	defer func() {
  2235  		unixsockAddr = old
  2236  	}()
  2237  
  2238  	// change unixsocke to error file
  2239  	unixsockAddr = "/etc/hosts"
  2240  	m.start(cfg)
  2241  	// close udp connection
  2242  	m.conn.Close()
  2243  
  2244  	//check the log
  2245  	got := w.String()
  2246  	expect := "listen unix domain socket failed"
  2247  	if !strings.Contains(got, expect) {
  2248  		t.Errorf("mainSrv.start() expect %q, got \n%s\n", expect, got)
  2249  	}
  2250  }
  2251  
  2252  func TestStartChildFail(t *testing.T) {
  2253  	tc := []struct {
  2254  		label  string
  2255  		req    string
  2256  		conf   Config
  2257  		expect string
  2258  	}{
  2259  		{"destination without @", "a:b,cd",
  2260  			Config{desiredPort: "6510"}, "open aprilsh:malform destination"},
  2261  		{"startShellProcess failed: DebugLevel", "open aprilsh:xterm-fake," + getCurrentUser() + "@fakehost",
  2262  			Config{desiredPort: "6511", verbose: util.DebugLevel},
  2263  			"start child got key timeout"},
  2264  		{"startShellProcess failed: TraceLevel", "open aprilsh:xterm-fake," + getCurrentUser() + "@fakehost",
  2265  			Config{desiredPort: "6512", verbose: util.TraceLevel},
  2266  			"start child got key timeout"},
  2267  		{"startShellProcess failed: addSource", "open aprilsh:xterm-fake," + getCurrentUser() + "@fakehost",
  2268  			Config{desiredPort: "6513", addSource: true},
  2269  			"start child got key timeout"},
  2270  	}
  2271  
  2272  	for _, v := range tc {
  2273  		t.Run(v.label, func(t *testing.T) {
  2274  			// prepare the server
  2275  			m := newMainSrv(&v.conf)
  2276  			m.timeout = 10
  2277  			m.listen(&v.conf)
  2278  
  2279  			var wg sync.WaitGroup
  2280  
  2281  			var out strings.Builder
  2282  			// util.Logger.CreateLogger(os.Stderr, true, slog.LevelDebug)
  2283  			util.Logger.CreateLogger(&out, true, slog.LevelDebug)
  2284  
  2285  			// reading and validate the message
  2286  			wg.Add(1)
  2287  			go func() {
  2288  				defer wg.Done()
  2289  
  2290  				buf := make([]byte, 128)
  2291  				shutdown := false
  2292  				for {
  2293  					select {
  2294  					case <-m.downChan:
  2295  						shutdown = true
  2296  					default:
  2297  					}
  2298  					if shutdown {
  2299  						util.Logger.Debug("fake receiver shudown")
  2300  						break
  2301  					}
  2302  
  2303  					m.conn.SetDeadline(time.Now().Add(time.Millisecond * time.Duration(m.timeout)))
  2304  					m.conn.ReadFromUDP(buf)
  2305  				}
  2306  			}()
  2307  
  2308  			// run startChild
  2309  			addr, err := net.ResolveUDPAddr("udp", "localhost:"+v.conf.desiredPort)
  2310  			if err != nil {
  2311  				t.Errorf("startChild failed")
  2312  			} else {
  2313  				old := os.Getenv("SHELL")
  2314  				os.Setenv("SHELL", "")
  2315  				m.startChild(v.req, addr, v.conf)
  2316  				os.Setenv("SHELL", old)
  2317  			}
  2318  
  2319  			// shudown reader
  2320  			m.downChan <- true
  2321  			wg.Wait()
  2322  			m.conn.Close()
  2323  
  2324  			// validate the result
  2325  			got := out.String()
  2326  			if !strings.Contains(got, v.expect) {
  2327  				t.Errorf("startChild expect %q, got \n%s\n", v.expect, got)
  2328  			}
  2329  		})
  2330  	}
  2331  }
  2332  
  2333  func TestBuildConfig2(t *testing.T) {
  2334  	cfg := &Config{flowControl: _FC_NON_UTF8_LOCALE}
  2335  
  2336  	r, w, _ := os.Pipe()
  2337  	// save stdout
  2338  	olderr := os.Stderr
  2339  	oldout := os.Stdout
  2340  	os.Stderr = w
  2341  	os.Stdout = w
  2342  
  2343  	_, ok := cfg.buildConfig()
  2344  
  2345  	// close pipe writer, get the output
  2346  	w.Close()
  2347  	output, _ := io.ReadAll(r)
  2348  	r.Close()
  2349  
  2350  	os.Stderr = olderr
  2351  	os.Stdout = oldout
  2352  
  2353  	// validate the result
  2354  	got := string(output)
  2355  	expect := "needs a UTF-8 native locale to run"
  2356  	if !ok && strings.Contains(got, expect) {
  2357  	} else {
  2358  		t.Errorf("runChild expect %q got \n%s\n", expect, got)
  2359  	}
  2360  }
  2361  
  2362  func TestMessageError(t *testing.T) {
  2363  	tc := []struct {
  2364  		label  string
  2365  		e      *messageError
  2366  		expect string
  2367  	}{
  2368  		{"nil error", &messageError{}, "<nil>"},
  2369  		{"reason + error", &messageError{reason: "got apple", err: errors.New("bad apple")}, "got apple: bad apple"},
  2370  		{"only error", &messageError{err: errors.New("just apple")}, ": just apple"},
  2371  	}
  2372  
  2373  	for _, v := range tc {
  2374  		t.Run(v.label, func(t *testing.T) {
  2375  			got := v.e.Error()
  2376  			if got != v.expect {
  2377  				t.Errorf("messageError sould return %q got %q\n", v.expect, got)
  2378  			}
  2379  		})
  2380  	}
  2381  }
  2382  
  2383  func TestCloseChild(t *testing.T) {
  2384  	tc := []struct {
  2385  		label   string
  2386  		req     string
  2387  		holders []int
  2388  		conf    *Config
  2389  		expect  string
  2390  	}{
  2391  		{"placeHolder port", frontend.AprishMsgClose + "6252", []int{6252},
  2392  			&Config{desiredPort: "6250"}, "close port is a holder"},
  2393  		{"wrong port number", frontend.AprishMsgClose + "625a", nil,
  2394  			&Config{desiredPort: "6250"}, "wrong port number"},
  2395  		{"port doesn't exist", frontend.AprishMsgClose + "6252", nil,
  2396  			&Config{desiredPort: "6250"}, "port does not exist"},
  2397  	}
  2398  
  2399  	for _, v := range tc {
  2400  		t.Run(v.label, func(t *testing.T) {
  2401  			// prepare the server
  2402  			m := newMainSrv(v.conf)
  2403  			m.listen(v.conf)
  2404  
  2405  			var wg sync.WaitGroup
  2406  
  2407  			var out strings.Builder
  2408  			// util.Logger.CreateLogger(os.Stderr, true, slog.LevelDebug)
  2409  			util.Logger.CreateLogger(&out, true, slog.LevelDebug)
  2410  
  2411  			// create place holders data
  2412  			for _, value := range v.holders {
  2413  				m.workers[value] = &workhorse{}
  2414  			}
  2415  			// reading the udp response
  2416  			wg.Add(1)
  2417  			go func() {
  2418  				defer wg.Done()
  2419  
  2420  				buf := make([]byte, 128)
  2421  				shutdown := false
  2422  				for {
  2423  					select { // waiting for shutdown
  2424  					case <-m.downChan:
  2425  						shutdown = true
  2426  					default:
  2427  					}
  2428  					if shutdown {
  2429  						util.Logger.Debug("fake receiver shudown")
  2430  						break
  2431  					}
  2432  
  2433  					m.conn.SetDeadline(time.Now().Add(time.Millisecond * time.Duration(m.timeout)))
  2434  					m.conn.ReadFromUDP(buf)
  2435  				}
  2436  			}()
  2437  
  2438  			// run closeChild
  2439  			addr, err := net.ResolveUDPAddr("udp", "localhost:"+v.conf.desiredPort)
  2440  			if err != nil {
  2441  				t.Errorf("get address fail: %s\n", err)
  2442  			} else {
  2443  				m.closeChild(v.req, addr)
  2444  			}
  2445  
  2446  			// shudown reader
  2447  			m.downChan <- true
  2448  			wg.Wait()
  2449  			m.conn.Close()
  2450  
  2451  			// validate the result
  2452  			got := out.String()
  2453  			// fmt.Println(got)
  2454  			if !strings.Contains(got, v.expect) {
  2455  				t.Errorf("startChild expect %q, got \n%s\n", v.expect, got)
  2456  			}
  2457  		})
  2458  	}
  2459  }