github.com/ericwq/aprilsh@v0.0.0-20240517091432-958bc568daa0/frontend/client/client.go (about) 1 // Copyright 2022~2024 wangqi. All rights reserved. 2 // Use of this source code is governed by a MIT-style 3 // license that can be found in the LICENSE file. 4 5 package main 6 7 import ( 8 "bytes" 9 "errors" 10 "flag" 11 "fmt" 12 "io" 13 "net" 14 "os" 15 "os/signal" 16 "path/filepath" 17 "strconv" 18 "strings" 19 "syscall" 20 "time" 21 22 "log/slog" 23 24 "github.com/ericwq/aprilsh/encrypt" 25 "github.com/ericwq/aprilsh/frontend" 26 "github.com/ericwq/aprilsh/network" 27 "github.com/ericwq/aprilsh/statesync" 28 "github.com/ericwq/aprilsh/terminal" 29 "github.com/ericwq/aprilsh/util" 30 "github.com/rivo/uniseg" 31 "github.com/skeema/knownhosts" 32 "golang.org/x/crypto/ssh" 33 "golang.org/x/crypto/ssh/agent" 34 xknownhosts "golang.org/x/crypto/ssh/knownhosts" 35 "golang.org/x/sync/errgroup" 36 "golang.org/x/sys/unix" 37 "golang.org/x/term" 38 ) 39 40 const ( 41 _APRILSH_KEY = "APRISH_KEY" 42 _PREDICTION_DISPLAY = "APRISH_PREDICTION_DISPLAY" 43 _PREDICTION_OVERWRITE = "APRISH_PREDICTION_OVERWRITE" 44 ) 45 46 var ( 47 usage = `Usage: 48 ` + frontend.CommandClientName + ` [--version] [--help] [--colors] 49 ` + frontend.CommandClientName + ` [-v[v]] [--port PORT] [-i identity_file] destination 50 Options: 51 --------------------------------------------------------------------------------------------------- 52 -h, --help print this message 53 -c, --colors print the number of terminal color 54 --version print version information 55 --------------------------------------------------------------------------------------------------- 56 -p, --port apshd server port (default 8100) 57 destination in the form of user@host[:port], here the port is ssh server port (default 22) 58 -i ssh client identity (private key) (default $HOME/.ssh/id_rsa) 59 -v, --verbose verbose log output (debug level, default info level) 60 -vv verbose log output (trace level) 61 -m, --mapping container port mapping (default 0, new port = returned port + mapping) 62 --------------------------------------------------------------------------------------------------- 63 ` 64 predictionValues = []string{"always", "never", "adaptive", "experimental"} 65 defaultSSHClientID = filepath.Join(os.Getenv("HOME"), ".ssh", "id_rsa") 66 signals frontend.Signals 67 ) 68 69 func printVersion() { 70 fmt.Printf("%s package : %s client, %s\n", 71 frontend.AprilshPackageName, frontend.AprilshPackageName, frontend.CommandClientName) 72 frontend.PrintVersion() 73 } 74 75 func printColors() { 76 value, ok := os.LookupEnv("TERM") 77 if ok { 78 if value != "" { 79 // ti, err := terminfo.LookupTerminfo(value) 80 ti, err := terminal.LookupTerminfo(value) 81 if err == nil { 82 fmt.Printf("%s %d\n", value, ti.Colors) 83 } else { 84 fmt.Printf("Dynamic load terminfo failed. %s Install infocmp (ncurses package) first.\n", err) 85 } 86 } else { 87 fmt.Println("The TERM is empty string.") 88 } 89 } else { 90 fmt.Println("The TERM doesn't exist.") 91 } 92 } 93 94 func parseFlags(progname string, args []string) (config *Config, output string, err error) { 95 // https://eli.thegreenplace.net/2020/testing-flag-parsing-in-go-programs/ 96 flagSet := flag.NewFlagSet(progname, flag.ContinueOnError) 97 var buf bytes.Buffer 98 flagSet.SetOutput(&buf) 99 100 var conf Config 101 102 var v1, v2 bool 103 flagSet.BoolVar(&v1, "v", false, "verbose log output debug level") 104 flagSet.BoolVar(&v1, "verbose", false, "verbose log output debug levle") 105 flagSet.BoolVar(&v2, "vv", false, "verbose log output trace level") 106 107 flagSet.BoolVar(&conf.version, "version", false, "print version information") 108 109 flagSet.BoolVar(&conf.addSource, "source", false, "add source info to log") 110 111 flagSet.IntVar(&conf.port, "port", frontend.DefaultPort, frontend.CommandServerName+" server port") 112 flagSet.IntVar(&conf.port, "p", frontend.DefaultPort, frontend.CommandServerName+" server port") 113 114 flagSet.BoolVar(&conf.colors, "color", false, "terminal colors number") 115 flagSet.BoolVar(&conf.colors, "c", false, "terminal colors number") 116 117 flagSet.StringVar(&conf.sshClientID, "i", defaultSSHClientID, "ssh client identity file") 118 flagSet.IntVar(&conf.mapping, "mapping", 0, "container port mapping") 119 flagSet.IntVar(&conf.mapping, "m", 0, "container port mapping") 120 121 err = flagSet.Parse(args) 122 if err != nil { 123 return nil, buf.String(), err 124 } 125 126 // get the non-flag command-line arguments. 127 conf.destination = flagSet.Args() 128 129 // detremine verbose level 130 if v1 { 131 conf.verbose = util.DebugLevel 132 } else if v2 { 133 conf.verbose = util.TraceLevel 134 } 135 return &conf, buf.String(), nil 136 } 137 138 type Config struct { 139 version bool 140 destination []string // raw parameter 141 host string // target host/server 142 user string // target user 143 port int // first server port, then target port 144 verbose int 145 colors bool 146 key string 147 predictMode string 148 predictOverwrite string 149 sshClientID string // ssh client identity, for SSH public key authentication 150 sshPort string // ssh port, default 22 151 addSource bool // add source file to log 152 mapping int // container(such as docker) port mapping value 153 } 154 155 var errNoResponse = errors.New("no response, please make sure the server is running.") 156 157 type hostkeyChangeError struct { 158 hostname string 159 } 160 161 func (e *hostkeyChangeError) Error() string { 162 return "REMOTE HOST IDENTIFICATION HAS CHANGED for host '" + 163 e.hostname + "' ! This may indicate a MITM attack." 164 } 165 166 // func (e *hostkeyChangeError) Hostname() string { return e.hostname } 167 168 type responseError struct { 169 Err error 170 Msg string 171 } 172 173 func (e *responseError) Error() string { 174 if e.Err == nil { 175 return "<nil>" 176 } 177 return e.Msg + ", " + e.Err.Error() 178 } 179 180 // utilize ssh to fetch the key from remote server and start a server. 181 // return empty string if success, otherwise return error info. 182 // 183 // For alpine, ssh is provided by openssh package, nc and echo is provided by busybox. 184 // % ssh ide@localhost "echo 'open aprilsh:' | nc localhost 6000 -u -w 1" 185 // 186 // ssh-keygen -t ed25519 187 // ssh-copy-id -i ~/.ssh/id_ed25519.pub root@localhost 188 // ssh-copy-id -i ~/.ssh/id_ed25519.pub ide@localhost 189 // ssh-add ~/.ssh/id_ed25519 190 func (c *Config) fetchKey() error { 191 var auth []ssh.AuthMethod 192 auth = make([]ssh.AuthMethod, 0) 193 194 if c.sshClientID != defaultSSHClientID { 195 if am := publicKeyFile(c.sshClientID); am != nil { 196 auth = append(auth, am) // public key first 197 fmt.Printf("public key first, %s, %s\n", am, c.sshClientID) 198 } 199 if am := sshAgent(); am != nil { 200 auth = append(auth, am) // ssh agent second 201 fmt.Printf("ssh agent second, %s\n", am) 202 } 203 } else { 204 if am := sshAgent(); am != nil { 205 auth = append(auth, am) // ssh agent first 206 fmt.Printf("ssh agent first, %s\n", am) 207 } 208 if am := publicKeyFile(c.sshClientID); am != nil { 209 auth = append(auth, am) // public key second 210 fmt.Printf("public key second, %s, %s\n", am, c.sshClientID) 211 } 212 } 213 214 if len(auth) == 0 { 215 // get password if we don't have any authenticate method 216 pwd, err := getPassword("password", os.Stdin) 217 if err != nil { 218 return err 219 } 220 221 // password authentication is the last resort 222 if am := ssh.Password(pwd); am != nil { 223 auth = append(auth, am) 224 fmt.Printf("password auth last, %s\n", am) 225 } 226 } 227 228 fmt.Printf("c.sshClientID=%s, defaultSSHClientID=%s, eq=%t\n", c.sshClientID, defaultSSHClientID, 229 c.sshClientID == defaultSSHClientID) 230 231 // prepare for knownhosts 232 sshHost := net.JoinHostPort(c.host, c.sshPort) 233 khPath := filepath.Join(os.Getenv("HOME"), ".ssh", "known_hosts") 234 if _, err := os.Stat(khPath); err != nil { 235 kh, err2 := os.Create(khPath) 236 if err2 != nil { 237 return err 238 } 239 kh.Close() 240 } 241 kh, err := knownhosts.New(khPath) 242 if err != nil { 243 return err 244 } 245 246 // https://github.com/skeema/knownhosts 247 // https://github.com/golang/go/issues/29286 248 // 249 // Create a custom permissive hostkey callback which still errors on hosts 250 // with changed keys, but allows unknown hosts and adds them to known_hosts 251 cb := ssh.HostKeyCallback(func(hostname string, remote net.Addr, key ssh.PublicKey) (err error) { 252 err = kh(hostname, remote, key) 253 if knownhosts.IsHostKeyChanged(err) { 254 return &hostkeyChangeError{hostname: hostname} 255 } else if knownhosts.IsHostUnknown(err) { 256 257 hint := "The authenticity of host '%s (%s)' can't be established.\n" + 258 "%s key fingerprint is %s.\n" + 259 "This key is not known by any other names\n" + 260 "Are you sure you want to continue connecting (yes/no/[fingerprint])?" 261 fmt.Printf(hint, hostname, remote, strings.ToUpper(key.Type()), ssh.FingerprintSHA256(key)) 262 263 var answer string 264 fmt.Scanln(&answer) 265 switch answer { 266 case "yes", "y": 267 f, ferr := os.OpenFile(khPath, os.O_APPEND|os.O_WRONLY, 0600) 268 if ferr == nil { 269 defer f.Close() 270 ferr = knownhosts.WriteKnownHost(f, hostname, remote, key) 271 } 272 if ferr == nil { 273 fmt.Printf("Warning: Permanently added '%s' (%s) to the list of known hosts.\n", 274 hostname, strings.ToUpper(key.Type())) 275 err = nil // permit previously-unknown hosts (warning: may be insecure) 276 } else { 277 fmt.Printf("Failed to add host %s to known_hosts: %v\n", hostname, ferr) 278 err = ferr 279 } 280 case "no", "n": 281 fallthrough 282 default: 283 fmt.Println("Host key verification failed.") 284 } 285 } 286 return 287 }) 288 289 // https://betterprogramming.pub/a-simple-cross-platform-ssh-client-in-100-lines-of-go-280644d8beea 290 // https://blog.ralch.com/articles/golang-ssh-connection/ 291 // https://www.ssh.com/blog/what-are-ssh-host-keys 292 clientConfig := &ssh.ClientConfig{ 293 User: c.user, 294 Auth: auth, 295 HostKeyCallback: cb, 296 HostKeyAlgorithms: kh.HostKeyAlgorithms(sshHost), 297 Timeout: time.Duration(3) * time.Second, 298 } 299 300 // TODO understand ssh login session, is that possible to replace the sshd depdends? 301 client, err := ssh.Dial("tcp", sshHost, clientConfig) 302 if err != nil { 303 return err 304 } 305 defer client.Close() 306 307 // Each ClientConn can support multiple interactive sessions, 308 // represented by a Session. 309 session, err := client.NewSession() 310 if err != nil { 311 return err 312 } 313 defer session.Close() 314 315 // Once a Session is created, you can execute a single command on 316 // the remote side using the Run method. 317 // before fetchKey() it's the server port, after it's target port 318 var b []byte 319 cmd := fmt.Sprintf("/usr/bin/apshd -b -t %s -destination %s -p %d", 320 os.Getenv("TERM"), c.destination[0], c.port) 321 // fmt.Printf("cmd=%s\n", cmd) 322 323 if b, err = session.Output(cmd); err != nil { 324 return err 325 } 326 out := strings.TrimSpace(string(b)) 327 328 // open aprilsh:60001,31kR3xgfmNxhDESXQ8VIQw== 329 body := strings.Split(out, ":") 330 if len(body) != 2 || body[0] != frontend.AprilshMsgOpen[:12] { // [:12]remove the last ':' 331 return errors.New(fmt.Sprintf("response: %s", out)) 332 } 333 334 // parse port and key 335 content := strings.Split(body[1], ",") 336 if len(content) == 2 && len(content[0]) > 0 && len(content[1]) > 0 { 337 p, e := strconv.Atoi(content[0]) 338 if e != nil { 339 return errors.New("can't get port") 340 } 341 // calculate new port based on container mapping value 342 // new port = returned port + mapping 343 // 8201 = 8101 + 100 344 c.port = p + c.mapping 345 346 if encrypt.NewBase64Key2(content[1]) != nil { 347 c.key = content[1] 348 } else { 349 return errors.New("can't get key") 350 } 351 // fmt.Printf("fetchKey port=%d, key=%s\n", c.port, c.key) 352 } else { 353 return errors.New(fmt.Sprintf("response: %s", body[1])) 354 } 355 356 return nil 357 } 358 359 func (c *Config) buildConfig() (string, bool) { 360 // just need version info 361 if c.version { 362 return "", true 363 } 364 365 // just need terminal number of colors 366 if c.colors { 367 return "", true 368 } 369 370 if len(c.destination) == 0 { 371 return "destination (user@host[:port]) is mandatory.", false 372 } 373 374 if len(c.destination) != 1 { 375 return "only one destination (user@host[:port]) is allowed.", false 376 } 377 378 // check destination 379 first := strings.Split(c.destination[0], "@") 380 if len(first) == 2 && len(first[0]) > 0 && len(first[1]) > 0 { 381 c.user = first[0] 382 second := strings.Split(first[1], ":") 383 c.host = second[0] 384 if len(second) == 1 { 385 c.sshPort = "22" // default ssh port 386 } else { 387 if _, err := strconv.Atoi(second[1]); err != nil { 388 return "please check destination, illegal port number.", false 389 } 390 c.sshPort = second[1] 391 } 392 } else { 393 return "destination should be in the form of user@host[:port]", false 394 } 395 396 // Read key from environment 397 // c.key = os.Getenv(_APRILSH_KEY) 398 // if c.key == "" { 399 // return _APRILSH_KEY + " environment variable not found.", false 400 // } 401 // os.Unsetenv(_APRILSH_KEY) 402 403 // Read prediction preference, predictMode can be empty 404 foundInScope := false 405 c.predictMode = strings.ToLower(os.Getenv(_PREDICTION_DISPLAY)) 406 if c.predictMode != "" { 407 // if predictMode is not empty string, it's must be one of predictionValues 408 for i := range predictionValues { 409 if predictionValues[i] == c.predictMode { 410 foundInScope = true 411 } 412 } 413 if !foundInScope { 414 return _PREDICTION_DISPLAY + " unknown prediction mode.", false 415 } 416 } 417 418 // Read prediction insertion preference. can be "" 419 c.predictOverwrite = strings.ToLower(os.Getenv(_PREDICTION_OVERWRITE)) 420 421 return "", true 422 } 423 424 // read password from specified input source 425 func getPassword(prompt string, in *os.File) (string, error) { 426 fmt.Printf("%s: ", prompt) 427 bytepw, err := term.ReadPassword(int(in.Fd())) 428 defer fmt.Printf("\n") 429 430 if err != nil { 431 return "", err 432 } 433 434 return string(bytepw), nil 435 } 436 437 func sshAgent() ssh.AuthMethod { 438 sshAgent, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")) 439 if err != nil { 440 fmt.Printf("Failed to connect ssh agent. %s\n", err) 441 return nil 442 } 443 return ssh.PublicKeysCallback(agent.NewClient(sshAgent).Signers) 444 } 445 446 func publicKeyFile(file string) ssh.AuthMethod { 447 key, err := os.ReadFile(file) 448 if err != nil { 449 fmt.Printf("Unable to read private key: %s\n", err) 450 return nil 451 } 452 453 signer, err := ssh.ParsePrivateKey(key) 454 if err != nil { 455 if strings.Contains(err.Error(), "private key is passphrase protected") { 456 passphrase, err2 := getPassword("passphrase", os.Stdin) 457 if err2 != nil { 458 fmt.Printf("Failed to get passphrase. %s\n", err2) 459 return nil // read passphrase error 460 } 461 signer, err = ssh.ParsePrivateKeyWithPassphrase(key, []byte(passphrase)) 462 if err != nil { 463 fmt.Printf("Failed to parse private key. %s\n", err) 464 return nil 465 } 466 } else { 467 fmt.Printf("Unable to parse private key: %s\n", err) 468 return nil 469 } 470 } 471 return ssh.PublicKeys(signer) // Use the PublicKeys method for remote authentication. 472 } 473 474 type STMClient struct { 475 ip string 476 port int 477 key string 478 479 escapeKey int 480 escapePassKey int 481 escapePassKey2 int 482 escapeRequireslf bool 483 escapeKeyHelp string 484 485 savedTermios *term.State // store the original termios, used for shutdown. 486 rawTermios *term.State // set IUTF8 flag, set raw terminal in raw mode, used for resume. 487 windowSize *unix.Winsize 488 489 localFramebuffer *terminal.Emulator 490 newState *terminal.Emulator 491 overlays *frontend.OverlayManager 492 network *network.Transport[*statesync.UserStream, *statesync.Complete] 493 display *terminal.Display 494 495 connectingNotification string 496 repaintRequested bool 497 lfEntered bool 498 quitSequenceStarted bool 499 cleanShutdown bool 500 verbose int 501 } 502 503 func newSTMClient(config *Config) *STMClient { 504 sc := STMClient{} 505 506 sc.ip = config.host 507 sc.port = config.port 508 sc.key = config.key 509 sc.escapeKey = 0x1E 510 sc.escapePassKey = '^' 511 sc.escapePassKey2 = '^' 512 sc.escapeRequireslf = false 513 sc.escapeKeyHelp = "?" 514 sc.overlays = frontend.NewOverlayManager() 515 516 var err error 517 sc.display, err = terminal.NewDisplay(true) 518 if err != nil { 519 return nil 520 } 521 522 sc.repaintRequested = false 523 sc.lfEntered = false 524 sc.quitSequenceStarted = false 525 sc.cleanShutdown = false 526 sc.verbose = config.verbose 527 528 if config.predictMode != "" { 529 switch config.predictMode { 530 case predictionValues[0]: // always 531 sc.overlays.GetPredictionEngine().SetDisplayPreference(frontend.Always) 532 case predictionValues[1]: // never 533 sc.overlays.GetPredictionEngine().SetDisplayPreference(frontend.Never) 534 case predictionValues[2]: // adaptive 535 sc.overlays.GetPredictionEngine().SetDisplayPreference(frontend.Adaptive) 536 case predictionValues[3]: // experimental 537 sc.overlays.GetPredictionEngine().SetDisplayPreference(frontend.Experimental) 538 } 539 } 540 541 if config.predictOverwrite == "yes" { 542 sc.overlays.GetPredictionEngine().SetPredictOverwrite(true) 543 } 544 return &sc 545 } 546 547 func (sc *STMClient) mainInit() error { 548 // get initial window size 549 col, row, err := term.GetSize(int(os.Stdin.Fd())) 550 if err != nil { 551 return err 552 } 553 util.Logger.Debug("client window size", "col", col, "row", row) 554 555 // local state 556 savedLines := terminal.SaveLinesRowsOption 557 sc.localFramebuffer = terminal.NewEmulator3(col, row, savedLines) 558 sc.newState = terminal.NewEmulator3(col, row, savedLines) 559 560 // initialize screen 561 // init := sc.display.NewFrame(true, sc.localFramebuffer, sc.localFramebuffer) 562 // CSI ? 1049l Use Normal Screen Buffer and restore cursor as in DECRC, xterm. 563 // CSI ? 1l Normal Cursor Keys (DECCKM) 564 // CSI ? 1004l Disable FocusIn/FocusOut 565 init := "\x1B[?1049l\x1B[?1l\x1B[?1004l" 566 os.Stdout.WriteString(init) 567 util.Logger.Debug("mainInit", "init", init) 568 569 // open network 570 blank := &statesync.UserStream{} 571 terminal, err := statesync.NewComplete(col, row, savedLines) 572 sc.network = network.NewTransportClient(blank, terminal, sc.key, sc.ip, fmt.Sprintf("%d", sc.port)) 573 574 // minimal delay on outgoing keystrokes 575 sc.network.SetSendDelay(1) 576 577 // tell server the size of the terminal 578 sc.network.GetCurrentState().PushBackResize(col, row) 579 580 // be noisy as necessary 581 sc.network.SetVerbose(uint(sc.verbose)) 582 583 return nil 584 } 585 586 func (sc *STMClient) processNetworkInput(s string) { 587 // sc.network.Recv() 588 if err := sc.network.ProcessPayload(s); err != nil { 589 util.Logger.Warn("ProcessPayload", "error", err) 590 } 591 592 // Now give hints to the overlays 593 rs := sc.network.GetLatestRemoteState() 594 sc.overlays.GetNotificationEngine().ServerHeard(rs.GetTimestamp()) 595 sc.overlays.GetNotificationEngine().ServerAcked(sc.network.GetSentStateAckedTimestamp()) 596 597 sc.overlays.GetPredictionEngine().SetLocalFrameAcked(sc.network.GetSentStateAcked()) 598 sc.overlays.GetPredictionEngine().SetSendInterval(sc.network.SentInterval()) 599 state := sc.network.GetLatestRemoteState() 600 lateAcked := state.GetState().GetEchoAck() 601 sc.overlays.GetPredictionEngine().SetLocalFrameLateAcked(lateAcked) 602 } 603 604 func (sc *STMClient) processUserInput(buf string) bool { 605 if sc.network.ShutdownInProgress() { 606 return true 607 } 608 sc.overlays.GetPredictionEngine().SetLocalFrameSent(sc.network.GetSentStateLast()) 609 610 // Don't predict for bulk data. 611 paste := len(buf) > 100 612 if paste { 613 sc.overlays.GetPredictionEngine().Reset() 614 } 615 616 util.Logger.Debug("processUserInput", "buf", buf) 617 var input []rune 618 graphemes := uniseg.NewGraphemes(buf) 619 for graphemes.Next() { 620 input = graphemes.Runes() 621 theByte := input[0] // the first byte 622 623 if !paste { 624 sc.overlays.GetPredictionEngine().NewUserInput(sc.localFramebuffer, input) 625 } 626 627 if sc.quitSequenceStarted { 628 if theByte == '.' { // Quit sequence is Ctrl-^ . 629 if sc.network.HasRemoteAddr() && !sc.network.ShutdownInProgress() { 630 sc.overlays.GetNotificationEngine().SetNotificationString( 631 "Exiting on user request...", true, true) 632 sc.network.StartShutdown() 633 return true 634 } else { 635 return false 636 } 637 } else if theByte == 0x1A { // Suspend sequence is escape_key Ctrl-Z 638 // Restore terminal and terminal-driver state 639 os.Stdout.WriteString(sc.display.Close()) 640 641 if err := term.Restore(int(os.Stdin.Fd()), sc.savedTermios); err != nil { 642 util.Logger.Error("restore terminal failed", "error", err) 643 return false 644 } 645 646 fmt.Printf("\n\033[37;44m[%s is suspended.]\033[m\n", frontend.CommandClientName) 647 648 // fflush(NULL) 649 // 650 /* actually suspend */ 651 // kill(0, SIGSTOP); 652 // TODO check SIGSTOP 653 654 sc.resume() 655 } else if theByte == rune(sc.escapePassKey) || theByte == rune(sc.escapePassKey2) { 656 // Emulation sequence to type escape_key is escape_key + 657 // escape_pass_key (that is escape key without Ctrl) 658 sc.network.GetCurrentState().PushBack([]rune{rune(sc.escapeKey)}) 659 } else { 660 // Escape key followed by anything other than . and ^ gets sent literally 661 sc.network.GetCurrentState().PushBack([]rune{rune(sc.escapeKey), theByte}) 662 } 663 664 sc.quitSequenceStarted = false 665 666 if sc.overlays.GetNotificationEngine().GetNotificationString() == sc.escapeKeyHelp { 667 sc.overlays.GetNotificationEngine().SetNotificationString("", false, true) 668 } 669 670 continue 671 } 672 673 sc.quitSequenceStarted = sc.escapeKey > 0 && theByte == rune(sc.escapeKey) && 674 (sc.lfEntered || !sc.escapeRequireslf) 675 676 if sc.quitSequenceStarted { 677 sc.lfEntered = false 678 sc.overlays.GetNotificationEngine().SetNotificationString(sc.escapeKeyHelp, true, false) 679 continue 680 } 681 682 sc.lfEntered = theByte == 0x0A || theByte == 0x0D // LineFeed, Ctrl-J, '\n' or CarriageReturn, Ctrl-M, '\r' 683 684 if theByte == 0x0C { // Ctrl-L 685 sc.repaintRequested = true 686 } 687 688 sc.network.GetCurrentState().PushBack(input) 689 } 690 691 return true 692 } 693 694 func (sc *STMClient) processResize() bool { 695 // get new size 696 col, row, err := term.GetSize(int(os.Stdin.Fd())) 697 if err != nil { 698 return false 699 } 700 701 // newSize := terminal.Resize{Width: col, Height: row} 702 // tell remote emulator 703 if !sc.network.ShutdownInProgress() { 704 sc.network.GetCurrentState().PushBackResize(col, row) 705 } 706 // note remote emulator will probably reply with its own Resize to adjust our state 707 708 // tell prediction engine 709 sc.overlays.GetPredictionEngine().Reset() 710 return true 711 } 712 713 func (sc *STMClient) outputNewFrame() { 714 // clean shutdown even when not initialized 715 if sc.network == nil { 716 return 717 } 718 719 // fetch target state 720 state := sc.network.GetLatestRemoteState() 721 sc.newState = state.GetState().GetEmulator() 722 // apply local overlays 723 sc.overlays.Apply(sc.newState) 724 725 // calculate minimal difference from where we are 726 // util.Log.SetLevel(slog.LevelInfo) 727 // diff := sc.display.NewFrame(!sc.repaintRequested, sc.localFramebuffer, sc.newState) 728 diff := state.GetState().GetDiff() 729 // util.Log.SetLevel(slog.LevelDebug) 730 os.Stdout.WriteString(diff) 731 if diff != "" { 732 util.Logger.Debug("outputNewFrame", "diff", diff) 733 } 734 735 sc.repaintRequested = false 736 sc.localFramebuffer = sc.newState 737 } 738 739 func (sc *STMClient) stillConnecting() bool { 740 // Initially, network == nil 741 return sc.network != nil && sc.network.GetRemoteStateNum() == 0 742 } 743 744 func (sc *STMClient) resume() { 745 // Restore termios state 746 if err := term.Restore(int(os.Stdin.Fd()), sc.rawTermios); err != nil { 747 os.Exit(1) 748 } 749 750 // Put terminal in application-cursor-key mode 751 os.Stdout.WriteString(sc.display.Open()) 752 753 // Flag that outer terminal state is unknown 754 sc.repaintRequested = true 755 } 756 757 func (sc *STMClient) init() error { 758 if !util.IsUtf8Locale() { 759 nativeType := util.GetCtype() 760 nativeCharset := util.LocaleCharset() 761 762 fmt.Printf("%s needs a UTF-8 native locale to run.\n\n", frontend.CommandClientName) 763 fmt.Printf("Unfortunately, the client's environment (%s) specifies\nthe character set %q.\n\n", 764 nativeType, nativeCharset) 765 return errors.New(frontend.CommandClientName + " requires UTF-8 environment.") 766 } 767 768 var err error 769 // Verify terminal configuration 770 sc.savedTermios, err = term.GetState(int(os.Stdin.Fd())) 771 if err != nil { 772 return err 773 } 774 775 // set IUTF8 if available 776 // term package doesn't allow us to access termios, we use util package to do that. 777 if err = util.SetIUTF8(int(os.Stdin.Fd())); err != nil { 778 return err 779 } 780 781 // Put terminal driver in raw mode 782 // https://learnku.com/go/t/23460/bit-operation-of-go 783 // &^ is used to clean the specified bit 784 _, err = term.MakeRaw(int(os.Stdin.Fd())) 785 if err != nil { 786 return err 787 } 788 // save raw + IUTF8 termios to rawTermios 789 sc.rawTermios, err = term.GetState(int(os.Stdin.Fd())) 790 if err != nil { 791 return err 792 } 793 794 // Put terminal in application-cursor-key mode 795 os.Stdout.WriteString(sc.display.Open()) 796 util.Logger.Info("open terminal", "seq", sc.display.Open()) 797 798 // Add our name to window title 799 prefix := os.Getenv("APRILSH_TITLE_PREFIX") 800 if prefix != "" { 801 sc.overlays.SetTitlePrefix(prefix) 802 } 803 804 // Set terminal escape key. 805 escapeKeyEnv := os.Getenv("APRILSH_ESCAPE_KEY") 806 if escapeKeyEnv != "" { 807 if len(escapeKeyEnv) == 1 { 808 sc.escapeKey = int(escapeKeyEnv[0]) 809 if sc.escapeKey > 0 && sc.escapeKey < 128 { 810 if sc.escapeKey < 32 { 811 // If escape is ctrl-something, pass it with repeating the key without ctrl. 812 sc.escapePassKey = sc.escapeKey + '@' 813 } else { 814 // If escape is something else, pass it with repeating the key itself. 815 sc.escapePassKey = sc.escapeKey 816 } 817 if sc.escapePassKey >= 'A' && sc.escapePassKey <= 'Z' { 818 // If escape pass is an upper case character, define optional version as lower case of the same. 819 sc.escapePassKey2 = sc.escapePassKey + 'a' - 'A' 820 } else { 821 sc.escapePassKey2 = sc.escapePassKey 822 } 823 } else { 824 sc.escapeKey = 0x1E 825 sc.escapePassKey = '^' 826 sc.escapePassKey2 = '^' 827 } 828 } else if len(escapeKeyEnv) == 0 { 829 sc.escapeKey = -1 830 } else { 831 sc.escapeKey = 0x1E 832 sc.escapePassKey = '^' 833 sc.escapePassKey2 = '^' 834 } 835 } else { 836 sc.escapeKey = 0x1E 837 sc.escapePassKey = '^' 838 sc.escapePassKey2 = '^' 839 } 840 841 // There are so many better ways to shoot oneself into leg than 842 // setting escape key to Ctrl-C, Ctrl-D, NewLine, Ctrl-L or CarriageReturn 843 // that we just won't allow that. 844 845 if sc.escapeKey == 0x03 || sc.escapeKey == 0x04 || sc.escapeKey == 0x0A || 846 sc.escapeKey == 0x0C || sc.escapeKey == 0x0D { 847 sc.escapeKey = 0x1E 848 sc.escapePassKey = '^' 849 sc.escapePassKey2 = '^' 850 } 851 852 // Adjust escape help differently if escape is a control character. 853 if sc.escapeKey > 0 { 854 var b strings.Builder 855 escapeKeyName := "" 856 escapePassName := fmt.Sprintf("\"%c\"", sc.escapePassKey) 857 if sc.escapeKey < 32 { 858 escapeKeyName = fmt.Sprintf("Ctrl-%c", sc.escapePassKey) 859 sc.escapeRequireslf = false 860 } else { 861 escapeKeyName = fmt.Sprintf("\"%c\"", sc.escapePassKey) 862 sc.escapeRequireslf = true 863 } 864 865 sc.escapeKeyHelp = fmt.Sprintf("Commands: Ctrl-Z suspends, \".\" quits, " + escapePassName + 866 " gives literal " + escapeKeyName) 867 sc.overlays.GetNotificationEngine().SetEscapeKeyString(b.String()) 868 } 869 sc.connectingNotification = fmt.Sprintf("Nothing received from server on UDP port %d.", sc.port) 870 871 return nil 872 } 873 874 func (sc *STMClient) shutdown() error { 875 // Restore screen state 876 sc.overlays.GetNotificationEngine().SetNotificationString("", false, true) 877 sc.overlays.GetNotificationEngine().ServerHeard(time.Now().UnixMilli()) 878 sc.overlays.SetTitlePrefix("") 879 880 sc.outputNewFrame() 881 882 // Restore terminal and terminal-driver state 883 os.Stdout.WriteString(sc.display.Close()) 884 util.Logger.Info("close terminal", "seq", sc.display.Close()) 885 886 if err := term.Restore(int(os.Stdin.Fd()), sc.savedTermios); err != nil { 887 util.Logger.Warn("restore terminal failed", "error", err) 888 return err 889 } 890 891 if sc.stillConnecting() { 892 fmt.Printf("%s did not make a successful connection to '%s:%d'.\n", 893 frontend.CommandClientName, sc.ip, sc.port) 894 fmt.Printf("Please verify that UDP port is not firewalled and %s can reach the server.\n", 895 frontend.CommandClientName) 896 fmt.Printf("By default, %s uses UDP port begin with %d, The -p option specifies base %s port.\n", 897 frontend.CommandClientName, frontend.DefaultPort+1, frontend.CommandServerName) 898 } else if sc.network != nil { 899 if !sc.cleanShutdown { 900 fmt.Printf("\n%s did not shut down cleanly.\n", frontend.CommandClientName) 901 fmt.Printf("Please verify that UDP port %d is not firewalled and can reach the server.\n", 902 sc.port) 903 } else { 904 fmt.Printf("Connection to %s:%d closed.\n", sc.ip, sc.port) 905 } 906 } 907 return nil 908 } 909 910 func (sc *STMClient) main() error { 911 // initialize signal handling and structures 912 sc.mainInit() 913 914 // /* Drop unnecessary privileges */ 915 // #ifdef HAVE_PLEDGE 916 // /* OpenBSD pledge() syscall */ 917 // if (pledge("stdio inet tty", NULL)) { 918 // perror("pledge() failed"); 919 // exit(1); 920 // } 921 // #endif 922 923 var networkChan chan frontend.Message 924 var fileChan chan frontend.Message 925 networkChan = make(chan frontend.Message, 1) 926 fileChan = make(chan frontend.Message, 1) 927 fileDownChan := make(chan any, 1) 928 networkDownChan := make(chan any, 1) 929 930 eg := errgroup.Group{} 931 // read from network 932 eg.Go(func() error { 933 frontend.ReadFromNetwork(1, networkChan, networkDownChan, sc.network.GetConnection()) 934 return nil 935 }) 936 937 // read from pty master file 938 eg.Go(func() error { 939 frontend.ReadFromFile(10, fileChan, fileDownChan, os.Stdin) 940 return nil 941 }) 942 943 // intercept signal 944 sigChan := make(chan os.Signal, 1) 945 signal.Notify(sigChan, syscall.SIGWINCH, syscall.SIGTERM, syscall.SIGINT, 946 syscall.SIGHUP, syscall.SIGPIPE, syscall.SIGCONT) 947 // shutdownChan := make(chan bool) 948 // eg.Go(func() error { 949 // for { 950 // select { 951 // case s := <-sigChan: 952 // util.Log.Debug("got signal","signal", s) 953 // signals.Handler(s) 954 // case <-shutdownChan: 955 // return nil 956 // } 957 // } 958 // }) 959 960 mainLoop: 961 for { 962 sc.outputNewFrame() 963 964 w0 := sc.network.WaitTime() 965 w1 := sc.overlays.WaitTime() 966 waitTime := min(w0, w1) 967 // waitTime := terminal.Min(sc.network.WaitTime(), sc.overlays.WaitTime()) 968 969 // Handle startup "Connecting..." message 970 if sc.stillConnecting() { 971 waitTime = min(250, waitTime) 972 } 973 974 timer := time.NewTimer(time.Duration(waitTime) * time.Millisecond) 975 util.Logger.Debug("mainLoop", "point", 100, 976 "network.WaitTime", w0, "overlays.WaitTime", w1, "timeout", waitTime) 977 select { 978 case <-timer.C: 979 // util.Log.Debug("mainLoop", "overlays", sc.overlays.WaitTime(), 980 // "network", sc.network.WaitTime(), "waitTime", waitTime) 981 case networkMsg := <-networkChan: 982 983 // got data from server 984 if networkMsg.Err != nil { 985 // quit asap for refused connection 986 if errors.Is(networkMsg.Err, syscall.ECONNREFUSED) { 987 break mainLoop 988 } 989 // if read from server failed, retry after 0.2 second 990 util.Logger.Warn("receive from network", "error", networkMsg.Err) 991 if !sc.network.ShutdownInProgress() { 992 sc.overlays.GetNotificationEngine().SetNetworkError(networkMsg.Err.Error()) 993 } 994 // TODO handle "use of closed network connection" error? 995 time.Sleep(time.Duration(200) * time.Millisecond) 996 continue mainLoop 997 } 998 // util.Log.Info("got from network", "data", networkMsg.Data) 999 sc.processNetworkInput(networkMsg.Data) 1000 1001 case fileMsg := <-fileChan: 1002 1003 // input from the user needs to be fed to the network 1004 if fileMsg.Err != nil || !sc.processUserInput(fileMsg.Data) { 1005 1006 // if read from local pts terminal failed, quit 1007 if fileMsg.Err != nil { 1008 util.Logger.Warn("read from file", "error", fileMsg.Err) 1009 } 1010 if !sc.network.HasRemoteAddr() { 1011 break mainLoop 1012 } else if !sc.network.ShutdownInProgress() { 1013 sc.overlays.GetNotificationEngine().SetNotificationString("Exiting...", true, true) 1014 sc.network.StartShutdown() 1015 } 1016 } 1017 case s := <-sigChan: 1018 util.Logger.Debug("got signal", "signal", s) 1019 signals.Handler(s) 1020 } 1021 1022 if signals.GotSignal(syscall.SIGWINCH) { 1023 // resize 1024 if !sc.processResize() { 1025 return nil 1026 } 1027 } 1028 1029 if signals.GotSignal(syscall.SIGCONT) { 1030 sc.resume() 1031 } 1032 1033 if signals.GotSignal(syscall.SIGTERM) || signals.GotSignal(syscall.SIGINT) || 1034 signals.GotSignal(syscall.SIGHUP) || signals.GotSignal(syscall.SIGPIPE) { 1035 // shutdown signal 1036 if !sc.network.HasRemoteAddr() { 1037 break 1038 } else if !sc.network.ShutdownInProgress() { 1039 util.Logger.Debug("start shutting down.") 1040 sc.overlays.GetNotificationEngine().SetNotificationString( 1041 "Signal received, shutting down...", true, true) 1042 sc.network.StartShutdown() 1043 } 1044 } 1045 1046 // quit if our shutdown has been acknowledged 1047 if sc.network.ShutdownInProgress() && sc.network.ShutdownAcknowledged() { 1048 sc.cleanShutdown = true 1049 break 1050 } 1051 1052 // quit after shutdown acknowledgement timeout 1053 if sc.network.ShutdownInProgress() && sc.network.ShutdownAckTimedout() { 1054 break 1055 } 1056 1057 // quit if we received and acknowledged a shutdown request 1058 if sc.network.CounterpartyShutdownAckSent() { 1059 sc.cleanShutdown = true 1060 break 1061 } 1062 1063 // write diagnostic message if can't reach server 1064 now := time.Now().UnixMilli() 1065 remoteState := sc.network.GetLatestRemoteState() 1066 sinceLastResponse := now - remoteState.GetTimestamp() 1067 if sc.stillConnecting() && !sc.network.ShutdownInProgress() && sinceLastResponse > 250 { 1068 if sinceLastResponse > frontend.TimeoutIfNoConnect { 1069 if !sc.network.ShutdownInProgress() { 1070 sc.overlays.GetNotificationEngine().SetNotificationString( 1071 "Timed out waiting for server...", true, true) 1072 // sc.network.StartShutdown() 1073 util.Logger.Warn("No connection within x seconds", "seconds", frontend.TimeoutIfNoConnect/1000) 1074 break 1075 } 1076 } else { 1077 sc.overlays.GetNotificationEngine().SetNotificationString( 1078 sc.connectingNotification, false, true) 1079 } 1080 } else if sc.network.GetRemoteStateNum() != 0 && 1081 sc.overlays.GetNotificationEngine().GetNotificationString() == sc.connectingNotification { 1082 sc.overlays.GetNotificationEngine().SetNotificationString("", false, true) 1083 } 1084 1085 // util.Log.Warn("mainLoop", "before", "tick") 1086 err := sc.network.Tick() 1087 if err != nil { 1088 util.Logger.Warn("tick send failed", "error", err) 1089 sc.overlays.GetNotificationEngine().SetNetworkError(err.Error()) 1090 // if errors.Is(err, syscall.ECONNREFUSED) { 1091 sc.network.StartShutdown() 1092 util.Logger.Debug("start shutting down.") 1093 } else { 1094 sc.overlays.GetNotificationEngine().ClearNetworkError() 1095 } 1096 1097 // if connected and no response over TimeoutIfNoResp 1098 if sc.network.GetRemoteStateNum() != 0 && sinceLastResponse > frontend.TimeoutIfNoResp { 1099 // if no awaken 1100 if !sc.network.Awaken(now) { 1101 util.Logger.Warn("No server response over x seconds", "seconds", frontend.TimeoutIfNoResp) 1102 break 1103 } 1104 } 1105 } 1106 1107 // stop signal and network 1108 signal.Stop(sigChan) 1109 sc.network.Close() 1110 1111 // shutdown the goroutines: file reader and network reader 1112 select { 1113 case fileDownChan <- "done": 1114 default: 1115 } 1116 select { 1117 case networkDownChan <- "done": 1118 default: 1119 } 1120 1121 // consume last message to release reader if possible 1122 select { 1123 case <-fileChan: 1124 default: 1125 } 1126 select { 1127 case <-networkChan: 1128 default: 1129 } 1130 eg.Wait() 1131 1132 return nil 1133 } 1134 1135 func main() { 1136 // cpuf, err := os.Create("cpu.profile") 1137 // if err != nil { 1138 // fmt.Println(err) 1139 // return 1140 // } 1141 // pprof.StartCPUProfile(cpuf) 1142 // defer pprof.StopCPUProfile() 1143 1144 // For security, make sure we don't dump core 1145 encrypt.DisableDumpingCore() 1146 1147 conf, _, err := parseFlags(os.Args[0], os.Args[1:]) 1148 if errors.Is(err, flag.ErrHelp) { 1149 frontend.PrintUsage("", usage) 1150 return 1151 } else if err != nil { 1152 frontend.PrintUsage(err.Error()) 1153 return 1154 } else if hint, ok := conf.buildConfig(); !ok { 1155 frontend.PrintUsage(hint) 1156 return 1157 } 1158 1159 if conf.version { 1160 printVersion() 1161 return 1162 } 1163 1164 if conf.colors { 1165 printColors() 1166 return 1167 } 1168 1169 var logWriter io.Writer 1170 logWriter = os.Stderr 1171 1172 // https://rderik.com/blog/identify-if-output-goes-to-the-terminal-or-is-being-redirected-in-golang/ 1173 // 1174 // if stderr outputs to terminal, we redirect it to /dev/null. 1175 f2, _ := os.Stderr.Stat() 1176 if (f2.Mode() & os.ModeCharDevice) == os.ModeCharDevice { 1177 os.Stderr = os.NewFile(uintptr(syscall.Stderr), os.DevNull) 1178 logWriter = io.Discard 1179 } 1180 1181 // setup client log file 1182 switch conf.verbose { 1183 case util.DebugLevel: 1184 util.Logger.CreateLogger(logWriter, conf.addSource, slog.LevelDebug) 1185 case util.TraceLevel: 1186 util.Logger.CreateLogger(logWriter, conf.addSource, util.LevelTrace) 1187 default: 1188 util.Logger.CreateLogger(logWriter, conf.addSource, slog.LevelInfo) 1189 } 1190 1191 // https://earthly.dev/blog/golang-errors/ 1192 // https://gosamples.dev/check-error-type/ 1193 // https://www.digitalocean.com/community/tutorials/how-to-add-extra-information-to-errors-in-go 1194 // 1195 // ssh login to remote server and fetch the seesion key 1196 if err = conf.fetchKey(); err != nil { 1197 var dnsError *net.DNSError 1198 var opError *net.OpError 1199 var keyError *xknownhosts.KeyError 1200 var exitError *ssh.ExitError 1201 var hostkeyChangeError *hostkeyChangeError 1202 1203 if errors.As(err, &dnsError) { 1204 frontend.PrintUsage(fmt.Sprintf("No such host: %q", dnsError.Name)) 1205 } else if errors.As(err, &opError) && opError.Op == "dial" { 1206 frontend.PrintUsage(fmt.Sprintf("Failed to connect to: %s", opError.Addr)) 1207 } else if strings.Contains(err.Error(), "unable to authenticate") { 1208 // the error returned by ssh.NewClientConn() doen't naming error, 1209 // we have to check the error message directly. 1210 1211 // enable 'PubkeyAuthentication yes' line in sshd_config 1212 frontend.PrintUsage(fmt.Sprintf("Failed to authenticate user %q", conf.user)) 1213 fmt.Printf("%s\n", err) 1214 } else if errors.As(err, &keyError) { 1215 // } else if strings.Contains(err.Error(), "key is unknown") { 1216 // we already handle it 1217 } else if errors.Is(err, errNoResponse) { 1218 frontend.PrintUsage(err.Error()) 1219 } else if errors.As(err, &exitError) && exitError.Waitmsg.ExitStatus() == 127 { 1220 frontend.PrintUsage("Plase check aprilsh is installed on server.") 1221 } else if errors.As(err, &hostkeyChangeError) { 1222 frontend.PrintUsage(hostkeyChangeError.Error()) 1223 } else { 1224 // printUsage(fmt.Sprintf("%#v", err)) 1225 frontend.PrintUsage(err.Error()) 1226 } 1227 return 1228 } 1229 1230 // start client 1231 util.SetNativeLocale() 1232 client := newSTMClient(conf) 1233 if err := client.init(); err != nil { 1234 fmt.Printf("%s init error:%s\n", frontend.CommandClientName, err) 1235 return 1236 } 1237 client.main() 1238 client.shutdown() 1239 }