github.com/ericwq/aprilsh@v0.0.0-20240517091432-958bc568daa0/frontend/client/client_test.go (about) 1 // Copyright 2022~2024 wangqi. All rights reserved. 2 // Use of this source code is governed by a MIT-style 3 // license that can be found in the LICENSE file. 4 5 package main 6 7 import ( 8 "errors" 9 "io" 10 "os" 11 "strings" 12 "sync" 13 "testing" 14 "time" 15 16 "github.com/creack/pty" 17 "github.com/ericwq/aprilsh/frontend" 18 ) 19 20 func TestPrintColors(t *testing.T) { 21 tc := []struct { 22 label string 23 term string 24 expect []string 25 }{ 26 {"lookup terminfo failed", "NotExist", []string{"Dynamic load terminfo failed."}}, 27 {"TERM is empty", "", []string{"The TERM is empty string."}}, 28 {"TERM doesn't exit", "-remove", []string{"The TERM doesn't exist."}}, 29 {"normal found", "xterm-256color", []string{"xterm-256color", "256"}}, 30 // {"dynamic found", "xfce", []string{"xfce 8 (dynamic)"}}, 31 {"dynamic not found", "xxx", []string{"Dynamic load terminfo failed."}}, 32 } 33 34 for _, v := range tc { 35 t.Run(v.label, func(t *testing.T) { 36 // intercept stdout 37 saveStdout := os.Stdout 38 r, w, _ := os.Pipe() 39 os.Stdout = w 40 // save original TERM 41 term := os.Getenv("TERM") 42 43 // set TERM according to test case 44 if v.term == "-remove" { 45 os.Unsetenv("TERM") 46 } else { 47 os.Setenv("TERM", v.term) 48 } 49 50 printColors() 51 52 // restore stdout 53 w.Close() 54 b, _ := io.ReadAll(r) 55 os.Stdout = saveStdout 56 r.Close() 57 58 // validate the result 59 result := string(b) 60 found := 0 61 for i := range v.expect { 62 if strings.Contains(result, v.expect[i]) { 63 found++ 64 } 65 } 66 if found != len(v.expect) { 67 t.Errorf("#test %s expect %q, got %q\n", v.label, v.expect, result) 68 } 69 70 // restore original TERM 71 os.Setenv("TERM", term) 72 }) 73 } 74 } 75 76 func TestMainRun_Parameters(t *testing.T) { 77 tc := []struct { 78 label string 79 args []string 80 term string 81 expect []string 82 }{ 83 { 84 "no parameters", 85 []string{frontend.CommandClientName}, 86 "xterm-256color", 87 []string{"destination (user@host[:port]) is mandatory."}, 88 }, 89 { 90 "just version", 91 []string{frontend.CommandClientName, "-version"}, 92 "xterm-256color", 93 []string{ 94 frontend.CommandClientName, frontend.AprilshPackageName, 95 "Copyright (c) 2022~2024 wangqi <ericwq057@qq.com>", "remote shell support intermittent or mobile network.", 96 }, 97 }, 98 { 99 "just help", 100 []string{frontend.CommandClientName, "-h"}, 101 "xterm-256color", 102 []string{ 103 "Usage:", frontend.CommandClientName, "Options:", "-c", "--colors", 104 "print the number of terminal color", 105 }, 106 }, 107 { 108 "just colors", 109 []string{frontend.CommandClientName, "-c", "-v"}, 110 "xterm-256color", 111 []string{"xterm-256color", "256"}, 112 }, 113 { 114 "invalid target parameter", 115 []string{frontend.CommandClientName, "invalid", "target", "parameter"}, 116 "xterm-256color", 117 []string{"only one destination (user@host[:port]) is allowed."}, 118 }, 119 { 120 "destination no second part", 121 []string{frontend.CommandClientName, "malform@"}, 122 "xterm-256color", 123 []string{"destination should be in the form of user@host[:port]"}, 124 }, 125 { 126 "destination no first part", 127 []string{frontend.CommandClientName, "@malform"}, 128 "xterm-256color", 129 []string{"destination should be in the form of user@host[:port]"}, 130 }, 131 { 132 "infvalid port number", 133 []string{frontend.CommandClientName, "-p", "7s"}, 134 "xterm-256color", 135 []string{"invalid value \"7s\" for flag -p: parse error"}, 136 }, 137 } 138 139 for _, v := range tc { 140 t.Run(v.label, func(t *testing.T) { 141 // intercept stdout 142 saveStdout := os.Stdout 143 r, w, _ := os.Pipe() 144 os.Stdout = w 145 146 // prepare data 147 os.Args = v.args 148 os.Setenv("TERM", v.term) 149 // test main 150 main() 151 152 // restore stdout 153 w.Close() 154 out, _ := io.ReadAll(r) 155 os.Stdout = saveStdout 156 r.Close() 157 158 // validate the result 159 result := string(out) 160 found := 0 161 for i := range v.expect { 162 if strings.Contains(result, v.expect[i]) { 163 // fmt.Printf("found %s\n", expect[i]) 164 found++ 165 } 166 } 167 if found != len(v.expect) { 168 t.Errorf("#test expect %s, got \n%s\n", v.expect, result) 169 } 170 }) 171 } 172 } 173 174 func TestBuildConfig(t *testing.T) { 175 targetMsg := "destination should be in the form of user@host[:port]" 176 modeMsg := _PREDICTION_DISPLAY + " unknown prediction mode." 177 tc := []struct { 178 label string 179 target string 180 predictMode string 181 expect string 182 ok bool 183 }{ 184 {"valid target, empty mode", "usr@localhost", "", "", true}, 185 {"valid target, lack of mode", "gig@factory", "mode", modeMsg, false}, 186 {"valid target, valid mode", "vfab@factory", "aLwaYs", "", true}, 187 {"invalid target", "factory", "", targetMsg, false}, 188 {"invalid @target", "@factory", "", targetMsg, false}, 189 {"invalid target@", "factory@", "", targetMsg, false}, 190 } 191 192 for _, v := range tc { 193 t.Run(v.label, func(t *testing.T) { 194 var conf Config 195 conf.destination = []string{v.target} 196 197 // prepare parse result 198 var host string 199 var user string 200 idx := strings.Index(v.target, "@") 201 if idx > 0 && idx < len(v.target)-1 { 202 host = v.target[idx+1:] 203 user = v.target[:idx] 204 } 205 206 os.Setenv(_PREDICTION_DISPLAY, v.predictMode) 207 208 got, ok := conf.buildConfig() 209 if got != v.expect { 210 t.Errorf("#test buildConfig() %s expect %q, got %s\n", v.label, v.expect, got) 211 } 212 if conf.user != user || conf.host != host { 213 t.Errorf("#test buildConfig() %q config.user expect %s, got %s\n", v.label, user, conf.user) 214 t.Errorf("#test buildConfig() %q config.host expect %s, got %s\n", v.label, host, conf.host) 215 } 216 if conf.predictMode != strings.ToLower(v.predictMode) { 217 t.Errorf("#test buildConfig() conf.predictMode expect %q, got %q\n", v.predictMode, conf.predictMode) 218 } 219 if ok != v.ok { 220 t.Errorf("#test buildConfig() expect %t, got %t\n", v.ok, ok) 221 } 222 }) 223 } 224 } 225 226 func TestBuildConfig2(t *testing.T) { 227 tc := []struct { 228 label string 229 conf *Config 230 expectStr string 231 ok bool 232 }{ 233 {"destination without port", &Config{destination: []string{"usr@host"}}, "", true}, 234 {"destination with port", &Config{destination: []string{"usr@host:23"}}, "", true}, 235 {"destination with wrong port", 236 &Config{destination: []string{"usr@host:a23"}}, "please check destination, illegal port number.", false}, 237 } 238 for _, v := range tc { 239 t.Run(v.label, func(t *testing.T) { 240 got, ok := v.conf.buildConfig() 241 if ok != v.ok || got != v.expectStr { 242 t.Errorf("%q expect (%s,%t) got (%s,%t)\n", v.label, v.expectStr, v.ok, got, ok) 243 } 244 }) 245 } 246 } 247 248 // func TestFetchKey(t *testing.T) { 249 // tc := []struct { 250 // label string 251 // conf *Config 252 // pwd string 253 // msg string 254 // }{ 255 // {"wrong host", &Config{user: "ide", host: "wrong", port: 60000}, "password", "dial tcp"}, 256 // } 257 // for _, v := range tc { 258 // t.Run(v.label, func(t *testing.T) { 259 // v.conf.pwd = v.pwd 260 // got := v.conf.fetchKey() 261 // if !strings.Contains(got.Error(), v.msg) { 262 // t.Errorf("#test %q expect %q contains %q.\n", v.label, got, v.msg) 263 // } 264 // }) 265 // } 266 // } 267 268 func TestGetPassword(t *testing.T) { 269 270 tc := []struct { 271 label string 272 conf *Config 273 pwd string //input 274 expect string 275 }{ 276 {"normal get password", &Config{}, "password\n", "password"}, 277 {"just CR", &Config{}, "\n", ""}, 278 } 279 for _, v := range tc { 280 t.Run(v.label, func(t *testing.T) { 281 // intercept stdout 282 saveStdout := os.Stdout 283 r, w, _ := os.Pipe() 284 os.Stdout = w 285 286 // get password require pts file. 287 ptmx, pts, err := pty.Open() 288 if err != nil { 289 err = errors.New("invalid parameter") 290 } 291 292 // prepare input data 293 ptmx.WriteString(v.pwd) 294 295 got, err := getPassword("password", pts) 296 297 ptmx.Close() 298 pts.Close() 299 300 // restore stdout 301 w.Close() 302 out, _ := io.ReadAll(r) 303 os.Stdout = saveStdout 304 r.Close() 305 306 // validate the result. 307 if err != nil { 308 t.Errorf("#test %q report %s\n", v.label, err) 309 } 310 if got != v.expect { 311 t.Errorf("#test %q expect %q, got %q. out=%s\n", v.label, v.expect, got, out) 312 } 313 314 }) 315 } 316 } 317 318 func TestGetPasswordFail(t *testing.T) { 319 // conf := &Config{} 320 321 // intercept stdout 322 saveStdout := os.Stdout 323 r, w, _ := os.Pipe() 324 os.Stdout = w 325 326 got, err := getPassword("password", r) 327 328 // restore stdout 329 w.Close() 330 out, _ := io.ReadAll(r) 331 os.Stdout = saveStdout 332 r.Close() 333 334 // validate, for non-tty input, getPassword return err: inappropriate ioctl for device 335 if err == nil { 336 t.Errorf("#test getPassword fail expt %q, got=%q, err=%s, out=%s\n", "", got, err, out) 337 } 338 } 339 340 func TestGetPasswordFail2(t *testing.T) { 341 // store stdout/in, open pts pair 342 ptmx, pts, err := pty.Open() 343 if err != nil { 344 t.Errorf("failed to open pts, %s\n", err) 345 return 346 } 347 saveStdout := os.Stdout 348 saveStdin := os.Stdin 349 os.Stdout = pts 350 os.Stdin = pts 351 352 expect := "hello world" 353 354 // provide the input 355 var wg sync.WaitGroup 356 wg.Add(1) 357 go func() { 358 defer wg.Done() 359 // make sure we provide input after the getPassword() 360 timer := time.NewTimer(time.Duration(2) * time.Millisecond) 361 <-timer.C 362 ptmx.WriteString(expect + "\n") // \n is important for getPassword() 363 }() 364 365 // waiting for the input 366 wg.Add(1) 367 var got string 368 var err2 error 369 go func() { 370 defer wg.Done() 371 got, err2 = getPassword("password", pts) 372 }() 373 wg.Wait() 374 375 // close pts paire and restore stdou/stdin 376 ptmx.Close() 377 pts.Close() 378 os.Stdout = saveStdout 379 os.Stdin = saveStdin 380 381 // validate, for non-tty input, getPassword return err: inappropriate ioctl for device 382 if err2 != nil || got != expect { 383 t.Errorf("#test getPassword fail expt %q, got=%q, err=%s\n", expect, got, err) 384 } 385 } 386 387 func TestSshAgentFail(t *testing.T) { 388 tc := []struct { 389 label string 390 env bool 391 expect string 392 }{ 393 {"lack of SSH_AUTH_SOCK", false, "Failed to connect ssh agent."}, 394 } 395 for _, v := range tc { 396 t.Run(v.label, func(t *testing.T) { 397 old := os.Getenv("SSH_AUTH_SOCK") 398 defer os.Setenv("SSH_AUTH_SOCK", old) 399 400 // intercept stdout 401 saveStdout := os.Stdout 402 r, w, _ := os.Pipe() 403 os.Stdout = w 404 405 // clear SSH_AUTH_SOCK 406 if !v.env { 407 os.Unsetenv("SSH_AUTH_SOCK") 408 } 409 // run the test 410 sshAgent() 411 412 // restore stdout 413 w.Close() 414 out, _ := io.ReadAll(r) 415 os.Stdout = saveStdout 416 r.Close() 417 418 got := string(out) 419 if !strings.HasPrefix(got, v.expect) { 420 t.Errorf("%q expect %q got %q\n", v.label, v.expect, got) 421 } 422 }) 423 } 424 } 425 426 func TestErrors(t *testing.T) { 427 tc := []struct { 428 label string 429 error error 430 expect string 431 }{ 432 {"hostkeyChangeError", &hostkeyChangeError{hostname: "some.where"}, 433 "REMOTE HOST IDENTIFICATION HAS CHANGED"}, 434 {"responseErr without error", &responseError{}, "<nil>"}, 435 {"responseErr error", &responseError{Msg: "hello", Err: errors.New("world")}, "hello, world"}, 436 } 437 for _, v := range tc { 438 t.Run(v.label, func(t *testing.T) { 439 440 got := v.error.Error() 441 if !strings.Contains(got, v.expect) { 442 t.Errorf("%q expect %q got %q\n", v.label, v.expect, got) 443 } 444 445 }) 446 } 447 } 448 449 func TestPublicKeyFileFail(t *testing.T) { 450 tc := []struct { 451 label string 452 file string 453 expect string 454 }{ 455 {"file doesn't exist", "/do/es/not/exist", "Unable to read private key"}, 456 {"is not private key", "/etc/hosts", "Unable to parse private key"}, 457 } 458 for _, v := range tc { 459 t.Run(v.label, func(t *testing.T) { 460 461 // intercept stdout 462 saveStdout := os.Stdout 463 r, w, _ := os.Pipe() 464 os.Stdout = w 465 466 // run the test 467 publicKeyFile(v.file) 468 469 // restore stdout 470 w.Close() 471 out, _ := io.ReadAll(r) 472 os.Stdout = saveStdout 473 r.Close() 474 475 // validate the output 476 got := string(out) 477 if !strings.Contains(got, v.expect) { 478 t.Errorf("%q expect %q got %q\n", v.label, v.expect, got) 479 } 480 }) 481 } 482 }