github.com/ericwq/aprilsh@v0.0.0-20240517091432-958bc568daa0/frontend/server/server_test.go (about) 1 // Copyright 2022~2024 wangqi. All rights reserved. 2 // Use of this source code is governed by a MIT-style 3 // license that can be found in the LICENSE file. 4 5 package main 6 7 import ( 8 "bytes" 9 "errors" 10 "flag" 11 "fmt" 12 "io" 13 "net" 14 "os" 15 "os/exec" 16 "reflect" 17 "runtime" 18 "strconv" 19 "strings" 20 "sync" 21 "syscall" 22 "testing" 23 "time" 24 25 "log/slog" 26 27 "github.com/creack/pty" 28 "github.com/ericwq/aprilsh/frontend" 29 "github.com/ericwq/aprilsh/network" 30 "github.com/ericwq/aprilsh/statesync" 31 "github.com/ericwq/aprilsh/util" 32 "golang.org/x/sys/unix" 33 ) 34 35 func TestPrintMotd(t *testing.T) { 36 // darwin doesn't has the following motd files, so we add /etc/hosts for testing. 37 files := []string{"/run/motd.dynamic", "/var/run/motd.dynamic", "/etc/motd", "/etc/hosts"} 38 39 var output bytes.Buffer 40 41 found := false 42 for i := range files { 43 output.Reset() 44 if printMotd(&output, files[i]) { 45 if output.Len() > 0 { // we got and print the file content 46 found = true 47 break 48 } 49 } 50 } 51 52 // validate the result 53 if !found { 54 t.Errorf("#test expect found %s, found nothing\n", files) 55 } 56 57 output.Reset() 58 59 // creat a .hide file and write long token into it 60 fName := ".hide" 61 hide, _ := os.Create(fName) 62 for i := 0; i < 1025; i++ { 63 data := bytes.Repeat([]byte{'s'}, 64) 64 hide.Write(data) 65 } 66 hide.Close() 67 68 if printMotd(&output, fName) { 69 t.Errorf("#test printMotd should return false, instead it return true.") 70 } 71 72 os.Remove(fName) 73 } 74 75 func TestPrintVersion(t *testing.T) { 76 // intercept stdout 77 saveStdout := os.Stdout 78 r, w, _ := os.Pipe() 79 os.Stdout = w 80 // initLog() 81 82 expect := []string{frontend.CommandServerName, "version", "git commit", "wangqi <ericwq057@qq.com>"} 83 84 printVersion() 85 86 // restore stdout 87 w.Close() 88 b, _ := io.ReadAll(r) 89 os.Stdout = saveStdout 90 r.Close() 91 92 // validate the result 93 result := string(b) 94 found := 0 95 for i := range expect { 96 if strings.Contains(result, expect[i]) { 97 found++ 98 } 99 } 100 if found != len(expect) { 101 t.Errorf("#test printVersion expect %q, got %q\n", expect, result) 102 } 103 } 104 105 var cmdOptions = "[-s] [-v[v]] [-i LOCALADDR] [-p PORT[:PORT2]] [-l NAME=VALUE] [-- command...]" 106 107 func TestPrintUsage(t *testing.T) { 108 tc := []struct { 109 label string 110 hints string 111 expect []string 112 }{ 113 {"no hint", "", []string{"Usage:", frontend.CommandServerName, cmdOptions}}, 114 {"some hints", "some hints", []string{"Usage:", frontend.CommandServerName, "some hints", cmdOptions}}, 115 } 116 117 for _, v := range tc { 118 t.Run(v.label, func(t *testing.T) { 119 120 out := captureOutputRun(func() { 121 frontend.PrintUsage(v.hints, usage) 122 }) 123 124 // validate the result 125 result := string(out) 126 found := 0 127 for i := range v.expect { 128 if strings.Contains(result, v.expect[i]) { 129 found++ 130 } 131 } 132 if found != len(v.expect) { 133 t.Errorf("#test printUsage expect %s, got %s\n", v.expect, result) 134 } 135 }) 136 } 137 } 138 139 func TestChdirHomedir(t *testing.T) { 140 // save the current dir 141 oldPwd := os.Getenv("PWD") 142 143 // use the HOME 144 got := "" 145 if !chdirHomedir("") { 146 got = os.Getenv("PWD") 147 t.Errorf("#test chdirHomedir expect change to home directory, got %s\n", got) 148 } 149 150 // validate the PWD 151 got = os.Getenv("PWD") 152 // fmt.Printf("#test chdirHomedir home=%q\n", got) 153 if got == oldPwd { 154 t.Errorf("#test chdirHomedir home dir %q, is different from old dir %q\n", got, oldPwd) 155 } 156 157 // unset HOME 158 os.Unsetenv("HOME") 159 // validate the false 160 if chdirHomedir("") { 161 t.Errorf("#test chdirHomedir return false.\n") 162 } 163 164 // use the parameter as HOME 165 if chdirHomedir("/does/not/exist") { 166 t.Errorf("#test chdirHomedir should return false\n") 167 } 168 169 // restore the current dir and PWD 170 os.Chdir(oldPwd) 171 os.Setenv("PWD", oldPwd) 172 } 173 174 func TestGetHomeDir(t *testing.T) { 175 tc := []struct { 176 label string 177 env string 178 expect string 179 }{ 180 {"normal case", "/home/aprish", "/home/aprish"}, 181 {"no HOME case", "", ""}, // for unix anc macOS, no HOME means getHomeDir() return "" 182 } 183 184 for _, v := range tc { 185 oldHome := os.Getenv("HOME") 186 if v.env == "" { // unset HOME env 187 os.Unsetenv("HOME") 188 } else { 189 os.Setenv("HOME", v.env) 190 } 191 got := getHomeDir() 192 193 if got != v.expect { 194 t.Errorf("%s getHomeDir() expect %q got %q\n", v.label, v.expect, got) 195 } 196 os.Setenv("HOME", oldHome) 197 } 198 } 199 200 func TestMotdHushed(t *testing.T) { 201 label := "#test motdHushed " 202 if motdHushed() != false { 203 t.Errorf("%s should report false, got %t\n", label, motdHushed()) 204 } 205 206 cmd := exec.Command("touch", ".hushlogin") 207 if err := cmd.Run(); err != nil { 208 t.Errorf("%s create .hushlogin failed, %s\n", label, err) 209 } 210 if motdHushed() != true { 211 t.Errorf("%s should report true, got %t\n", label, motdHushed()) 212 } 213 214 cmd = exec.Command("rm", ".hushlogin") 215 if err := cmd.Run(); err != nil { 216 t.Errorf("%s delete .hushlogin failed, %s\n", label, err) 217 } 218 } 219 220 func TestMainHelp(t *testing.T) { 221 testHelpFunc := func() { 222 // prepare data 223 os.Args = []string{frontend.CommandServerName, "--help"} 224 // test help 225 main() 226 } 227 228 out := captureOutputRun(testHelpFunc) 229 230 // validate result 231 expect := []string{"Usage:", frontend.CommandServerName, cmdOptions} 232 233 // validate the result 234 result := string(out) 235 found := 0 236 for i := range expect { 237 if strings.Contains(result, expect[i]) { 238 found++ 239 } 240 } 241 if found != len(expect) { 242 t.Errorf("#test printUsage expect %q, got %q\n", expect, result) 243 } 244 } 245 246 // capture the stdout and run the 247 func captureOutputRun(f func()) []byte { 248 // save the stdout,stderr and create replaced pipe 249 stderr := os.Stderr 250 stdout := os.Stdout 251 r, w, _ := os.Pipe() 252 // replace stdout,stderr with pipe writer 253 // alll the output to stdout,stderr is captured 254 os.Stderr = w 255 os.Stdout = w 256 257 util.Logger.CreateLogger(w, true, slog.LevelDebug) 258 259 // os.Args is a "global variable", so keep the state from before the test, and restore it after. 260 oldArgs := os.Args 261 defer func() { os.Args = oldArgs }() 262 263 f() 264 265 // close pipe writer 266 w.Close() 267 // get the output 268 out, _ := io.ReadAll(r) 269 os.Stderr = stderr 270 os.Stdout = stdout 271 r.Close() 272 273 return out 274 } 275 276 func TestMainVersion(t *testing.T) { 277 278 testHelpFunc := func() { 279 // prepare data 280 os.Args = []string{frontend.CommandServerName, "--version"} 281 // test 282 main() 283 284 } 285 286 out := captureOutputRun(testHelpFunc) 287 288 // validate result 289 expect := []string{frontend.CommandServerName, "go version", "git commit", "wangqi <ericwq057@qq.com>", 290 "remote shell support intermittent or mobile network."} 291 result := string(out) 292 found := 0 293 for i := range expect { 294 if strings.Contains(result, expect[i]) { 295 found++ 296 } 297 } 298 if found != len(expect) { 299 t.Errorf("#test printVersion expect %q, got %q\n", expect, result) 300 } 301 } 302 303 func TestMainParseFlagsError(t *testing.T) { 304 testFunc := func() { 305 // prepare data 306 os.Args = []string{frontend.CommandServerName, "--foo"} 307 // test 308 main() 309 } 310 311 out := captureOutputRun(testFunc) 312 313 // validate result 314 expect := []string{"flag provided but not defined: -foo"} 315 found := 0 316 for i := range expect { 317 if strings.Contains(string(out), expect[i]) { 318 found++ 319 } 320 } 321 if found != len(expect) { 322 t.Errorf("#test parserError expect %q, got \n%s\n", expect, out) 323 } 324 } 325 326 func TestParseFlagsUsage(t *testing.T) { 327 usageArgs := []string{"-help", "-h", "--help"} 328 329 for _, arg := range usageArgs { 330 t.Run(arg, func(t *testing.T) { 331 conf, output, err := parseFlags("prog", []string{arg}) 332 if err != flag.ErrHelp { 333 t.Errorf("err got %v, want ErrHelp", err) 334 } 335 if conf != nil { 336 t.Errorf("conf got %v, want nil", conf) 337 } 338 if strings.Index(output, "Usage of") < 0 { 339 t.Errorf("output can't find \"Usage of\": %q", output) 340 } 341 }) 342 } 343 } 344 345 func TestMainRun(t *testing.T) { 346 tc := []struct { 347 label string 348 args []string 349 expect []string 350 }{ 351 {"run main and killed by signal", 352 []string{frontend.CommandServerName, "-locale", 353 "LC_ALL=en_US.UTF-8", "-p", "6100", "--", "/bin/sh", "-sh"}, 354 []string{frontend.CommandServerName, "start listening on", "gitTag", 355 /* "got signal: SIGHUP", */ "got signal: SIGTERM or SIGINT", 356 "stop listening", "6100"}}, 357 {"main killed by -a", // auto stop after 1 second 358 []string{frontend.CommandServerName, "-verbose", "-auto", "1", "-locale", 359 "LC_ALL=en_US.UTF-8", "-p", "6200", "--", "/bin/sh", "-sh"}, 360 []string{frontend.CommandServerName, "start listening on", "gitTag", 361 "stop listening", "6200"}}, 362 {"main killed by -a, write to syslog", // auto stop after 1 second 363 []string{frontend.CommandServerName, "-auto", "1", "-locale", 364 "LC_ALL=en_US.UTF-8", "-p", "6300", "--", "/bin/sh", "-sh"}, 365 []string{}}, // log write to syslog, we can't get anything 366 } 367 368 for _, v := range tc { 369 370 if strings.Contains(v.label, "by signal") { 371 // shutdown after 15ms 372 time.AfterFunc(time.Duration(15)*time.Millisecond, func() { 373 util.Logger.Debug("#test kill process by signal") 374 syscall.Kill(os.Getpid(), syscall.SIGTERM) 375 // syscall.Kill(os.Getpid(), syscall.SIGHUP) 376 }) 377 } 378 379 testFunc := func() { 380 os.Args = v.args 381 main() 382 } 383 384 out := captureOutputRun(testFunc) 385 386 // validate the result from printWelcome 387 result := string(out) 388 found := 0 389 for i := range v.expect { 390 if strings.Contains(result, v.expect[i]) { 391 // fmt.Printf("found %s\n", expect[i]) 392 found++ 393 } 394 } 395 if found != len(v.expect) { 396 t.Errorf("#test expect %q, got %s\n", v.expect, result) 397 } 398 // fmt.Printf("###\n%s\n###\n", string(out)) 399 } 400 } 401 402 func testMainBuildConfigFail(t *testing.T) { 403 testFunc := func() { 404 // prepare parameter 405 os.Args = []string{frontend.CommandServerName, "-locale", "LC_ALL=en_US.UTF-8", 406 "-p", "6100", "--", "/bin/sh", "-sh"} 407 // test 408 main() 409 } 410 411 // prepare for buildConfig fail 412 // buildConfigTest = true 413 out := captureOutputRun(testFunc) 414 415 // restore the condition 416 // buildConfigTest = false 417 418 // validate the result 419 expect := []string{"needs a UTF-8 native locale to run"} 420 result := string(out) 421 found := 0 422 for i := range expect { 423 if strings.Contains(result, expect[i]) { 424 found++ 425 } 426 } 427 if found != len(expect) { 428 t.Errorf("#test buildConfig() expect %q, got %s\n", expect, result) 429 } 430 } 431 432 func TestParseFlagsCorrect(t *testing.T) { 433 tc := []struct { 434 args []string 435 conf Config 436 }{ 437 { 438 []string{"-locale", "ALL=en_US.UTF-8", "-l", "LANG=UTF-8"}, 439 Config{ 440 version: false, server: false, verbose: 0, desiredIP: "", desiredPort: "8100", 441 locales: localeFlag{"ALL": "en_US.UTF-8", "LANG": "UTF-8"}, 442 commandPath: "", commandArgv: []string{}, withMotd: false, 443 }, 444 }, 445 { 446 []string{"--", "/bin/sh", "-sh"}, 447 Config{ 448 version: false, server: false, verbose: 0, desiredIP: "", desiredPort: "8100", 449 locales: localeFlag{}, 450 commandPath: "", commandArgv: []string{"/bin/sh", "-sh"}, withMotd: false, 451 }, 452 }, 453 { 454 []string{"--", ""}, 455 Config{ 456 version: false, server: false, verbose: 0, desiredIP: "", desiredPort: "8100", 457 locales: localeFlag{}, 458 commandPath: "", commandArgv: []string{""}, withMotd: false, 459 }, 460 }, 461 } 462 463 for _, v := range tc { 464 t.Run(strings.Join(v.args, " "), func(t *testing.T) { 465 conf, output, err := parseFlags("prog", v.args) 466 if err != nil { 467 t.Errorf("err got %v, want nil", err) 468 } 469 if output != "" { 470 t.Errorf("output got %q, want empty", output) 471 } 472 if !reflect.DeepEqual(*conf, v.conf) { 473 t.Logf("#test parseFlags got commandArgv=%+v\n", conf.commandArgv) 474 t.Errorf("conf got \n%+v, want \n%+v", *conf, v.conf) 475 } 476 }) 477 } 478 } 479 480 func TestGetShell(t *testing.T) { 481 tc := []struct { 482 label string 483 expect string 484 }{ 485 {"get unix shell from cmd", "fill later"}, 486 } 487 488 var err error 489 tc[0].expect, err = util.GetShell() 490 if err != nil { 491 t.Errorf("#test getShell() reports %q\n", err) 492 } 493 494 for _, v := range tc { 495 if got, _ := util.GetShell(); got != v.expect { 496 if got != v.expect { 497 t.Errorf("#test getShell() %s expect %q, got %q\n", v.label, v.expect, got) 498 } 499 } 500 } 501 } 502 503 func TestParseFlagsError(t *testing.T) { 504 tests := []struct { 505 args []string 506 errstr string 507 }{ 508 {[]string{"-foo"}, "flag provided but not defined"}, 509 // {[]string{"-color", "joe"}, "invalid value"}, 510 {[]string{"-locale", "a=b=c"}, "malform locale parameter"}, 511 } 512 513 for _, tt := range tests { 514 t.Run(strings.Join(tt.args, " "), func(t *testing.T) { 515 conf, output, err := parseFlags("prog", tt.args) 516 if conf != nil { 517 t.Errorf("conf got %v, want nil", conf) 518 } 519 if strings.Index(err.Error(), tt.errstr) < 0 { 520 t.Errorf("err got %q, want to find %q", err.Error(), tt.errstr) 521 } 522 if strings.Index(output, "Usage of prog") < 0 { 523 t.Errorf("output got %q", output) 524 } 525 }) 526 } 527 } 528 529 // func TestMainParameters(t *testing.T) { 530 // // flag is a global variable, reset it before test 531 // flag.CommandLine = flag.NewFlagSet("TestMainParameters", flag.ExitOnError) 532 // testParaFunc := func() { 533 // // prepare data 534 // os.Args = []string{COMMAND_NAME, "--", "/bin/sh","-sh"} //"-l LC_ALL=en_US.UTF-8", "--"} 535 // // test 536 // main() 537 // } 538 // 539 // out := captureStdoutRun(testParaFunc) 540 // 541 // // validate result 542 // expect := []string{"main", "commandPath=", "commandArgv=", "withMotd=", "locales=", "color="} 543 // result := string(out) 544 // found := 0 545 // for i := range expect { 546 // if strings.Contains(result, expect[i]) { 547 // found++ 548 // } 549 // } 550 // if found != len(expect) { 551 // t.Errorf("#test main() expect %s, got %s\n", expect, result) 552 // } 553 // } 554 555 func TestMainServerPortrangeError(t *testing.T) { 556 testFunc := func() { 557 os.Args = []string{frontend.CommandServerName, "-s", "-p=3a"} 558 os.Setenv("SSH_CONNECTION", "172.17.0.1 58774 172.17.0.2 22") 559 main() 560 } 561 562 out := captureOutputRun(testFunc) 563 // validate port range check 564 expect := "Bad UDP port" 565 got := string(out) 566 if !strings.Contains(got, expect) { 567 t.Errorf("#test --port should contains %q, got %s\n", expect, got) 568 } 569 } 570 571 func TestGetSSHip(t *testing.T) { 572 tc := []struct { 573 label string 574 env string 575 expect string 576 ok bool 577 }{ 578 {"no env variable", "", "Warning: SSH_CONNECTION not found; binding to any interface.", false}, 579 {"ipv4 address", "172.17.0.1 58774 172.17.0.2 22", "172.17.0.2", true}, 580 {"malform variable", " 1 2 3 4", 581 "Warning: Could not parse SSH_CONNECTION; binding to any interface.", false}, 582 {"ipv6 address", "fe80::14d5:1215:f8c9:11fa%en0 42000 fe80::aede:48ff:fe00:1122%en5 22", 583 "fe80::aede:48ff:fe00:1122%en5", true}, 584 {"ipv4 mapped address", "::FFFF:172.17.0.1 42200 ::FFFF:129.144.52.38 22", "129.144.52.38", true}, 585 } 586 587 for _, v := range tc { 588 589 os.Setenv("SSH_CONNECTION", v.env) 590 got, ok := getSSHip() 591 if got != v.expect || ok != v.ok { 592 t.Errorf("%q expect %q, got %q, ok=%t\n", v.label, v.expect, got, ok) 593 } 594 } 595 } 596 597 func TestGetShellNameFrom(t *testing.T) { 598 tc := []struct { 599 label string 600 shellPath string 601 shellName string 602 }{ 603 {"normal", "/bin/sh", "-sh"}, 604 {"no slash sign", "noslash", "-noslash"}, 605 } 606 607 for _, v := range tc { 608 got := getShellNameFrom(v.shellPath) 609 if got != v.shellName { 610 t.Errorf("%q expect %q, got %q\n", v.label, v.shellName, got) 611 } 612 } 613 } 614 615 func TestGetTimeFrom(t *testing.T) { 616 tc := []struct { 617 lable string 618 key, value string 619 expect int64 620 }{ 621 {"positive int64", "ENV1", "123", 123}, 622 {"malform int64", "ENV2", "123a", 0}, 623 {"negative int64", "ENV3", "-123", 0}, 624 } 625 626 // save the stdout and create replaced pipe 627 rescueStdout := os.Stdout 628 r, w, _ := os.Pipe() 629 os.Stdout = w 630 // initLog() 631 632 oldArgs := os.Args 633 defer func() { os.Args = oldArgs }() 634 635 for _, v := range tc { 636 os.Setenv(v.key, v.value) 637 638 got := getTimeFrom(v.key, 0) 639 if got != v.expect { 640 t.Errorf("%s expct %d, got %d\n", v.lable, v.expect, got) 641 } 642 } 643 644 // read and restore the stdout 645 w.Close() 646 io.ReadAll(r) 647 os.Stdout = rescueStdout 648 } 649 650 /* 651 func testPTY() error { 652 // Create arbitrary command. 653 c := exec.Command("bash") 654 655 // Start the command with a pty. 656 ptmx, err := pty.Start(c) 657 if err != nil { 658 return err 659 } 660 // Make sure to close the pty at the end. 661 defer func() { _ = ptmx.Close() }() // Best effort. 662 663 // Handle pty size. 664 ch := make(chan os.Signal, 1) 665 signal.Notify(ch, syscall.SIGWINCH) 666 go func() { 667 for range ch { 668 if err := pty.InheritSize(os.Stdin, ptmx); err != nil { 669 log.Printf("error resizing pty: %s", err) 670 } 671 } 672 }() 673 ch <- syscall.SIGWINCH // Initial resize. 674 defer func() { signal.Stop(ch); close(ch) }() // Cleanup signals when done. 675 676 // Set stdin in raw mode. 677 oldState, err := term.MakeRaw(int(os.Stdin.Fd())) 678 if err != nil { 679 panic(err) 680 } 681 defer func() { _ = term.Restore(int(os.Stdin.Fd()), oldState) }() // Best effort. 682 683 // Copy stdin to the pty and the pty to stdout. 684 // NOTE: The goroutine will keep reading until the next keystroke before returning. 685 go func() { _, _ = io.Copy(ptmx, os.Stdin) }() 686 _, _ = io.Copy(os.Stdout, ptmx) 687 688 return nil 689 } 690 */ 691 692 func TestMainSrvStart(t *testing.T) { 693 tc := []struct { 694 label string 695 pause int // pause between client send and read 696 resp string // response client read 697 shutdown int // pause before shutdown message 698 conf Config 699 }{ 700 { 701 "start normally", 100, frontend.AprilshMsgOpen + "7101,", 150, 702 Config{ 703 version: false, server: true, verbose: 0, desiredIP: "", desiredPort: "7100", 704 locales: localeFlag{"LC_ALL": "en_US.UTF-8", "LANG": "en_US.UTF-8"}, 705 commandPath: "/bin/sh", commandArgv: []string{"/bin/sh"}, withMotd: false, 706 addSource: false, 707 }, 708 }, 709 } 710 711 if runtime.GOARCH == "riscv64" { 712 t.Skip("riscv64 timer is not as accurate as other platform, skip this test.") 713 } 714 // the test start child process, which is /usr/bin/apshd 715 // which means you need to compile /usr/bin/apshd before test 716 for _, v := range tc { 717 t.Run(v.label, func(t *testing.T) { 718 // init log 719 // util.Logger.CreateLogger(io.Discard, true, slog.LevelDebug) 720 util.Logger.CreateLogger(os.Stderr, true, slog.LevelDebug) 721 722 srv := newMainSrv(&v.conf) 723 724 // send shutdown message after some time 725 timer1 := time.NewTimer(time.Duration(v.shutdown) * time.Millisecond) 726 go func() { 727 <-timer1.C 728 // fmt.Printf("#test start PID:%d\n", os.Getpid()) 729 // all the go test run in the same process 730 // syscall.Kill(os.Getpid(), syscall.SIGHUP) 731 // syscall.Kill(os.Getpid(), syscall.SIGTERM) 732 srv.downChan <- true 733 // stop the worker correctly, because mockRunWorker2 failed to 734 // do it on purpose. 735 // srv.exChan <- fmt.Sprintf("%d", srv.maxPort) 736 }() 737 738 srv.start(&v.conf) 739 740 // mock client operation 741 // fmt.Printf("#test mark=%d\n", 100) 742 resp := mockClient(v.conf.desiredPort, v.pause, frontend.AprilshMsgOpen) 743 // fmt.Printf("#test mark=%s\n", resp) 744 if !strings.Contains(resp, v.resp) { 745 t.Errorf("#test run expect %q got %q\n", v.resp, resp) 746 } 747 748 srv.wait() 749 // e, err := os.Executable() 750 // fmt.Fprintf(os.Stderr, "Executable=%s, err=%s\n", e, err) 751 // fmt.Fprintf(os.Stderr, "Args[0] =%s\n", os.Args[0]) 752 // fmt.Fprintf(os.Stderr, "CWD =%s\n", os.Args[0]) 753 }) 754 } 755 } 756 757 func TestStartFail(t *testing.T) { 758 tc := []struct { 759 label string 760 pause int // pause between client send and read 761 resp string // response client read 762 finish int // pause before shutdown message 763 conf Config 764 }{ 765 { 766 "illegal port", 20, "", 150, 767 Config{ 768 version: false, server: true, verbose: 0, desiredIP: "", desiredPort: "7000a", 769 locales: localeFlag{"LC_ALL": "en_US.UTF-8", "LANG": "en_US.UTF-8"}, 770 commandPath: "/bin/sh", commandArgv: []string{"/bin/sh"}, withMotd: false, 771 }, 772 }, 773 } 774 775 for _, v := range tc { 776 t.Run(v.label, func(t *testing.T) { 777 // intercept logW 778 var b strings.Builder 779 util.Logger.CreateLogger(&b, true, slog.LevelDebug) 780 781 // srv := newMainSrv(&v.conf, mockRunWorker) 782 m := newMainSrv(&v.conf) 783 784 // defer func() { 785 // logW = log.New(os.Stdout, "WARN: ", log.Ldate|log.Ltime|log.Lshortfile) 786 // }() 787 788 // start mainserver 789 m.start(&v.conf) 790 // fmt.Println("#test start fail!") 791 792 // validate result: result contains WARN and COMMAND_NAME 793 expect := []string{"WARN", "listen failed"} 794 result := b.String() 795 found := 0 796 for i := range expect { 797 if strings.Contains(result, expect[i]) { 798 found++ 799 } 800 } 801 if found != 2 { 802 t.Errorf("#test start() expect %q, got %q\n", expect, result) 803 } 804 }) 805 } 806 } 807 808 // the mock runWorker send the key, pause some time and close the 809 // worker by send finish message 810 func mockRunWorker(conf *Config, exChan chan string, whChan chan workhorse) error { 811 // send the mock key 812 // fmt.Println("#mockRunWorker send mock key to run().") 813 exChan <- "This is the mock key" 814 815 // pause some time 816 time.Sleep(time.Duration(2) * time.Millisecond) 817 818 whChan <- workhorse{} 819 820 // notify the server 821 // fmt.Println("#mockRunWorker finish run().") 822 exChan <- conf.desiredPort 823 return nil 824 } 825 826 // the mock runWorker send the key, pause some time and try to close the 827 // worker by send wrong finish message: port+"x" 828 func mockRunWorker2(conf *Config, exChan chan string, whChan chan workhorse) error { 829 // send the mock key 830 exChan <- "mock key from mockRunWorker2" 831 832 // pause some time 833 time.Sleep(time.Duration(2) * time.Millisecond) 834 835 // fail to stop the worker on purpose 836 exChan <- conf.desiredPort + "x" 837 838 whChan <- workhorse{} 839 840 return nil 841 } 842 843 // mock client connect to the port, send handshake message, pause some time 844 // return the response message. 845 func mockClient(port string, pause int, action string, ex ...string) string { 846 server_addr, _ := net.ResolveUDPAddr("udp", "localhost:"+port) 847 local_addr, _ := net.ResolveUDPAddr("udp", "localhost:0") 848 conn, _ := net.DialUDP("udp", local_addr, server_addr) 849 850 defer conn.Close() 851 852 // send handshake message based on action & port 853 var txbuf []byte 854 switch action { 855 case frontend.AprilshMsgOpen: 856 switch len(ex) { 857 case 0: 858 txbuf = []byte(frontend.AprilshMsgOpen + "xterm," + getCurrentUser() + "@localhost") 859 case 1: 860 // the request missing the ',' 861 txbuf = []byte(fmt.Sprintf("%s%s", frontend.AprilshMsgOpen, ex[0])) 862 } 863 case frontend.AprishMsgClose: 864 p, _ := strconv.Atoi(port) 865 switch len(ex) { 866 case 0: 867 txbuf = []byte(fmt.Sprintf("%s%d", frontend.AprishMsgClose, p+1)) 868 case 1: 869 p2, err := strconv.Atoi(ex[0]) 870 if err == nil { 871 txbuf = []byte(fmt.Sprintf("%s%d", frontend.AprishMsgClose, p2)) // 1 digital parameter: wrong port 872 } else { 873 txbuf = []byte(fmt.Sprintf("%s%s", frontend.AprishMsgClose, ex[0])) // 1 str parameter: malform port 874 } 875 case 2: 876 txbuf = []byte(fmt.Sprintf("%s%d", "unknow header:", p+1)) // 2 parameters: unknow header 877 } 878 } 879 880 _, err := conn.Write(txbuf) 881 // fmt.Printf("#mockClient send %q to server: %v from %v\n", txbuf, server_addr, conn.LocalAddr()) 882 if err != nil { 883 fmt.Printf("#mockClient send %s, error %s\n", string(txbuf), err) 884 } 885 886 // pause some time 887 time.Sleep(time.Duration(pause) * time.Millisecond) 888 889 // read the response 890 rxbuf := make([]byte, 512) 891 n, _, err := conn.ReadFromUDP(rxbuf) 892 893 // fmt.Printf("#mockClient read %q from server: %v\n", rxbuf[0:n], server_addr) 894 return string(rxbuf[0:n]) 895 } 896 897 func TestPrintWelcome(t *testing.T) { 898 // open pts master and slave first. 899 pty, tty, err := pty.Open() 900 if err != nil { 901 t.Errorf("#test printWelcome Open %s\n", err) 902 } 903 904 // clean pts fd 905 defer func() { 906 if err != nil { 907 pty.Close() 908 tty.Close() 909 } 910 }() 911 912 // pty master doesn't support IUTF8 913 flag, err := util.CheckIUTF8(int(pty.Fd())) 914 if flag { 915 t.Errorf("#test printWelcome master got %t, expect %t\n", flag, false) 916 } 917 918 expect := []string{"Warning: termios IUTF8 flag not defined."} 919 920 tc := []struct { 921 label string 922 tty *os.File 923 }{ 924 {"tty doesn't support IUTF8 flag", pty}, 925 {"tty failed with checkIUTF8", os.Stdin}, 926 } 927 928 for _, v := range tc { 929 // intercept stdout 930 saveStdout := os.Stdout 931 r, w, _ := os.Pipe() 932 os.Stdout = w 933 util.Logger.CreateLogger(w, true, slog.LevelDebug) 934 935 // printWelcome(os.Getpid(), 6000, v.tty) 936 printWelcome(v.tty) 937 938 // restore stdout 939 w.Close() 940 b, _ := io.ReadAll(r) 941 os.Stdout = saveStdout 942 r.Close() 943 944 // validate the result 945 result := string(b) 946 found := 0 947 for i := range expect { 948 if strings.Contains(result, expect[i]) { 949 found++ 950 } 951 } 952 if found != len(expect) { 953 t.Errorf("#test printWelcome expect %q, got %s\n", expect, result) 954 } 955 } 956 } 957 958 func TestListenFail(t *testing.T) { 959 tc := []struct { 960 label string 961 port string 962 repeat bool // if true, will listen twice. 963 }{ 964 {"illegal port number", "22a", false}, 965 {"port already in use", "60001", true}, // 60001 is the docker port on macOS 966 } 967 for _, v := range tc { 968 conf := &Config{desiredPort: v.port} 969 // s := newMainSrv(conf, mockRunWorker) 970 s := newMainSrv(conf) 971 972 var e error 973 e = s.listen(conf) 974 // fmt.Printf("#test %q got 1st error: %q\n", v.label, e) 975 if v.repeat { 976 e = s.listen(conf) 977 // fmt.Printf("#test %q got 2nd error: %q\n", v.label, e) 978 } 979 980 // check the error does happens 981 if e == nil { 982 t.Errorf("#test %q expect error return, got nil\n", v.label) 983 } 984 985 // close the listen port 986 if v.repeat { 987 s.exChan <- conf.desiredPort 988 } 989 } 990 } 991 992 // func testRunFail(t *testing.T) { 993 // tc := []struct { 994 // label string 995 // pause int // pause between client send and read 996 // resp string // response client read 997 // finish int // pause before shutdown message 998 // conf Config 999 // }{ 1000 // { 1001 // "worker failed with wrong port number", 100, frontend.AprilshMsgOpen + "7101,mock key from mockRunWorker2\n", 30, 1002 // Config{ 1003 // version: false, server: true, verbose: 1, desiredIP: "", desiredPort: "7100", 1004 // locales: localeFlag{"LC_ALL": "en_US.UTF-8", "LANG": "en_US.UTF-8"}, 1005 // commandPath: "/bin/sh", commandArgv: []string{"/bin/sh"}, withMotd: false, 1006 // addSource: false, 1007 // }, 1008 // }, 1009 // } 1010 // 1011 // for _, v := range tc { 1012 // t.Run(v.label, func(t *testing.T) { 1013 // // intercept stdout 1014 // saveStdout := os.Stdout 1015 // r, w, _ := os.Pipe() 1016 // os.Stdout = w 1017 // // initLog() 1018 // 1019 // // util.Logger.CreateLogger(w, true, slog.LevelDebug) 1020 // util.Logger.CreateLogger(os.Stderr, true, slog.LevelDebug) 1021 // 1022 // // srv := newMainSrv(&v.conf, mockRunWorker2) 1023 // srv := newMainSrv(&v.conf) 1024 // 1025 // // send shutdown message after some time 1026 // timer1 := time.NewTimer(time.Duration(v.finish) * time.Millisecond) 1027 // go func() { 1028 // <-timer1.C 1029 // // prepare to shudown the mainSrv 1030 // // syscall.Kill(syscall.Getpid(), syscall.SIGTERM) 1031 // srv.downChan <- true 1032 // // stop the worker correctly, because mockRunWorker2 failed to 1033 // // do it on purpose. 1034 // port, _ := strconv.Atoi(v.conf.desiredPort) 1035 // srv.exChan <- fmt.Sprintf("%d", port+1) 1036 // util.Logger.Debug("send port to exChan", "port", port+1) 1037 // }() 1038 // // fmt.Println("#test start timer for shutdown") 1039 // 1040 // srv.start(&v.conf) 1041 // 1042 // // mock client operation 1043 // resp := mockClient(v.conf.desiredPort, v.pause, frontend.AprilshMsgOpen) 1044 // 1045 // // validate the result. 1046 // if resp != v.resp { 1047 // t.Errorf("#test run expect %q got %q\n", v.resp, resp) 1048 // } 1049 // 1050 // srv.wait() 1051 // 1052 // // restore stdout 1053 // w.Close() 1054 // io.ReadAll(r) 1055 // os.Stdout = saveStdout 1056 // r.Close() 1057 // }) 1058 // } 1059 // 1060 // // test case for run() without connection 1061 // 1062 // srv2 := &mainSrv{} 1063 // srv2.run(&Config{}) 1064 // } 1065 1066 func TestRunFail2(t *testing.T) { 1067 tc := []struct { 1068 label string 1069 pause int // pause between client send and read 1070 resp string // response client read 1071 finish int // pause before shutdown message 1072 conf Config 1073 }{ 1074 { 1075 "read udp error", 20, "7101,This is the mock key", 150, 1076 Config{ 1077 version: false, server: true, verbose: 0, desiredIP: "", desiredPort: "7100", 1078 locales: localeFlag{"LC_ALL": "en_US.UTF-8", "LANG": "en_US.UTF-8"}, 1079 commandPath: "/bin/sh", commandArgv: []string{"/bin/sh"}, withMotd: false, 1080 }, 1081 }, 1082 } 1083 1084 for _, v := range tc { 1085 t.Run(v.label, func(t *testing.T) { 1086 // intercept stdout 1087 saveStdout := os.Stdout 1088 r, w, _ := os.Pipe() 1089 os.Stdout = w 1090 // initLog() 1091 util.Logger.CreateLogger(w, true, slog.LevelDebug) 1092 1093 // srv := newMainSrv(&v.conf, mockRunWorker) 1094 srv := newMainSrv(&v.conf) 1095 1096 // send shutdown message after some time 1097 timer1 := time.NewTimer(time.Duration(v.finish) * time.Millisecond) 1098 go func() { 1099 <-timer1.C 1100 srv.downChan <- true 1101 // syscall.Kill(syscall.Getpid(), syscall.SIGTERM) 1102 }() 1103 // fmt.Println("#test start timer for shutdown") 1104 1105 srv.start(&v.conf) 1106 1107 // close the connection, this will cause read error: use of closed network connection. 1108 srv.conn.Close() 1109 1110 srv.wait() 1111 1112 // restore stdout 1113 w.Close() 1114 io.ReadAll(r) 1115 os.Stdout = saveStdout 1116 r.Close() 1117 }) 1118 } 1119 } 1120 1121 func TestMaxPortLimit(t *testing.T) { 1122 tc := []struct { 1123 label string 1124 maxPortLimit int 1125 pause int // pause between client send and read 1126 resp string // response client read 1127 shutdownTime int // pause before shutdown message 1128 conf Config 1129 }{ 1130 { 1131 "run() over max port", 0, 20, "over max port limit", 150, 1132 Config{ 1133 version: false, server: true, verbose: 0, desiredIP: "", desiredPort: "7700", 1134 locales: localeFlag{"LC_ALL": "en_US.UTF-8", "LANG": "en_US.UTF-8"}, 1135 commandPath: "/bin/sh", commandArgv: []string{"/bin/sh"}, withMotd: false, 1136 }, 1137 }, 1138 } 1139 1140 for _, v := range tc { 1141 t.Run(v.label, func(t *testing.T) { 1142 // intercept stdout 1143 1144 util.Logger.CreateLogger(io.Discard, true, slog.LevelDebug) 1145 1146 // init mainSrv and workers 1147 // m := newMainSrv(&v.conf, runWorker) 1148 m := newMainSrv(&v.conf) 1149 1150 // save maxPortLimit 1151 old := maxPortLimit 1152 maxPortLimit = v.maxPortLimit 1153 1154 // send shutdown message after some time 1155 timer1 := time.NewTimer(time.Duration(v.shutdownTime) * time.Millisecond) 1156 go func() { 1157 <-timer1.C 1158 m.downChan <- true 1159 }() 1160 1161 // start mainserver 1162 m.start(&v.conf) 1163 1164 // mock client operation 1165 resp := mockClient(v.conf.desiredPort, v.pause, frontend.AprilshMsgOpen) 1166 1167 m.wait() 1168 1169 if !strings.Contains(resp, v.resp) { 1170 t.Errorf("%q expect response %q, got %q\n ", v.label, v.resp, resp) 1171 } 1172 1173 // restore maxPortLimit 1174 maxPortLimit = old 1175 }) 1176 } 1177 } 1178 1179 func TestMalformRequest(t *testing.T) { 1180 tc := []struct { 1181 label string 1182 pause int // pause between client send and read 1183 resp string // response client read 1184 shutdownTime int // pause before shutdown message 1185 conf Config 1186 }{ 1187 { 1188 "run() malform request", 20, "malform request", 150, 1189 Config{ 1190 version: false, server: true, verbose: 0, desiredIP: "", desiredPort: "7700", 1191 locales: localeFlag{"LC_ALL": "en_US.UTF-8", "LANG": "en_US.UTF-8"}, 1192 commandPath: "/bin/sh", commandArgv: []string{"/bin/sh"}, withMotd: false, 1193 }, 1194 }, 1195 } 1196 1197 for _, v := range tc { 1198 // intercept stdout 1199 1200 util.Logger.CreateLogger(io.Discard, true, slog.LevelDebug) 1201 1202 // init mainSrv and workers 1203 // m := newMainSrv(&v.conf, runWorker) 1204 m := newMainSrv(&v.conf) 1205 1206 // send shutdown message after some time 1207 timer1 := time.NewTimer(time.Duration(v.shutdownTime) * time.Millisecond) 1208 go func() { 1209 <-timer1.C 1210 syscall.Kill(os.Getpid(), syscall.SIGHUP) // add SIGHUP test condition 1211 time.Sleep(time.Duration(v.shutdownTime+5) * time.Millisecond) 1212 m.downChan <- true 1213 }() 1214 1215 // start mainserver 1216 m.start(&v.conf) 1217 1218 // mock client operation 1219 resp := mockClient(v.conf.desiredPort, v.pause, frontend.AprilshMsgOpen, "extraParam") 1220 1221 m.wait() 1222 1223 if !strings.Contains(resp, v.resp) { 1224 t.Errorf("%q expect response %q, got %q\n ", v.label, v.resp, resp) 1225 } 1226 } 1227 } 1228 1229 func mockServe(ptmx *os.File, pts *os.File, pw *io.PipeWriter, terminal *statesync.Complete, // x chan bool, 1230 network *network.Transport[*statesync.Complete, *statesync.UserStream], 1231 networkTimeout int64, networkSignaledTimeout int64, user string) error { 1232 time.Sleep(10 * time.Millisecond) 1233 // x <- true 1234 return nil 1235 } 1236 1237 // the mock runWorker send empty key, pause some time and close the worker 1238 func failRunWorker(conf *Config, exChan chan string, whChan chan *workhorse) error { 1239 // send the empty key 1240 // fmt.Println("#mockRunWorker send mock key to run().") 1241 exChan <- "" 1242 1243 // pause some time 1244 time.Sleep(time.Duration(2) * time.Millisecond) 1245 1246 // notify this worker is done 1247 defer func() { 1248 exChan <- conf.desiredPort 1249 }() 1250 1251 whChan <- &workhorse{} 1252 return errors.New("failed worker.") 1253 } 1254 1255 func TestRunWorkerKillSignal(t *testing.T) { 1256 tc := []struct { 1257 label string 1258 pause int // pause between client send and read 1259 resp string // response client read 1260 finish int // pause before shutdown message 1261 conf Config 1262 }{ 1263 { 1264 "runWorker stopped by signal kill", 10, frontend.AprilshMsgOpen + "7101,", 150, 1265 Config{ 1266 version: false, server: true, flowControl: _FC_SKIP_PIPE_LOCK, desiredIP: "", desiredPort: "7100", 1267 locales: localeFlag{"LC_ALL": "en_US.UTF-8", "LANG": "en_US.UTF-8"}, 1268 commandPath: "/bin/sh", commandArgv: []string{"-sh"}, withMotd: true, 1269 }, 1270 }, 1271 } 1272 1273 for _, v := range tc { 1274 t.Run(v.label, func(t *testing.T) { 1275 1276 // intercept stdout 1277 saveStdout := os.Stdout 1278 r, w, _ := os.Pipe() 1279 os.Stdout = w 1280 1281 util.Logger.CreateLogger(w, true, slog.LevelDebug) 1282 // util.Logger.CreateLogger(os.Stderr, true, slog.LevelDebug) 1283 1284 // set serve func and runWorker func 1285 v.conf.serve = mockServe 1286 // srv := newMainSrv(&v.conf, runWorker) 1287 srv := newMainSrv(&v.conf) 1288 1289 /// set commandPath and commandArgv based on environment 1290 v.conf.commandPath = os.Getenv("SHELL") 1291 v.conf.commandArgv = []string{getShellNameFrom(v.conf.commandPath)} 1292 1293 // send kill signal after some time (finish ms) 1294 timer1 := time.NewTimer(time.Duration(v.finish) * time.Millisecond) 1295 go func() { 1296 <-timer1.C 1297 srv.downChan <- true 1298 }() 1299 1300 srv.start(&v.conf) 1301 1302 // mock client operation 1303 resp := mockClient(v.conf.desiredPort, v.pause, frontend.AprilshMsgOpen) 1304 if !strings.HasPrefix(resp, v.resp) { 1305 t.Errorf("#test run expect %q got %q\n", v.resp, resp) 1306 } 1307 1308 srv.wait() 1309 1310 // restore stdout 1311 w.Close() 1312 io.ReadAll(r) 1313 os.Stdout = saveStdout 1314 r.Close() 1315 }) 1316 } 1317 } 1318 1319 // func testRunWorkerFail(t *testing.T) { 1320 // tc := []struct { 1321 // label string 1322 // conf Config 1323 // }{ 1324 // { 1325 // "openPTS fail", Config{ 1326 // version: false, server: true, flowControl: _FC_OPEN_PTS_FAIL, desiredIP: "", desiredPort: "7100", 1327 // locales: localeFlag{"LC_ALL": "en_US.UTF-8", "LANG": "en_US.UTF-8"}, term: "kitty", 1328 // commandPath: "/bin/xxxsh", commandArgv: []string{"-sh"}, withMotd: false, 1329 // }, 1330 // }, 1331 // { 1332 // "startShell fail", Config{ 1333 // version: false, server: true, flowControl: _FC_SKIP_START_SHELL, desiredIP: "", desiredPort: "7200", 1334 // locales: localeFlag{"LC_ALL": "en_US.UTF-8", "LANG": "en_US.UTF-8"}, term: "kitty", 1335 // commandPath: "/bin/xxxsh", commandArgv: []string{"-sh"}, withMotd: false, 1336 // }, 1337 // }, 1338 // // { 1339 // // "shell.Wait fail", Config{ 1340 // // version: false, server: true, verbose: _VERBOSE_SKIP_READ_PIPE, desiredIP: "", desiredPort: "7300", 1341 // // locales: localeFlag{"LC_ALL": "en_US.UTF-8", "LANG": "en_US.UTF-8"}, term: "kitty", 1342 // // commandPath: "echo", commandArgv: []string{"2"}, withMotd: false, 1343 // // }, 1344 // // }, 1345 // } 1346 // 1347 // exChan := make(chan string, 1) 1348 // whChan := make(chan workhorse, 1) 1349 // 1350 // for _, v := range tc { 1351 // t.Run(v.label, func(t *testing.T) { 1352 // 1353 // // intercept log output 1354 // util.Logger.CreateLogger(io.Discard, true, slog.LevelDebug) 1355 // 1356 // var wg sync.WaitGroup 1357 // var hasWorkhorse bool 1358 // v.conf.serve = mockServe 1359 // if strings.Contains(v.label, "shell.Wait fail") { 1360 // v.conf.commandPath, _ = exec.LookPath(v.conf.commandPath) 1361 // hasWorkhorse = true // last one has effective work horse. 1362 // } 1363 // 1364 // wg.Add(1) 1365 // go func() { 1366 // defer wg.Done() 1367 // <-exChan // get the key 1368 // wh := <-whChan // get the workhorse 1369 // if hasWorkhorse { 1370 // if wh.child == nil { 1371 // t.Errorf("#test runWorker fail should return empty workhorse\n") 1372 // } 1373 // wh.child.Kill() 1374 // } else if strings.Contains(v.label, "openPTS fail") { 1375 // if wh.child != nil { 1376 // t.Errorf("#test runWorker fail should return empty workhorse\n") 1377 // } 1378 // msg := <-exChan // get the done message 1379 // if msg != v.conf.desiredPort { 1380 // t.Errorf("#test runWorker fail should return %s, got %s\n", v.conf.desiredPort, msg) 1381 // } 1382 // } else if strings.Contains(v.label, "startShell fail") { 1383 // if wh.child != nil { 1384 // t.Errorf("#test runWorker fail should return empty workhorse\n") 1385 // } 1386 // msg := <-exChan // get the done message 1387 // if msg != v.conf.desiredPort+":shutdown" { 1388 // t.Errorf("#test runWorker fail should return %s, got %s\n", v.conf.desiredPort, msg) 1389 // } 1390 // } 1391 // }() 1392 // 1393 // // TODO disable it for the time being 1394 // // if hasWorkhorse { 1395 // // if err := runWorker(&v.conf, exChan, whChan); err != nil { 1396 // // t.Errorf("#test runWorker should not report error.\n") 1397 // // } 1398 // // } else { 1399 // // if err := runWorker(&v.conf, exChan, whChan); err == nil { 1400 // // t.Errorf("#test runWorker should report error.\n") 1401 // // } 1402 // // } 1403 // 1404 // wg.Wait() 1405 // }) 1406 // } 1407 // } 1408 1409 func TestRunCloseFail(t *testing.T) { 1410 tc := []struct { 1411 label string 1412 pause int // pause between client send and read 1413 resp1 string // response of start action 1414 resp2 string // response of stop action 1415 exp []string // ex parameter 1416 finish int // pause before shutdown message 1417 conf Config 1418 }{ 1419 // { 1420 // "runWorker stopped by " + frontend.AprishMsgClose, 20, frontend.AprilshMsgOpen + "7111,", frontend.AprishMsgClose + "done", 1421 // []string{}, 1422 // 150, 1423 // Config{ 1424 // version: false, server: true, flowControl: _FC_SKIP_PIPE_LOCK, desiredIP: "", desiredPort: "7110", 1425 // locales: localeFlag{"LC_ALL": "en_US.UTF-8", "LANG": "en_US.UTF-8"}, 1426 // commandPath: "/bin/sh", commandArgv: []string{"-sh"}, withMotd: true, 1427 // }, 1428 // }, 1429 // { 1430 // "runWorker stop port not exist", 5, frontend.AprilshMsgOpen + "7121,", frontend.AprishMsgClose + "port does not exist", 1431 // []string{"7100"}, 1432 // 150, 1433 // Config{ 1434 // version: false, server: true, flowControl: _FC_SKIP_PIPE_LOCK, desiredIP: "", desiredPort: "7120", 1435 // locales: localeFlag{"LC_ALL": "en_US.UTF-8", "LANG": "en_US.UTF-8"}, 1436 // commandPath: "/bin/sh", commandArgv: []string{"-sh"}, withMotd: true, 1437 // }, 1438 // }, 1439 // { 1440 // "runWorker stop wrong port number", 5, frontend.AprilshMsgOpen + "7131,", frontend.AprishMsgClose + "wrong port number", 1441 // []string{"7121x"}, 1442 // 150, 1443 // Config{ 1444 // version: false, server: true, flowControl: _FC_SKIP_PIPE_LOCK, desiredIP: "", desiredPort: "7130", 1445 // locales: localeFlag{"LC_ALL": "en_US.UTF-8", "LANG": "en_US.UTF-8"}, 1446 // commandPath: "/bin/sh", commandArgv: []string{"-sh"}, withMotd: true, 1447 // }, 1448 // }, 1449 { 1450 "runWorker stop unknow request", 25, frontend.AprilshMsgOpen + "7141,", frontend.AprishMsgClose + "unknow request", 1451 []string{"two", "params"}, 1452 150, 1453 Config{ 1454 version: false, server: true, flowControl: _FC_SKIP_PIPE_LOCK, desiredIP: "", desiredPort: "7140", 1455 locales: localeFlag{"LC_ALL": "en_US.UTF-8", "LANG": "en_US.UTF-8"}, 1456 commandPath: "/bin/sh", commandArgv: []string{"-sh"}, withMotd: true, 1457 }, 1458 }, 1459 } 1460 1461 if runtime.GOARCH == "s390x" { 1462 t.Skip("for s390x, skip this test.") 1463 } 1464 for _, v := range tc { 1465 t.Run(v.label, func(t *testing.T) { 1466 1467 util.Logger.CreateLogger(io.Discard, true, slog.LevelDebug) 1468 // util.Logger.CreateLogger(os.Stderr, true, slog.LevelDebug) 1469 1470 // set serve func and runWorker func 1471 v.conf.serve = mockServe 1472 // srv := newMainSrv(&v.conf, runWorker) 1473 srv := newMainSrv(&v.conf) 1474 1475 /// set commandPath and commandArgv based on environment 1476 v.conf.commandPath = os.Getenv("SHELL") 1477 v.conf.commandArgv = []string{getShellNameFrom(v.conf.commandPath)} 1478 1479 // send shutdown message after some time (finish ms) 1480 timer1 := time.NewTimer(time.Duration(v.finish) * time.Millisecond) 1481 go func() { 1482 <-timer1.C 1483 srv.downChan <- true 1484 }() 1485 1486 srv.start(&v.conf) 1487 1488 // start a new connection 1489 resp1 := mockClient(v.conf.desiredPort, v.pause, frontend.AprilshMsgOpen) 1490 if !strings.HasPrefix(resp1, v.resp1) { 1491 t.Errorf("#test run expect %q got %q\n", v.resp1, resp1) 1492 } 1493 // fmt.Printf("#test got response resp1=%s\n", resp1) 1494 1495 time.Sleep(10 * time.Millisecond) 1496 1497 // stop the new connection 1498 resp2 := mockClient(v.conf.desiredPort, v.pause, frontend.AprishMsgClose, v.exp...) 1499 if !strings.HasPrefix(resp2, v.resp2) { 1500 t.Errorf("#test run expect %q got %q\n", v.resp1, resp2) 1501 } 1502 1503 // fmt.Printf("#test got response resp2=%s\n", resp2) 1504 // stop the connection 1505 if len(v.exp) > 0 { 1506 expect := frontend.AprishMsgClose + "done" 1507 resp2 := mockClient(v.conf.desiredPort, v.pause, frontend.AprishMsgClose) 1508 if !strings.HasPrefix(resp2, expect) { 1509 t.Errorf("#test run stop the connection expect %q got %q\n", v.resp1, resp2) 1510 } 1511 } 1512 1513 // fmt.Printf("#test got stop response resp2=%s\n", resp2) 1514 srv.wait() 1515 }) 1516 } 1517 } 1518 1519 func TestRunWith2Clients(t *testing.T) { 1520 tc := []struct { 1521 label string 1522 pause int // pause between client send and read 1523 resp1 string // response of start action 1524 resp2 string // response of stop action 1525 resp3 string // response of additinoal open request 1526 exp []string // ex parameter 1527 finish int // pause before shutdown message 1528 conf Config 1529 }{ 1530 { 1531 "open aprilsh with duplicate request", 20, frontend.AprilshMsgOpen + "7101,", frontend.AprishMsgClose + "done", 1532 frontend.AprilshMsgOpen + "7102", []string{}, 150, 1533 Config{ 1534 version: false, server: true, flowControl: _FC_SKIP_PIPE_LOCK, desiredIP: "", desiredPort: "7100", 1535 locales: localeFlag{"LC_ALL": "en_US.UTF-8", "LANG": "en_US.UTF-8"}, 1536 commandPath: "/bin/sh", commandArgv: []string{"-sh"}, withMotd: true, 1537 }, 1538 }, 1539 } 1540 1541 for _, v := range tc { 1542 t.Run(v.label, func(t *testing.T) { 1543 1544 // intercept stdout 1545 util.Logger.CreateLogger(io.Discard, true, slog.LevelDebug) 1546 1547 // set serve func and runWorker func 1548 v.conf.serve = mockServe 1549 // srv := newMainSrv(&v.conf, runWorker) 1550 srv := newMainSrv(&v.conf) 1551 1552 /// set commandPath and commandArgv based on environment 1553 v.conf.commandPath = os.Getenv("SHELL") 1554 v.conf.commandArgv = []string{getShellNameFrom(v.conf.commandPath)} 1555 1556 srv.start(&v.conf) 1557 1558 // start a new connection 1559 resp1 := mockClient(v.conf.desiredPort, v.pause, frontend.AprilshMsgOpen) 1560 if !strings.HasPrefix(resp1, v.resp1) { 1561 t.Errorf("#test first client start expect %q got %q\n", v.resp1, resp1) 1562 } 1563 // fmt.Printf("#test got 1 response %q\n", resp1) 1564 1565 // start a new connection 1566 resp3 := mockClient(v.conf.desiredPort, v.pause, frontend.AprilshMsgOpen) 1567 if !strings.HasPrefix(resp3, v.resp3) { 1568 t.Errorf("#test second client start expect %q got %q\n", v.resp3, resp3) 1569 } 1570 // fmt.Printf("#test got 3 response %q\n", resp3) 1571 1572 // stop the new connection 1573 resp2 := mockClient(v.conf.desiredPort, v.pause, frontend.AprishMsgClose, v.exp...) 1574 if !strings.HasPrefix(resp2, v.resp2) { 1575 t.Errorf("#test firt client stop expect %q got %q\n", v.resp1, resp2) 1576 } 1577 // fmt.Printf("#test got 2 response %q\n", resp2) 1578 1579 // send shutdown message after some time (finish ms) 1580 timer1 := time.NewTimer(time.Duration(v.finish) * time.Millisecond) 1581 go func() { 1582 <-timer1.C 1583 srv.downChan <- true 1584 }() 1585 1586 srv.wait() 1587 }) 1588 } 1589 } 1590 1591 func TestStartShellError(t *testing.T) { 1592 tc := []struct { 1593 label string 1594 errStr string 1595 pts *os.File 1596 pr *io.PipeReader 1597 utmpHost string 1598 conf Config 1599 }{ 1600 {"first error return", "fail to start shell", os.Stdout, nil, "", 1601 Config{flowControl: _FC_SKIP_START_SHELL}, 1602 }, 1603 {"IUTF8 error return", strENOTTY, os.Stdin, nil, "", 1604 Config{}, 1605 }, // os.Stdin doesn't support IUTF8 flag, startShell should failed 1606 } 1607 1608 util.Logger.CreateLogger(io.Discard, true, slog.LevelDebug) 1609 1610 for _, v := range tc { 1611 t.Run(v.label, func(t *testing.T) { 1612 // open pty master and slave 1613 ptmx, pts, _ := pty.Open() 1614 if v.pts == nil { 1615 v.pts = pts 1616 } 1617 1618 // open pipe for parameter 1619 pr, pw := io.Pipe() 1620 if v.pr == nil { 1621 v.pr = pr 1622 } 1623 1624 _, err := startShellProcess(v.pts, v.pr, v.utmpHost, &v.conf) 1625 // fmt.Printf("%#v\n", err) 1626 1627 // validate error 1628 if !strings.Contains(err.Error(), v.errStr) { 1629 t.Errorf("%q should report %q, got %q\n", v.label, v.errStr, err) 1630 } 1631 1632 pr.Close() 1633 pw.Close() 1634 ptmx.Close() 1635 pts.Close() 1636 }) 1637 } 1638 } 1639 1640 func TestOpenPTS(t *testing.T) { 1641 1642 tc := []struct { 1643 label string 1644 ws unix.Winsize 1645 errStr string 1646 }{ 1647 {"invalid parameter error", unix.Winsize{}, "invalid parameter"}, 1648 {"invalid parameter error", unix.Winsize{Row: 4, Col: 4}, ""}, 1649 } 1650 1651 for i, v := range tc { 1652 t.Run(v.label, func(t *testing.T) { 1653 var ptmx, pts *os.File 1654 var err error 1655 if i == 0 { 1656 ptmx, pts, err = openPTS(nil) 1657 } else { 1658 ptmx, pts, err = openPTS(&v.ws) 1659 } 1660 defer ptmx.Close() 1661 defer pts.Close() 1662 if i == 0 { 1663 if !strings.Contains(err.Error(), v.errStr) { 1664 t.Errorf("%q should report %q, got %q\n", v.label, v.errStr, err) 1665 fmt.Printf("%#v\n", err) 1666 } 1667 } else { 1668 if err != nil { 1669 t.Errorf("%q expect no error, got %s\n", v.label, err) 1670 } 1671 } 1672 }) 1673 } 1674 } 1675 1676 // func testGetCurrentUser(t *testing.T) { 1677 // // normal invocation 1678 // userCurrentTest = false 1679 // uid := fmt.Sprintf("%d", os.Getuid()) 1680 // expect, _ := user.LookupId(uid) 1681 // 1682 // got := getCurrentUser() 1683 // if len(got) == 0 || expect.Username != got { 1684 // t.Errorf("#test getCurrentUser expect %s, got %s\n", expect.Username, got) 1685 // } 1686 // 1687 // // getCurrentUser fail 1688 // old := userCurrentTest 1689 // defer func() { 1690 // userCurrentTest = old 1691 // }() 1692 // 1693 // // intercept log output 1694 // var b strings.Builder 1695 // util.Logger.CreateLogger(&b, true, slog.LevelDebug) 1696 // 1697 // userCurrentTest = true 1698 // got = getCurrentUser() 1699 // if got != "" { 1700 // t.Errorf("#test getCurrentUser expect empty string, got %s\n", got) 1701 // } 1702 // // restore logW 1703 // // logW = log.New(os.Stdout, "WARN: ", log.Ldate|log.Ltime|log.Lshortfile) 1704 // } 1705 1706 func TestGetAvailablePort(t *testing.T) { 1707 tc := []struct { 1708 label string 1709 max int // pre-condition before getAvailabePort 1710 expectPort int 1711 expectMax int 1712 workers map[int]*workhorse 1713 }{ 1714 { 1715 "empty worker list", 6001, 6001, 6002, 1716 map[int]*workhorse{}, 1717 }, 1718 { 1719 "lart gap empty worker", 6008, 6001, 6002, 1720 map[int]*workhorse{}, 1721 }, 1722 { 1723 "add one port", 6002, 6002, 6003, 1724 map[int]*workhorse{6001: {}}, 1725 }, 1726 { 1727 "shrink max", 6013, 6002, 6003, 1728 map[int]*workhorse{6001: {}}, 1729 }, 1730 { 1731 "right most", 6004, 6004, 6005, 1732 map[int]*workhorse{6001: {}, 6002: {}, 6003: {}}, 1733 }, 1734 { 1735 "left most", 6006, 6001, 6006, 1736 map[int]*workhorse{6003: {}, 6004: {}, 6005: {}}, 1737 }, 1738 { 1739 "middle hole", 6009, 6004, 6009, 1740 map[int]*workhorse{6001: {}, 6002: {}, 6003: {}, 6008: {}}, 1741 }, 1742 { 1743 "border shape hole", 6019, 6002, 6019, 1744 map[int]*workhorse{6001: {}, 6018: {}}, 1745 }, 1746 } 1747 1748 conf := &Config{desiredPort: "6000"} 1749 1750 for _, v := range tc { 1751 t.Run(v.label, func(t *testing.T) { 1752 // intercept log output 1753 util.Logger.CreateLogger(io.Discard, true, slog.LevelDebug) 1754 1755 srv := newMainSrv(conf) 1756 srv.workers = v.workers 1757 srv.maxPort = v.max 1758 1759 got := srv.getAvailabePort() 1760 1761 if got != v.expectPort { 1762 t.Errorf("%q expect port=%d, got %d\n", v.label, v.expectPort, got) 1763 } 1764 1765 if srv.maxPort != v.expectMax { 1766 t.Errorf("%q expect maxPort=%d, got %d\n", v.label, v.expectMax, srv.maxPort) 1767 } 1768 }) 1769 } 1770 } 1771 1772 // func TestIsPortExist(t *testing.T) { 1773 // tc := []struct { 1774 // label string 1775 // port int 1776 // ret bool 1777 // }{ 1778 // {"port exist", 101, true}, 1779 // {"port does not exist", 10, false}, 1780 // } 1781 // 1782 // // prepare workers data 1783 // conf := &Config{desiredPort: "6000"} 1784 // 1785 // srv := newMainSrv(conf, mockRunWorker) 1786 // srv.workers[100] = &workhorse{nil, os.Stderr} 1787 // srv.workers[101] = &workhorse{nil, os.Stdout} 1788 // srv.workers[111] = &workhorse{nil, os.Stdin} 1789 // 1790 // for _, v := range tc { 1791 // t.Run(v.label, func(t *testing.T) { 1792 // got := srv.isPortExist(v.port) 1793 // if got != v.ret { 1794 // t.Errorf("%q port %d: expect %t, got %t\n", v.label, v.port, v.ret, got) 1795 // } 1796 // 1797 // }) 1798 // } 1799 // } 1800 1801 func BenchmarkGetAvailablePort(b *testing.B) { 1802 1803 conf := &Config{desiredPort: "100"} 1804 srv := newMainSrv(conf) 1805 srv.workers[100] = &workhorse{} 1806 srv.workers[101] = &workhorse{} 1807 srv.workers[102] = &workhorse{} 1808 1809 srv.maxPort = 102 1810 1811 for i := 0; i < b.N; i++ { 1812 srv.getAvailabePort() 1813 srv.maxPort-- // hedge maxPort++ in getAvailabePort 1814 } 1815 } 1816 1817 func TestCheckPortAvailable(t *testing.T) { 1818 tc := []struct { 1819 label string 1820 port int 1821 expect bool 1822 }{ 1823 {"wrong port number", -200, false}, 1824 {"duplicate por number", 8022, false}, 1825 } 1826 1827 cfg := &Config{desiredPort: "8022"} 1828 ms := newMainSrv(cfg) 1829 for _, v := range tc { 1830 t.Run(v.label, func(t *testing.T) { 1831 // take the port 1832 ms.listen(cfg) 1833 1834 // validate tc 1835 got := checkPortAvailable(v.port) 1836 if got != v.expect { 1837 t.Errorf("%s expect %t, got %t\n", v.label, v.expect, got) 1838 } 1839 // clear port 1840 ms.conn.Close() 1841 }) 1842 } 1843 } 1844 1845 func TestHandleMessage(t *testing.T) { 1846 1847 tc := []struct { 1848 label string 1849 content string 1850 reason string 1851 }{ 1852 {"no colon", "no colon", "lack of ':'"}, 1853 {"no comma", "no:comma", "lack of ','"}, 1854 {"wrong port number", "no:comma,x", "invalid port number"}, 1855 {"non-existence port number", "no:6000,x", "non-existence port number"}, 1856 {"invalid serve shutdown", _ServeHeader + ":8100,not shutdown", "invalid shutdown"}, 1857 {"kill shell process failed", _ServeHeader + ":8100,shutdown", "kill shell process failed"}, 1858 {"invalid run shutdown", _RunHeader + ":8100,not shutdown", "invalid shutdown"}, 1859 {"invalid shell pid", _ShellHeader + ":8100,x", "invalid shell pid"}, 1860 {"unknown header", "unknow:8100,x", "unknown header"}, 1861 } 1862 1863 cfg := &Config{desiredPort: "8022"} 1864 ms := newMainSrv(cfg) 1865 ms.workers[8100] = &workhorse{shellPid: 0} 1866 // ms.workers[8110] = &workhorse{shellPid: os.Getpid()} 1867 1868 for _, v := range tc { 1869 t.Run(v.label, func(t *testing.T) { 1870 _, err := ms.handleMessage(v.content) 1871 var messagError *messageError 1872 1873 if errors.As(err, &messagError) { 1874 if messagError.reason != v.reason { 1875 t.Errorf("%s expect %q, got %q\n", v.label, v.reason, messagError.reason) 1876 // } else { 1877 // t.Logf("go error %#v\n", messagError.err) 1878 } 1879 } else { 1880 t.Errorf("%s expect %v, got %v\n", v.label, messagError, err) 1881 } 1882 }) 1883 } 1884 } 1885 1886 func TestBeginChild(t *testing.T) { 1887 tc := []struct { 1888 label string 1889 pause int // pause between client send and read 1890 resp string // response for beginClientConn(). 1891 shutdown int // pause before shutdown message 1892 clientConf Config 1893 conf Config 1894 }{ 1895 { 1896 "normal beginClientConn", 100, frontend.AprilshMsgOpen + "7101,", 150, 1897 Config{desiredPort: "7100", term: "xterm-256color", destination: getCurrentUser() + "@localhost"}, 1898 Config{ 1899 version: false, server: false, desiredIP: "", desiredPort: "7100", 1900 locales: localeFlag{"LC_ALL": "en_US.UTF-8", "LANG": "en_US.UTF-8"}, 1901 commandPath: "/bin/sh", commandArgv: []string{"/bin/sh"}, withMotd: false, 1902 // addSource: false, verbose: util.TraceLevel, 1903 }, 1904 }, 1905 } 1906 1907 for _, v := range tc { 1908 t.Run(v.label, func(t *testing.T) { 1909 util.Logger.CreateLogger(io.Discard, true, slog.LevelDebug) 1910 // util.Logger.CreateLogger(os.Stderr, true, slog.LevelDebug) 1911 1912 srv := newMainSrv(&v.conf) 1913 // send shutdown message after some time 1914 timer1 := time.NewTimer(time.Duration(v.shutdown) * time.Millisecond) 1915 go func() { 1916 <-timer1.C 1917 // prepare to shudown the mainSrv 1918 // syscall.Kill(syscall.Getpid(), syscall.SIGTERM) 1919 srv.downChan <- true 1920 }() 1921 1922 srv.start(&v.conf) 1923 1924 // intercept stdout 1925 saveStdout := os.Stdout 1926 r, w, _ := os.Pipe() 1927 os.Stdout = w 1928 1929 beginChild(&v.clientConf) 1930 1931 // restore stdout 1932 w.Close() 1933 output, _ := io.ReadAll(r) 1934 os.Stdout = saveStdout 1935 r.Close() 1936 1937 // validate the result. 1938 resp := strings.TrimSpace(string(output)) 1939 // fmt.Printf("output from beginChild= %q\n", resp) 1940 if !strings.HasPrefix(resp, v.resp) { 1941 t.Errorf("#test beginChild expect start with %q got %q\n", v.resp, resp) 1942 } 1943 srv.wait() 1944 }) 1945 } 1946 } 1947 1948 func TestMainBeginChild(t *testing.T) { 1949 tc := []struct { 1950 label string 1951 resp string // response for beginChild(). 1952 shutdown int // pause before shutdown message 1953 args []string 1954 conf Config 1955 }{ 1956 { 1957 "main begin child", frontend.AprilshMsgOpen + "7151,", 150, 1958 []string{"/usr/bin/apshd", "-b", "-destination", getCurrentUser() + "@localhost", 1959 "-p", "7150", "-t", "xterm-256color", "-vv"}, 1960 Config{ 1961 desiredIP: "", desiredPort: "7150", // autoStop: 1, 1962 locales: localeFlag{"LC_ALL": "en_US.UTF-8", "LANG": "en_US.UTF-8"}, 1963 commandPath: "/bin/sh", commandArgv: []string{"/bin/sh"}, withMotd: false, 1964 // addSource: false, verbose: util.TraceLevel, 1965 }, 1966 }, 1967 } 1968 1969 for _, v := range tc { 1970 t.Run(v.label, func(t *testing.T) { 1971 r, w, _ := os.Pipe() 1972 // save stdout 1973 oldStdout := os.Stdout 1974 1975 // util.Logger.CreateLogger(os.Stderr, true, slog.LevelDebug) 1976 util.Logger.CreateLogger(io.Discard, true, slog.LevelDebug) 1977 1978 srv := newMainSrv(&v.conf) 1979 srv.start(&v.conf) 1980 1981 // send shutdown message after some time 1982 timer1 := time.NewTimer(time.Duration(v.shutdown) * time.Millisecond) 1983 go func() { 1984 <-timer1.C 1985 // prepare to shudown the mainSrv 1986 srv.downChan <- true 1987 }() 1988 1989 testFunc := func() { 1990 os.Args = v.args 1991 os.Stdout = w 1992 main() 1993 1994 // restore stdout 1995 os.Stdout = oldStdout 1996 } 1997 1998 testFunc() 1999 srv.wait() 2000 2001 // close pipe writer, get the output 2002 w.Close() 2003 output, _ := io.ReadAll(r) 2004 r.Close() 2005 2006 // validate the result. 2007 resp := string(output) 2008 if !strings.Contains(resp, v.resp) { 2009 t.Errorf("%q expect start with %q got \n%s\n", v.label, v.resp, resp) 2010 } 2011 }) 2012 } 2013 } 2014 2015 // https://coralogix.com/blog/optimizing-a-golang-service-to-reduce-over-40-cpu/ 2016 func TestRunChild(t *testing.T) { 2017 portStr := "7200" 2018 port, _ := strconv.Atoi(portStr) 2019 serverPortStr := "7100" 2020 2021 tc := []struct { 2022 label string 2023 shutdown int // pause before shutdown message 2024 conf Config // config for mainSrv 2025 childConf Config // config for child 2026 }{ 2027 { 2028 "early shutdown", 100, 2029 Config{ 2030 desiredIP: "", desiredPort: serverPortStr, 2031 locales: localeFlag{"LC_ALL": "en_US.UTF-8", "LANG": "en_US.UTF-8"}, 2032 commandPath: "/bin/sh", commandArgv: []string{"/bin/sh"}, withMotd: false, 2033 addSource: true, verbose: util.DebugLevel, 2034 }, 2035 Config{desiredPort: portStr, term: "xterm", destination: getCurrentUser() + "@localhost", 2036 serve: serve, verbose: 0, addSource: false}, 2037 }, 2038 { 2039 "skip pipe lock", 100, 2040 Config{ 2041 desiredIP: "", desiredPort: serverPortStr, 2042 locales: localeFlag{"LC_ALL": "en_US.UTF-8", "LANG": "en_US.UTF-8"}, 2043 commandPath: "/bin/sh", commandArgv: []string{"/bin/sh"}, withMotd: false, 2044 addSource: true, verbose: util.DebugLevel, 2045 }, 2046 Config{desiredPort: portStr, destination: getCurrentUser() + "@localhost", 2047 commandPath: "/bin/sh", commandArgv: []string{"/bin/sh"}, withMotd: true, 2048 flowControl: _FC_SKIP_PIPE_LOCK, serve: serve, verbose: 0, addSource: false}, 2049 }, 2050 { 2051 "skip start shell", 100, 2052 Config{ 2053 desiredIP: "", desiredPort: serverPortStr, 2054 locales: localeFlag{"LC_ALL": "en_US.UTF-8", "LANG": "en_US.UTF-8"}, 2055 commandPath: "/bin/sh", commandArgv: []string{"/bin/sh"}, withMotd: false, 2056 addSource: true, verbose: util.DebugLevel, 2057 }, 2058 Config{desiredPort: portStr, destination: getCurrentUser() + "@localhost", 2059 commandPath: "/bin/sh", commandArgv: []string{"/bin/sh"}, withMotd: false, 2060 flowControl: _FC_SKIP_START_SHELL, serve: serve, verbose: 0, addSource: false}, 2061 }, 2062 { 2063 "open pts failed", 100, 2064 Config{ 2065 desiredIP: "", desiredPort: serverPortStr, 2066 locales: localeFlag{"LC_ALL": "en_US.UTF-8", "LANG": "en_US.UTF-8"}, 2067 commandPath: "/bin/sh", commandArgv: []string{"/bin/sh"}, withMotd: false, 2068 addSource: true, verbose: util.DebugLevel, 2069 }, 2070 Config{desiredPort: portStr, term: "xterm", destination: getCurrentUser() + "@localhost", 2071 commandPath: "/bin/sh", commandArgv: []string{"/bin/sh"}, withMotd: false, 2072 flowControl: _FC_OPEN_PTS_FAIL, serve: serve, verbose: 0, addSource: false}, 2073 }, 2074 } 2075 2076 for _, v := range tc { 2077 t.Run(v.label, func(t *testing.T) { 2078 util.Logger.CreateLogger(io.Discard, true, slog.LevelDebug) 2079 // util.Logger.CreateLogger(os.Stderr, true, slog.LevelDebug) 2080 2081 srv := newMainSrv(&v.conf) 2082 2083 // listen UDS 2084 uxConn, err := srv.uxListen() 2085 if err != nil { 2086 util.Logger.Warn("listen unix domain socket failed", "error", err) 2087 return 2088 } 2089 2090 // receive UDS feed 2091 srv.wg.Add(1) 2092 go func() { 2093 srv.uxServe(uxConn, 2, func(c chan string, resp string) { 2094 ret, err := srv.handleMessage(resp) 2095 if err != nil { 2096 util.Logger.Warn("fake uxServe failed", "error", err) 2097 return 2098 } 2099 2100 if ret != "" { 2101 util.Logger.Debug("fake uxServe got key", "key", ret) 2102 return 2103 } 2104 2105 // stop uxServe if the worker is done 2106 if resp == _RunHeader+":"+portStr+",shutdown" { 2107 srv.uxdownChan <- true 2108 } 2109 2110 // stop shell process once we got shell pid 2111 if strings.HasPrefix(resp, _ShellHeader+":"+portStr) { 2112 if srv.workers[port].shellPid > 0 { 2113 util.Logger.Debug("fake uxServe kill the shell", "shellPid", srv.workers[port].shellPid) 2114 shell, err := os.FindProcess(srv.workers[port].shellPid) 2115 if err = shell.Kill(); err != nil { 2116 util.Logger.Debug("fake uxServe", "error", err) 2117 } 2118 } 2119 } 2120 }) 2121 srv.wg.Done() 2122 }() 2123 2124 // start runChild 2125 srv.wg.Add(1) 2126 go func() { 2127 // add this worker 2128 srv.workers[port] = &workhorse{} 2129 runChild(&v.childConf) 2130 srv.wg.Done() 2131 }() 2132 2133 if strings.Contains(v.label, "shutdown") { 2134 // send shutdown message after some time 2135 timer1 := time.NewTimer(time.Duration(v.shutdown) * time.Millisecond) 2136 go func() { 2137 <-timer1.C 2138 // prepare to shudown the mainSrv 2139 syscall.Kill(syscall.Getpid(), syscall.SIGTERM) 2140 srv.uxdownChan <- true 2141 }() 2142 } 2143 2144 // validate if we can quit this test 2145 srv.wait() 2146 }) 2147 } 2148 } 2149 2150 func TestRunFail(t *testing.T) { 2151 m := mainSrv{} 2152 cfg := &Config{} 2153 m.run(cfg) 2154 // run return if m.conn is nil 2155 } 2156 2157 func TestUxListenFail(t *testing.T) { 2158 old := unixsockAddr 2159 defer func() { 2160 unixsockAddr = old 2161 }() 2162 2163 unixsockAddr = "/etc/hosts" 2164 m := mainSrv{} 2165 _, err := m.uxListen() 2166 if err == nil { 2167 t.Errorf("uxListen expect error got nil\n") 2168 } 2169 } 2170 2171 func TestRunChildFail(t *testing.T) { 2172 old := unixsockAddr 2173 defer func() { 2174 unixsockAddr = old 2175 }() 2176 2177 unixsockAddr = "/etc/hosts" 2178 err := runChild(&Config{}) 2179 if err == nil { 2180 t.Errorf("uxListen expect error got nil\n") 2181 } 2182 } 2183 2184 func TestMainRunChildFail(t *testing.T) { 2185 old := unixsockAddr 2186 defer func() { 2187 unixsockAddr = old 2188 }() 2189 2190 args := []string{"/usr/bin/apshd", "-c", "-p", "6160", "-vv"} 2191 2192 r, w, _ := os.Pipe() 2193 // save stdout 2194 oldStderr := os.Stderr 2195 2196 // error condition 2197 unixsockAddr = "/etc/hosts" 2198 2199 // run the test 2200 testFunc := func() { 2201 os.Args = args 2202 os.Stderr = w 2203 main() 2204 2205 // restore stdout 2206 os.Stderr = oldStderr 2207 } 2208 testFunc() 2209 2210 // close pipe writer, get the output 2211 w.Close() 2212 output, _ := io.ReadAll(r) 2213 r.Close() 2214 2215 // validate the result 2216 got := string(output) 2217 expect := "init uds client failed" 2218 if !strings.Contains(got, expect) { 2219 t.Errorf("runChild expect %q got %q\n", expect, got) 2220 } 2221 } 2222 2223 func TestStartFail2(t *testing.T) { 2224 2225 // intercept log 2226 var w strings.Builder 2227 util.Logger.CreateLogger(&w, true, slog.LevelDebug) 2228 2229 cfg := &Config{desiredPort: "7230"} 2230 m := mainSrv{} 2231 2232 // this will cause uxListen failed 2233 old := unixsockAddr 2234 defer func() { 2235 unixsockAddr = old 2236 }() 2237 2238 // change unixsocke to error file 2239 unixsockAddr = "/etc/hosts" 2240 m.start(cfg) 2241 // close udp connection 2242 m.conn.Close() 2243 2244 //check the log 2245 got := w.String() 2246 expect := "listen unix domain socket failed" 2247 if !strings.Contains(got, expect) { 2248 t.Errorf("mainSrv.start() expect %q, got \n%s\n", expect, got) 2249 } 2250 } 2251 2252 func TestStartChildFail(t *testing.T) { 2253 tc := []struct { 2254 label string 2255 req string 2256 conf Config 2257 expect string 2258 }{ 2259 {"destination without @", "a:b,cd", 2260 Config{desiredPort: "6510"}, "open aprilsh:malform destination"}, 2261 {"startShellProcess failed: DebugLevel", "open aprilsh:xterm-fake," + getCurrentUser() + "@fakehost", 2262 Config{desiredPort: "6511", verbose: util.DebugLevel}, 2263 "start child got key timeout"}, 2264 {"startShellProcess failed: TraceLevel", "open aprilsh:xterm-fake," + getCurrentUser() + "@fakehost", 2265 Config{desiredPort: "6512", verbose: util.TraceLevel}, 2266 "start child got key timeout"}, 2267 {"startShellProcess failed: addSource", "open aprilsh:xterm-fake," + getCurrentUser() + "@fakehost", 2268 Config{desiredPort: "6513", addSource: true}, 2269 "start child got key timeout"}, 2270 } 2271 2272 for _, v := range tc { 2273 t.Run(v.label, func(t *testing.T) { 2274 // prepare the server 2275 m := newMainSrv(&v.conf) 2276 m.timeout = 10 2277 m.listen(&v.conf) 2278 2279 var wg sync.WaitGroup 2280 2281 var out strings.Builder 2282 // util.Logger.CreateLogger(os.Stderr, true, slog.LevelDebug) 2283 util.Logger.CreateLogger(&out, true, slog.LevelDebug) 2284 2285 // reading and validate the message 2286 wg.Add(1) 2287 go func() { 2288 defer wg.Done() 2289 2290 buf := make([]byte, 128) 2291 shutdown := false 2292 for { 2293 select { 2294 case <-m.downChan: 2295 shutdown = true 2296 default: 2297 } 2298 if shutdown { 2299 util.Logger.Debug("fake receiver shudown") 2300 break 2301 } 2302 2303 m.conn.SetDeadline(time.Now().Add(time.Millisecond * time.Duration(m.timeout))) 2304 m.conn.ReadFromUDP(buf) 2305 } 2306 }() 2307 2308 // run startChild 2309 addr, err := net.ResolveUDPAddr("udp", "localhost:"+v.conf.desiredPort) 2310 if err != nil { 2311 t.Errorf("startChild failed") 2312 } else { 2313 old := os.Getenv("SHELL") 2314 os.Setenv("SHELL", "") 2315 m.startChild(v.req, addr, v.conf) 2316 os.Setenv("SHELL", old) 2317 } 2318 2319 // shudown reader 2320 m.downChan <- true 2321 wg.Wait() 2322 m.conn.Close() 2323 2324 // validate the result 2325 got := out.String() 2326 if !strings.Contains(got, v.expect) { 2327 t.Errorf("startChild expect %q, got \n%s\n", v.expect, got) 2328 } 2329 }) 2330 } 2331 } 2332 2333 func TestBuildConfig2(t *testing.T) { 2334 cfg := &Config{flowControl: _FC_NON_UTF8_LOCALE} 2335 2336 r, w, _ := os.Pipe() 2337 // save stdout 2338 olderr := os.Stderr 2339 oldout := os.Stdout 2340 os.Stderr = w 2341 os.Stdout = w 2342 2343 _, ok := cfg.buildConfig() 2344 2345 // close pipe writer, get the output 2346 w.Close() 2347 output, _ := io.ReadAll(r) 2348 r.Close() 2349 2350 os.Stderr = olderr 2351 os.Stdout = oldout 2352 2353 // validate the result 2354 got := string(output) 2355 expect := "needs a UTF-8 native locale to run" 2356 if !ok && strings.Contains(got, expect) { 2357 } else { 2358 t.Errorf("runChild expect %q got \n%s\n", expect, got) 2359 } 2360 } 2361 2362 func TestMessageError(t *testing.T) { 2363 tc := []struct { 2364 label string 2365 e *messageError 2366 expect string 2367 }{ 2368 {"nil error", &messageError{}, "<nil>"}, 2369 {"reason + error", &messageError{reason: "got apple", err: errors.New("bad apple")}, "got apple: bad apple"}, 2370 {"only error", &messageError{err: errors.New("just apple")}, ": just apple"}, 2371 } 2372 2373 for _, v := range tc { 2374 t.Run(v.label, func(t *testing.T) { 2375 got := v.e.Error() 2376 if got != v.expect { 2377 t.Errorf("messageError sould return %q got %q\n", v.expect, got) 2378 } 2379 }) 2380 } 2381 } 2382 2383 func TestCloseChild(t *testing.T) { 2384 tc := []struct { 2385 label string 2386 req string 2387 holders []int 2388 conf *Config 2389 expect string 2390 }{ 2391 {"placeHolder port", frontend.AprishMsgClose + "6252", []int{6252}, 2392 &Config{desiredPort: "6250"}, "close port is a holder"}, 2393 {"wrong port number", frontend.AprishMsgClose + "625a", nil, 2394 &Config{desiredPort: "6250"}, "wrong port number"}, 2395 {"port doesn't exist", frontend.AprishMsgClose + "6252", nil, 2396 &Config{desiredPort: "6250"}, "port does not exist"}, 2397 } 2398 2399 for _, v := range tc { 2400 t.Run(v.label, func(t *testing.T) { 2401 // prepare the server 2402 m := newMainSrv(v.conf) 2403 m.listen(v.conf) 2404 2405 var wg sync.WaitGroup 2406 2407 var out strings.Builder 2408 // util.Logger.CreateLogger(os.Stderr, true, slog.LevelDebug) 2409 util.Logger.CreateLogger(&out, true, slog.LevelDebug) 2410 2411 // create place holders data 2412 for _, value := range v.holders { 2413 m.workers[value] = &workhorse{} 2414 } 2415 // reading the udp response 2416 wg.Add(1) 2417 go func() { 2418 defer wg.Done() 2419 2420 buf := make([]byte, 128) 2421 shutdown := false 2422 for { 2423 select { // waiting for shutdown 2424 case <-m.downChan: 2425 shutdown = true 2426 default: 2427 } 2428 if shutdown { 2429 util.Logger.Debug("fake receiver shudown") 2430 break 2431 } 2432 2433 m.conn.SetDeadline(time.Now().Add(time.Millisecond * time.Duration(m.timeout))) 2434 m.conn.ReadFromUDP(buf) 2435 } 2436 }() 2437 2438 // run closeChild 2439 addr, err := net.ResolveUDPAddr("udp", "localhost:"+v.conf.desiredPort) 2440 if err != nil { 2441 t.Errorf("get address fail: %s\n", err) 2442 } else { 2443 m.closeChild(v.req, addr) 2444 } 2445 2446 // shudown reader 2447 m.downChan <- true 2448 wg.Wait() 2449 m.conn.Close() 2450 2451 // validate the result 2452 got := out.String() 2453 // fmt.Println(got) 2454 if !strings.Contains(got, v.expect) { 2455 t.Errorf("startChild expect %q, got \n%s\n", v.expect, got) 2456 } 2457 }) 2458 } 2459 }