github.com/rsc/tmp@v0.0.0-20240517235954-6deaab19748b/ssh-namespace-agent/main.go (about) 1 // Copyright 2017 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 // Ssh-namespace-agent tunnels the 9P name space over ssh-agent protocol. 6 // 7 // To use, add to your profile on both the local and remote systems: 8 // 9 // eval $(ssh-namespace-agent) 10 // 11 package main 12 13 import ( 14 "bytes" 15 "encoding/binary" 16 "errors" 17 "flag" 18 "fmt" 19 "io" 20 "log" 21 "net" 22 "os" 23 "os/exec" 24 "path/filepath" 25 "strconv" 26 "strings" 27 "sync" 28 "sync/atomic" 29 "syscall" 30 "time" 31 32 plan9client "9fans.net/go/plan9/client" 33 ) 34 35 var verbose = flag.Bool("v", false, "enable verbose debugging") 36 37 func usage() { 38 fmt.Fprintf(os.Stderr, "usage: eval $(ssh-namespace-agent)\n") 39 os.Exit(2) 40 } 41 42 func main() { 43 log.SetPrefix("ssh-namespace-agent: ") 44 log.SetFlags(0) 45 if len(os.Args) == 2 && os.Args[1] == "--daemon--" { 46 daemon() 47 return 48 } 49 50 flag.Usage = usage 51 flag.Parse() 52 if flag.NArg() != 0 { 53 usage() 54 } 55 56 r1, w1, err := os.Pipe() 57 if err != nil { 58 log.Fatal(err) 59 } 60 r2, w2, err := os.Pipe() 61 if err != nil { 62 log.Fatal(err) 63 } 64 cmd := exec.Command(os.Args[0], "--daemon--") 65 cmd.Stdout = w1 66 cmd.Stderr = w2 67 err = cmd.Start() 68 if err != nil { 69 log.Fatalf("reexec: %v", err) 70 } 71 w1.Close() 72 w2.Close() 73 74 var stdout bytes.Buffer 75 var stderr bytes.Buffer 76 done := make(chan bool, 2) 77 go func() { 78 io.Copy(&stdout, r1) 79 done <- true 80 }() 81 go func() { 82 io.Copy(&stderr, r2) 83 done <- true 84 }() 85 <-done 86 <-done 87 88 out := stdout.Bytes() 89 ok := false 90 if bytes.HasSuffix(out, []byte("\nOK\n")) || bytes.Equal(out, []byte("OK\n")) { 91 out = out[:len(out)-len("OK\n")] 92 ok = true 93 } 94 if len(out)+stderr.Len() == 0 { 95 log.Print("no output") 96 } 97 os.Stdout.Write(out) 98 os.Stderr.Write(stderr.Bytes()) 99 if !ok { 100 os.Exit(1) 101 } 102 } 103 104 func readMsg(c net.Conn) ([]byte, error) { 105 buf := make([]byte, 4) 106 n, err := io.ReadFull(c, buf) 107 if err != nil { 108 return buf[:n], err 109 } 110 nn := int(binary.BigEndian.Uint32(buf)) 111 bbuf := make([]byte, nn) 112 copy(bbuf, buf) 113 _, err = io.ReadFull(c, bbuf) 114 if err != nil { 115 return nil, err 116 } 117 return bbuf, nil 118 } 119 120 func writeMsg(c net.Conn, body []byte) error { 121 buf := make([]byte, 4) 122 binary.BigEndian.PutUint32(buf, uint32(len(body))) 123 _, err := c.Write(buf) 124 if err != nil { 125 return err 126 } 127 _, err = c.Write(body) 128 if err != nil { 129 return err 130 } 131 return nil 132 } 133 134 const ( 135 SSH_AGENT_FAILURE = 5 136 SSH_AGENT_SUCCESS = 6 137 SSH_AGENTC_EXTENSION = 27 138 SSH_AGENT_EXTENSION_FAILURE = 28 139 extName = "sshns@9fans.net" 140 ) 141 142 var ( 143 extHeader = []byte("\x1b\x0fsshns@9fans.net") 144 ) 145 146 func runExt(c net.Conn, req []byte) ([]byte, error, bool) { 147 msg := make([]byte, 4+len(extHeader)) 148 binary.BigEndian.PutUint32(msg, uint32(len(extHeader)+len(req))) 149 copy(msg[4:], extHeader) 150 if _, err := c.Write(msg); err != nil { 151 return nil, err, false 152 } 153 if _, err := c.Write(req); err != nil { 154 return nil, err, false 155 } 156 m, err := readMsg(c) 157 if err != nil { 158 return nil, err, true 159 } 160 if !bytes.HasPrefix(m, extHeader) { 161 return nil, fmt.Errorf("unexpected response"), true 162 } 163 m = m[len(extHeader):] 164 if bytes.HasPrefix(m, []byte("ok\n")) { 165 return m[3:], nil, true 166 } 167 if bytes.HasPrefix(m, []byte("err\n")) { 168 return nil, errors.New(string(m[4:])), true 169 } 170 return nil, fmt.Errorf("unexpected response"), true 171 } 172 173 func writeExtReply(c net.Conn, data []byte) error { 174 return writeMsg(c, append(extHeader, data...)) 175 } 176 177 func parseExtmsg(m []byte) (string, []byte) { 178 line := m 179 if i := bytes.IndexByte(line, '\n'); i >= 0 { 180 line, m = line[:i], m[i+1:] 181 } else { 182 line, m = m, nil 183 } 184 cmd := string(line) 185 return cmd, m 186 } 187 188 func daemon() { 189 if os.Getenv("SSH_CONNECTION") != "" { 190 server() 191 return 192 } 193 client() 194 } 195 196 // runs on ssh server side 197 func server() { 198 // Maybe these should be quiet failures? 199 sock := os.Getenv("SSH_AUTH_SOCK") 200 if sock == "" { 201 log.Fatal("$SSH_AUTH_SOCK not set") 202 } 203 204 _, err := listRemote(sock) 205 if err != nil { 206 log.Fatal(err) 207 } 208 209 dir := filepath.Dir(sock) 210 plan9 := filepath.Join(dir, "plan9") 211 _, err = os.Stat(plan9) 212 if err == nil { 213 // Daemon already running. 214 fmt.Printf("export NAMESPACE=%s\n", plan9) 215 fmt.Printf("OK\n") 216 return 217 } 218 err = os.Mkdir(plan9, 0700) 219 if err != nil { 220 log.Fatal(err) 221 } 222 223 if err := createSockets(sock, plan9); err != nil { 224 log.Fatal(err) 225 } 226 227 fmt.Printf("export NAMESPACE=%s\n", plan9) 228 fmt.Printf("OK\n") 229 closeStdout() 230 231 for { 232 time.Sleep(1 * time.Minute) 233 createSockets(sock, plan9) 234 } 235 } 236 237 var connCache struct { 238 sync.Mutex 239 c []net.Conn 240 } 241 242 // TODO: Cache connections. 243 func dialAndRunExt(sock string, msg []byte) ([]byte, error) { 244 connCache.Lock() 245 var c net.Conn 246 if len(connCache.c) > 0 { 247 c = connCache.c[len(connCache.c)-1] 248 connCache.c = connCache.c[:len(connCache.c)-1] 249 } 250 connCache.Unlock() 251 if c == nil { 252 var err error 253 log.Printf("redial %s", sock) 254 c, err = net.Dial("unix", sock) 255 if err != nil { 256 return nil, err 257 } 258 } 259 m, err, ok := runExt(c, msg) 260 if !ok { 261 c.Close() 262 } else { 263 connCache.Lock() 264 connCache.c = append(connCache.c, c) 265 connCache.Unlock() 266 } 267 return m, err 268 } 269 270 func listRemote(sock string) ([]string, error) { 271 data, err := dialAndRunExt(sock, []byte("list")) 272 if err != nil { 273 return nil, err 274 } 275 if len(data) == 0 { 276 return nil, nil 277 } 278 return strings.Split(string(data), "\x00"), nil 279 } 280 281 func closeStdout() { 282 fd, err := syscall.Open("/dev/null", syscall.O_RDWR, 0) 283 if err != nil { 284 log.Fatal(err) 285 } 286 syscall.Dup2(fd, 0) 287 if fd > 2 { 288 syscall.Close(fd) 289 } 290 fd, err = syscall.Open(os.Getenv("HOME")+"/.sshns.log", syscall.O_WRONLY|syscall.O_APPEND|syscall.O_CREAT, 0600) 291 if err != nil { 292 log.Fatal(err) 293 } 294 syscall.Dup2(fd, 1) 295 syscall.Dup2(fd, 2) 296 if fd > 2 { 297 syscall.Close(fd) 298 } 299 log.SetFlags(log.LstdFlags) 300 } 301 302 func reverseDial(sock, name string) (rc *remoteConn, err error) { 303 id, err := dialAndRunExt(sock, []byte("dial "+name)) 304 if err != nil { 305 log.Printf("dial %s: %v", name, err) 306 return nil, err 307 } 308 log.Printf("dial %s -> %s\n", name, id) 309 r := &remoteConn{sock: sock, id: string(id)} 310 go r.lease() 311 return r, nil 312 } 313 314 type remoteConn struct { 315 id string 316 sock string 317 dead uint32 318 } 319 320 const expireDelta = 10 * time.Minute 321 322 func (r *remoteConn) lease() { 323 for atomic.LoadUint32(&r.dead) == 0 { 324 dialAndRunExt(r.sock, []byte("refresh "+r.id)) 325 time.Sleep(expireDelta / 2) 326 } 327 } 328 329 func (r *remoteConn) Read(data []byte) (int, error) { 330 log.Printf("read %s %d\n", r.id, len(data)) 331 d, err := dialAndRunExt(r.sock, []byte(fmt.Sprintf("read %d %s", len(data), r.id))) 332 if err != nil { 333 log.Printf("read %s %d: %v", r.id, len(data), err) 334 return 0, err 335 } 336 log.Printf("read %s %d: %d", r.id, len(data), len(d)) 337 return copy(data, d), nil 338 } 339 340 func (r *remoteConn) Write(data []byte) (int, error) { 341 log.Printf("write %s %d\n", r.id, len(data)) 342 var w int 343 for len(data) > 0 { 344 n := len(data) 345 if n > 10000 { 346 n = 10000 347 } 348 log.Printf("write1 %s %d\n", r.id, n) 349 _, err := dialAndRunExt(r.sock, append([]byte("write "+r.id+"\n"), data[:n]...)) 350 if err != nil { 351 return w, err 352 } 353 w += n 354 data = data[n:] 355 } 356 return w, nil 357 } 358 359 func (r *remoteConn) Close() error { 360 log.Printf("close %s\n", r.id) 361 atomic.StoreUint32(&r.dead, 1) 362 _, err := dialAndRunExt(r.sock, []byte("close "+r.id)) 363 return err 364 } 365 366 var created = map[string]bool{} 367 368 func createSockets(sock, plan9 string) error { 369 names, err := listRemote(sock) 370 if err != nil { 371 log.Fatal(err) // probably client is gone 372 } 373 for _, name := range names { 374 if !created[name] { 375 created[name] = true 376 go proxySocket(sock, plan9, name) 377 } 378 } 379 return nil 380 } 381 382 func proxySocket(sock, plan9, name string) { 383 l, err := net.Listen("unix", filepath.Join(plan9, name)) 384 if err != nil { 385 log.Printf("post %s: %v", name, err) 386 return 387 } 388 389 for { 390 c, err := l.Accept() 391 if err != nil { 392 time.Sleep(1 * time.Minute) 393 continue 394 } 395 c1, err := reverseDial(sock, name) 396 if err != nil { 397 c.Close() 398 log.Printf("reverseDial %s: %v", name, err) 399 continue 400 } 401 go proxy(c, c1) 402 } 403 } 404 405 func proxy(c, c1 io.ReadWriteCloser) { 406 done := make(chan bool, 2) 407 go func() { 408 io.Copy(c, c1) 409 c.Close() 410 done <- true 411 }() 412 go func() { 413 io.Copy(c1, c) 414 c1.Close() 415 done <- true 416 }() 417 <-done 418 <-done 419 } 420 421 // runs on ssh client side 422 func client() { 423 // Maybe these should be quiet failures? 424 oldSock := os.Getenv("SSH_AUTH_SOCK") 425 if oldSock == "" { 426 if *verbose { 427 log.Fatal("$SSH_AUTH_SOCK not set") 428 } 429 return 430 } 431 if strings.HasSuffix(oldSock, "/sshns.socket") { 432 if *verbose { 433 log.Fatal("$SSH_AUTH_SOCK is already an ssh-namespace-agent") 434 } 435 return 436 } 437 438 ns := plan9client.Namespace() 439 if ns == "" { 440 log.Fatal("no plan9 namespace") 441 } 442 if err := os.MkdirAll(ns, 0700); err != nil { 443 log.Fatal(err) 444 } 445 446 // NOTE(rsc): Tried to use ssh-namespace-agent.socket, 447 // but combined with my Mac's current default $(namespace) 448 // of /tmp/ns.rsc._private_tmp_com.apple.launchd.7VN9hyV2B7_org.macosforge.xquartz:0/ 449 // that name just barely exceeds the 104-byte limit. 450 // Probably the default namespace needs to be shortened, 451 // but to avoid requiring that, we use a shorter name. 452 newSock := filepath.Join(ns, "sshns.socket") 453 l, err := net.Listen("unix", newSock) 454 if err != nil { 455 // Maybe already running? 456 c, err := net.Dial("unix", newSock) 457 if err == nil { 458 c.Close() 459 fmt.Printf("export SSH_AUTH_SOCK=%s\n", newSock) 460 fmt.Printf("OK\n") 461 return 462 } 463 os.Remove(newSock) 464 l, err = net.Listen("unix", newSock) 465 if err != nil { 466 log.Fatal(err) 467 } 468 } 469 470 fmt.Printf("export SSH_AUTH_SOCK=%s\n", newSock) 471 fmt.Printf("OK\n") 472 closeStdout() 473 474 for { 475 c, err := l.Accept() 476 if err != nil { 477 log.Fatal(err) 478 } 479 go serve(c, oldSock, ns) 480 } 481 } 482 483 func serve(c net.Conn, oldSock, ns string) { 484 log.Printf("serving on client\n") 485 var c1 net.Conn 486 defer c.Close() 487 for { 488 m, err := readMsg(c) 489 if err != nil { 490 log.Printf("serving socket: readMsg: %v", err) 491 return 492 } 493 log.Printf("serve %d %d", len(m), m[0]) 494 if !bytes.HasPrefix(m, extHeader) { 495 // pass message to underlying agent 496 if c1 == nil { 497 c1, err = net.Dial("unix", oldSock) 498 if err != nil { 499 log.Printf("proxying message: dial: %v", err) 500 return 501 } 502 defer c1.Close() 503 } 504 if err := writeMsg(c1, m); err != nil { 505 log.Printf("proxying message: write: %v", err) 506 return 507 } 508 m, err = readMsg(c1) 509 if err != nil { 510 log.Printf("proxying message: read: %v", err) 511 return 512 } 513 if err := writeMsg(c, m); err != nil { 514 log.Printf("proxying message: write back: %v", err) 515 return 516 } 517 continue 518 } 519 cmd, m := parseExtmsg(m[len(extHeader):]) 520 f := strings.Fields(cmd) 521 if len(f) > 0 { 522 switch f[0] { 523 case "list": 524 handleList(c, ns) 525 continue 526 case "dial": 527 if len(f) == 2 { 528 handleDial(c, ns, f[1]) 529 continue 530 } 531 case "close": 532 if len(f) == 2 { 533 handleClose(c, f[1]) 534 continue 535 } 536 case "write": 537 if len(f) == 2 { 538 handleWrite(c, f[1], m) 539 continue 540 } 541 case "read": 542 if len(f) == 3 { 543 n, err := strconv.Atoi(f[1]) 544 if err == nil { 545 handleRead(c, n, f[2]) 546 continue 547 } 548 } 549 case "refresh": 550 if len(f) == 2 { 551 handleRefresh(c, f[1]) 552 continue 553 } 554 } 555 } 556 writeExtReply(c, []byte(fmt.Sprintf("err\nunknown command %q", cmd))) 557 } 558 } 559 560 func handleList(c net.Conn, ns string) { 561 names, _ := filepath.Glob(filepath.Join(ns, "*")) 562 var out []string 563 for _, name := range names { 564 name = filepath.Base(name) 565 if !strings.HasSuffix(name, ".socket") { 566 out = append(out, name) 567 } 568 } 569 reply := []byte("ok\n" + strings.Join(out, "\x00")) 570 writeExtReply(c, reply) 571 } 572 573 type conn struct { 574 c net.Conn 575 expire time.Time 576 } 577 578 var conns struct { 579 sync.Mutex 580 m map[string]*conn 581 n int 582 } 583 584 func init() { 585 go func() { 586 for { 587 time.Sleep(expireDelta) 588 conns.Lock() 589 var dead []*conn 590 for k, cc := range conns.m { 591 if time.Now().After(cc.expire) { 592 dead = append(dead, cc) 593 delete(conns.m, k) 594 } 595 } 596 conns.Unlock() 597 for _, cc := range dead { 598 cc.c.Close() 599 } 600 } 601 }() 602 } 603 604 func handleDial(c net.Conn, ns string, name string) { 605 c1, err := net.Dial("unix", filepath.Join(ns, name)) 606 if err != nil { 607 writeExtReply(c, []byte("err\n"+err.Error())) 608 return 609 } 610 conns.Lock() 611 conns.n++ 612 id := fmt.Sprint(conns.n) 613 if conns.m == nil { 614 conns.m = map[string]*conn{} 615 } 616 conns.m[id] = &conn{c: c1, expire: time.Now().Add(expireDelta)} 617 conns.Unlock() 618 writeExtReply(c, []byte("ok\n"+id)) 619 } 620 621 func handleClose(c net.Conn, id string) { 622 conns.Lock() 623 cc := conns.m[id] 624 if cc != nil { 625 delete(conns.m, id) 626 } 627 conns.Unlock() 628 629 if cc == nil { 630 writeExtReply(c, []byte("err\nunknown conn")) 631 return 632 } 633 634 cc.c.Close() 635 writeExtReply(c, []byte("ok\n")) 636 } 637 638 func handleRead(c net.Conn, n int, id string) { 639 conns.Lock() 640 cc := conns.m[id] 641 if cc != nil { 642 cc.expire = time.Now().Add(expireDelta) 643 } 644 conns.Unlock() 645 646 if cc == nil { 647 writeExtReply(c, []byte("err\nunknown conn")) 648 return 649 } 650 651 log.Printf("handleRead %s %d", id, n) 652 buf := make([]byte, 3+n) 653 n, err := cc.c.Read(buf[3:]) 654 if n > 0 { 655 err = nil 656 } 657 if err != nil { 658 writeExtReply(c, []byte("err\n"+err.Error())) 659 return 660 } 661 copy(buf[0:], "ok\n") 662 writeExtReply(c, buf[:3+n]) 663 } 664 665 func handleWrite(c net.Conn, id string, data []byte) { 666 conns.Lock() 667 cc := conns.m[id] 668 if cc != nil { 669 cc.expire = time.Now().Add(expireDelta) 670 } 671 conns.Unlock() 672 673 if cc == nil { 674 writeExtReply(c, []byte("err\nunknown conn")) 675 return 676 } 677 678 log.Printf("handleWrite %s %d", id, len(data)) 679 _, err := cc.c.Write(data) 680 if err != nil { 681 writeExtReply(c, []byte("err\n"+err.Error())) 682 return 683 } 684 writeExtReply(c, []byte("ok\n")) 685 } 686 687 func handleRefresh(c net.Conn, id string) { 688 conns.Lock() 689 cc := conns.m[id] 690 if cc != nil { 691 cc.expire = time.Now().Add(expireDelta) 692 } 693 conns.Unlock() 694 if cc == nil { 695 writeExtReply(c, []byte("err\nunknown conn")) 696 return 697 } 698 writeExtReply(c, []byte("ok\n")) 699 }