github.com/rsc/tmp@v0.0.0-20240517235954-6deaab19748b/ssh-namespace-agent/main.go (about)

     1  // Copyright 2017 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Ssh-namespace-agent tunnels the 9P name space over ssh-agent protocol.
     6  //
     7  // To use, add to your profile on both the local and remote systems:
     8  //
     9  //	eval $(ssh-namespace-agent)
    10  //
    11  package main
    12  
    13  import (
    14  	"bytes"
    15  	"encoding/binary"
    16  	"errors"
    17  	"flag"
    18  	"fmt"
    19  	"io"
    20  	"log"
    21  	"net"
    22  	"os"
    23  	"os/exec"
    24  	"path/filepath"
    25  	"strconv"
    26  	"strings"
    27  	"sync"
    28  	"sync/atomic"
    29  	"syscall"
    30  	"time"
    31  
    32  	plan9client "9fans.net/go/plan9/client"
    33  )
    34  
    35  var verbose = flag.Bool("v", false, "enable verbose debugging")
    36  
    37  func usage() {
    38  	fmt.Fprintf(os.Stderr, "usage: eval $(ssh-namespace-agent)\n")
    39  	os.Exit(2)
    40  }
    41  
    42  func main() {
    43  	log.SetPrefix("ssh-namespace-agent: ")
    44  	log.SetFlags(0)
    45  	if len(os.Args) == 2 && os.Args[1] == "--daemon--" {
    46  		daemon()
    47  		return
    48  	}
    49  
    50  	flag.Usage = usage
    51  	flag.Parse()
    52  	if flag.NArg() != 0 {
    53  		usage()
    54  	}
    55  
    56  	r1, w1, err := os.Pipe()
    57  	if err != nil {
    58  		log.Fatal(err)
    59  	}
    60  	r2, w2, err := os.Pipe()
    61  	if err != nil {
    62  		log.Fatal(err)
    63  	}
    64  	cmd := exec.Command(os.Args[0], "--daemon--")
    65  	cmd.Stdout = w1
    66  	cmd.Stderr = w2
    67  	err = cmd.Start()
    68  	if err != nil {
    69  		log.Fatalf("reexec: %v", err)
    70  	}
    71  	w1.Close()
    72  	w2.Close()
    73  
    74  	var stdout bytes.Buffer
    75  	var stderr bytes.Buffer
    76  	done := make(chan bool, 2)
    77  	go func() {
    78  		io.Copy(&stdout, r1)
    79  		done <- true
    80  	}()
    81  	go func() {
    82  		io.Copy(&stderr, r2)
    83  		done <- true
    84  	}()
    85  	<-done
    86  	<-done
    87  
    88  	out := stdout.Bytes()
    89  	ok := false
    90  	if bytes.HasSuffix(out, []byte("\nOK\n")) || bytes.Equal(out, []byte("OK\n")) {
    91  		out = out[:len(out)-len("OK\n")]
    92  		ok = true
    93  	}
    94  	if len(out)+stderr.Len() == 0 {
    95  		log.Print("no output")
    96  	}
    97  	os.Stdout.Write(out)
    98  	os.Stderr.Write(stderr.Bytes())
    99  	if !ok {
   100  		os.Exit(1)
   101  	}
   102  }
   103  
   104  func readMsg(c net.Conn) ([]byte, error) {
   105  	buf := make([]byte, 4)
   106  	n, err := io.ReadFull(c, buf)
   107  	if err != nil {
   108  		return buf[:n], err
   109  	}
   110  	nn := int(binary.BigEndian.Uint32(buf))
   111  	bbuf := make([]byte, nn)
   112  	copy(bbuf, buf)
   113  	_, err = io.ReadFull(c, bbuf)
   114  	if err != nil {
   115  		return nil, err
   116  	}
   117  	return bbuf, nil
   118  }
   119  
   120  func writeMsg(c net.Conn, body []byte) error {
   121  	buf := make([]byte, 4)
   122  	binary.BigEndian.PutUint32(buf, uint32(len(body)))
   123  	_, err := c.Write(buf)
   124  	if err != nil {
   125  		return err
   126  	}
   127  	_, err = c.Write(body)
   128  	if err != nil {
   129  		return err
   130  	}
   131  	return nil
   132  }
   133  
   134  const (
   135  	SSH_AGENT_FAILURE           = 5
   136  	SSH_AGENT_SUCCESS           = 6
   137  	SSH_AGENTC_EXTENSION        = 27
   138  	SSH_AGENT_EXTENSION_FAILURE = 28
   139  	extName                     = "sshns@9fans.net"
   140  )
   141  
   142  var (
   143  	extHeader = []byte("\x1b\x0fsshns@9fans.net")
   144  )
   145  
   146  func runExt(c net.Conn, req []byte) ([]byte, error, bool) {
   147  	msg := make([]byte, 4+len(extHeader))
   148  	binary.BigEndian.PutUint32(msg, uint32(len(extHeader)+len(req)))
   149  	copy(msg[4:], extHeader)
   150  	if _, err := c.Write(msg); err != nil {
   151  		return nil, err, false
   152  	}
   153  	if _, err := c.Write(req); err != nil {
   154  		return nil, err, false
   155  	}
   156  	m, err := readMsg(c)
   157  	if err != nil {
   158  		return nil, err, true
   159  	}
   160  	if !bytes.HasPrefix(m, extHeader) {
   161  		return nil, fmt.Errorf("unexpected response"), true
   162  	}
   163  	m = m[len(extHeader):]
   164  	if bytes.HasPrefix(m, []byte("ok\n")) {
   165  		return m[3:], nil, true
   166  	}
   167  	if bytes.HasPrefix(m, []byte("err\n")) {
   168  		return nil, errors.New(string(m[4:])), true
   169  	}
   170  	return nil, fmt.Errorf("unexpected response"), true
   171  }
   172  
   173  func writeExtReply(c net.Conn, data []byte) error {
   174  	return writeMsg(c, append(extHeader, data...))
   175  }
   176  
   177  func parseExtmsg(m []byte) (string, []byte) {
   178  	line := m
   179  	if i := bytes.IndexByte(line, '\n'); i >= 0 {
   180  		line, m = line[:i], m[i+1:]
   181  	} else {
   182  		line, m = m, nil
   183  	}
   184  	cmd := string(line)
   185  	return cmd, m
   186  }
   187  
   188  func daemon() {
   189  	if os.Getenv("SSH_CONNECTION") != "" {
   190  		server()
   191  		return
   192  	}
   193  	client()
   194  }
   195  
   196  // runs on ssh server side
   197  func server() {
   198  	// Maybe these should be quiet failures?
   199  	sock := os.Getenv("SSH_AUTH_SOCK")
   200  	if sock == "" {
   201  		log.Fatal("$SSH_AUTH_SOCK not set")
   202  	}
   203  
   204  	_, err := listRemote(sock)
   205  	if err != nil {
   206  		log.Fatal(err)
   207  	}
   208  
   209  	dir := filepath.Dir(sock)
   210  	plan9 := filepath.Join(dir, "plan9")
   211  	_, err = os.Stat(plan9)
   212  	if err == nil {
   213  		// Daemon already running.
   214  		fmt.Printf("export NAMESPACE=%s\n", plan9)
   215  		fmt.Printf("OK\n")
   216  		return
   217  	}
   218  	err = os.Mkdir(plan9, 0700)
   219  	if err != nil {
   220  		log.Fatal(err)
   221  	}
   222  
   223  	if err := createSockets(sock, plan9); err != nil {
   224  		log.Fatal(err)
   225  	}
   226  
   227  	fmt.Printf("export NAMESPACE=%s\n", plan9)
   228  	fmt.Printf("OK\n")
   229  	closeStdout()
   230  
   231  	for {
   232  		time.Sleep(1 * time.Minute)
   233  		createSockets(sock, plan9)
   234  	}
   235  }
   236  
   237  var connCache struct {
   238  	sync.Mutex
   239  	c []net.Conn
   240  }
   241  
   242  // TODO: Cache connections.
   243  func dialAndRunExt(sock string, msg []byte) ([]byte, error) {
   244  	connCache.Lock()
   245  	var c net.Conn
   246  	if len(connCache.c) > 0 {
   247  		c = connCache.c[len(connCache.c)-1]
   248  		connCache.c = connCache.c[:len(connCache.c)-1]
   249  	}
   250  	connCache.Unlock()
   251  	if c == nil {
   252  		var err error
   253  		log.Printf("redial %s", sock)
   254  		c, err = net.Dial("unix", sock)
   255  		if err != nil {
   256  			return nil, err
   257  		}
   258  	}
   259  	m, err, ok := runExt(c, msg)
   260  	if !ok {
   261  		c.Close()
   262  	} else {
   263  		connCache.Lock()
   264  		connCache.c = append(connCache.c, c)
   265  		connCache.Unlock()
   266  	}
   267  	return m, err
   268  }
   269  
   270  func listRemote(sock string) ([]string, error) {
   271  	data, err := dialAndRunExt(sock, []byte("list"))
   272  	if err != nil {
   273  		return nil, err
   274  	}
   275  	if len(data) == 0 {
   276  		return nil, nil
   277  	}
   278  	return strings.Split(string(data), "\x00"), nil
   279  }
   280  
   281  func closeStdout() {
   282  	fd, err := syscall.Open("/dev/null", syscall.O_RDWR, 0)
   283  	if err != nil {
   284  		log.Fatal(err)
   285  	}
   286  	syscall.Dup2(fd, 0)
   287  	if fd > 2 {
   288  		syscall.Close(fd)
   289  	}
   290  	fd, err = syscall.Open(os.Getenv("HOME")+"/.sshns.log", syscall.O_WRONLY|syscall.O_APPEND|syscall.O_CREAT, 0600)
   291  	if err != nil {
   292  		log.Fatal(err)
   293  	}
   294  	syscall.Dup2(fd, 1)
   295  	syscall.Dup2(fd, 2)
   296  	if fd > 2 {
   297  		syscall.Close(fd)
   298  	}
   299  	log.SetFlags(log.LstdFlags)
   300  }
   301  
   302  func reverseDial(sock, name string) (rc *remoteConn, err error) {
   303  	id, err := dialAndRunExt(sock, []byte("dial "+name))
   304  	if err != nil {
   305  		log.Printf("dial %s: %v", name, err)
   306  		return nil, err
   307  	}
   308  	log.Printf("dial %s -> %s\n", name, id)
   309  	r := &remoteConn{sock: sock, id: string(id)}
   310  	go r.lease()
   311  	return r, nil
   312  }
   313  
   314  type remoteConn struct {
   315  	id   string
   316  	sock string
   317  	dead uint32
   318  }
   319  
   320  const expireDelta = 10 * time.Minute
   321  
   322  func (r *remoteConn) lease() {
   323  	for atomic.LoadUint32(&r.dead) == 0 {
   324  		dialAndRunExt(r.sock, []byte("refresh "+r.id))
   325  		time.Sleep(expireDelta / 2)
   326  	}
   327  }
   328  
   329  func (r *remoteConn) Read(data []byte) (int, error) {
   330  	log.Printf("read %s %d\n", r.id, len(data))
   331  	d, err := dialAndRunExt(r.sock, []byte(fmt.Sprintf("read %d %s", len(data), r.id)))
   332  	if err != nil {
   333  		log.Printf("read %s %d: %v", r.id, len(data), err)
   334  		return 0, err
   335  	}
   336  	log.Printf("read %s %d: %d", r.id, len(data), len(d))
   337  	return copy(data, d), nil
   338  }
   339  
   340  func (r *remoteConn) Write(data []byte) (int, error) {
   341  	log.Printf("write %s %d\n", r.id, len(data))
   342  	var w int
   343  	for len(data) > 0 {
   344  		n := len(data)
   345  		if n > 10000 {
   346  			n = 10000
   347  		}
   348  		log.Printf("write1 %s %d\n", r.id, n)
   349  		_, err := dialAndRunExt(r.sock, append([]byte("write "+r.id+"\n"), data[:n]...))
   350  		if err != nil {
   351  			return w, err
   352  		}
   353  		w += n
   354  		data = data[n:]
   355  	}
   356  	return w, nil
   357  }
   358  
   359  func (r *remoteConn) Close() error {
   360  	log.Printf("close %s\n", r.id)
   361  	atomic.StoreUint32(&r.dead, 1)
   362  	_, err := dialAndRunExt(r.sock, []byte("close "+r.id))
   363  	return err
   364  }
   365  
   366  var created = map[string]bool{}
   367  
   368  func createSockets(sock, plan9 string) error {
   369  	names, err := listRemote(sock)
   370  	if err != nil {
   371  		log.Fatal(err) // probably client is gone
   372  	}
   373  	for _, name := range names {
   374  		if !created[name] {
   375  			created[name] = true
   376  			go proxySocket(sock, plan9, name)
   377  		}
   378  	}
   379  	return nil
   380  }
   381  
   382  func proxySocket(sock, plan9, name string) {
   383  	l, err := net.Listen("unix", filepath.Join(plan9, name))
   384  	if err != nil {
   385  		log.Printf("post %s: %v", name, err)
   386  		return
   387  	}
   388  
   389  	for {
   390  		c, err := l.Accept()
   391  		if err != nil {
   392  			time.Sleep(1 * time.Minute)
   393  			continue
   394  		}
   395  		c1, err := reverseDial(sock, name)
   396  		if err != nil {
   397  			c.Close()
   398  			log.Printf("reverseDial %s: %v", name, err)
   399  			continue
   400  		}
   401  		go proxy(c, c1)
   402  	}
   403  }
   404  
   405  func proxy(c, c1 io.ReadWriteCloser) {
   406  	done := make(chan bool, 2)
   407  	go func() {
   408  		io.Copy(c, c1)
   409  		c.Close()
   410  		done <- true
   411  	}()
   412  	go func() {
   413  		io.Copy(c1, c)
   414  		c1.Close()
   415  		done <- true
   416  	}()
   417  	<-done
   418  	<-done
   419  }
   420  
   421  // runs on ssh client side
   422  func client() {
   423  	// Maybe these should be quiet failures?
   424  	oldSock := os.Getenv("SSH_AUTH_SOCK")
   425  	if oldSock == "" {
   426  		if *verbose {
   427  			log.Fatal("$SSH_AUTH_SOCK not set")
   428  		}
   429  		return
   430  	}
   431  	if strings.HasSuffix(oldSock, "/sshns.socket") {
   432  		if *verbose {
   433  			log.Fatal("$SSH_AUTH_SOCK is already an ssh-namespace-agent")
   434  		}
   435  		return
   436  	}
   437  
   438  	ns := plan9client.Namespace()
   439  	if ns == "" {
   440  		log.Fatal("no plan9 namespace")
   441  	}
   442  	if err := os.MkdirAll(ns, 0700); err != nil {
   443  		log.Fatal(err)
   444  	}
   445  
   446  	// NOTE(rsc): Tried to use ssh-namespace-agent.socket,
   447  	// but combined with my Mac's current default $(namespace)
   448  	// of /tmp/ns.rsc._private_tmp_com.apple.launchd.7VN9hyV2B7_org.macosforge.xquartz:0/
   449  	// that name just barely exceeds the 104-byte limit.
   450  	// Probably the default namespace needs to be shortened,
   451  	// but to avoid requiring that, we use a shorter name.
   452  	newSock := filepath.Join(ns, "sshns.socket")
   453  	l, err := net.Listen("unix", newSock)
   454  	if err != nil {
   455  		// Maybe already running?
   456  		c, err := net.Dial("unix", newSock)
   457  		if err == nil {
   458  			c.Close()
   459  			fmt.Printf("export SSH_AUTH_SOCK=%s\n", newSock)
   460  			fmt.Printf("OK\n")
   461  			return
   462  		}
   463  		os.Remove(newSock)
   464  		l, err = net.Listen("unix", newSock)
   465  		if err != nil {
   466  			log.Fatal(err)
   467  		}
   468  	}
   469  
   470  	fmt.Printf("export SSH_AUTH_SOCK=%s\n", newSock)
   471  	fmt.Printf("OK\n")
   472  	closeStdout()
   473  
   474  	for {
   475  		c, err := l.Accept()
   476  		if err != nil {
   477  			log.Fatal(err)
   478  		}
   479  		go serve(c, oldSock, ns)
   480  	}
   481  }
   482  
   483  func serve(c net.Conn, oldSock, ns string) {
   484  	log.Printf("serving on client\n")
   485  	var c1 net.Conn
   486  	defer c.Close()
   487  	for {
   488  		m, err := readMsg(c)
   489  		if err != nil {
   490  			log.Printf("serving socket: readMsg: %v", err)
   491  			return
   492  		}
   493  		log.Printf("serve %d %d", len(m), m[0])
   494  		if !bytes.HasPrefix(m, extHeader) {
   495  			// pass message to underlying agent
   496  			if c1 == nil {
   497  				c1, err = net.Dial("unix", oldSock)
   498  				if err != nil {
   499  					log.Printf("proxying message: dial: %v", err)
   500  					return
   501  				}
   502  				defer c1.Close()
   503  			}
   504  			if err := writeMsg(c1, m); err != nil {
   505  				log.Printf("proxying message: write: %v", err)
   506  				return
   507  			}
   508  			m, err = readMsg(c1)
   509  			if err != nil {
   510  				log.Printf("proxying message: read: %v", err)
   511  				return
   512  			}
   513  			if err := writeMsg(c, m); err != nil {
   514  				log.Printf("proxying message: write back: %v", err)
   515  				return
   516  			}
   517  			continue
   518  		}
   519  		cmd, m := parseExtmsg(m[len(extHeader):])
   520  		f := strings.Fields(cmd)
   521  		if len(f) > 0 {
   522  			switch f[0] {
   523  			case "list":
   524  				handleList(c, ns)
   525  				continue
   526  			case "dial":
   527  				if len(f) == 2 {
   528  					handleDial(c, ns, f[1])
   529  					continue
   530  				}
   531  			case "close":
   532  				if len(f) == 2 {
   533  					handleClose(c, f[1])
   534  					continue
   535  				}
   536  			case "write":
   537  				if len(f) == 2 {
   538  					handleWrite(c, f[1], m)
   539  					continue
   540  				}
   541  			case "read":
   542  				if len(f) == 3 {
   543  					n, err := strconv.Atoi(f[1])
   544  					if err == nil {
   545  						handleRead(c, n, f[2])
   546  						continue
   547  					}
   548  				}
   549  			case "refresh":
   550  				if len(f) == 2 {
   551  					handleRefresh(c, f[1])
   552  					continue
   553  				}
   554  			}
   555  		}
   556  		writeExtReply(c, []byte(fmt.Sprintf("err\nunknown command %q", cmd)))
   557  	}
   558  }
   559  
   560  func handleList(c net.Conn, ns string) {
   561  	names, _ := filepath.Glob(filepath.Join(ns, "*"))
   562  	var out []string
   563  	for _, name := range names {
   564  		name = filepath.Base(name)
   565  		if !strings.HasSuffix(name, ".socket") {
   566  			out = append(out, name)
   567  		}
   568  	}
   569  	reply := []byte("ok\n" + strings.Join(out, "\x00"))
   570  	writeExtReply(c, reply)
   571  }
   572  
   573  type conn struct {
   574  	c      net.Conn
   575  	expire time.Time
   576  }
   577  
   578  var conns struct {
   579  	sync.Mutex
   580  	m map[string]*conn
   581  	n int
   582  }
   583  
   584  func init() {
   585  	go func() {
   586  		for {
   587  			time.Sleep(expireDelta)
   588  			conns.Lock()
   589  			var dead []*conn
   590  			for k, cc := range conns.m {
   591  				if time.Now().After(cc.expire) {
   592  					dead = append(dead, cc)
   593  					delete(conns.m, k)
   594  				}
   595  			}
   596  			conns.Unlock()
   597  			for _, cc := range dead {
   598  				cc.c.Close()
   599  			}
   600  		}
   601  	}()
   602  }
   603  
   604  func handleDial(c net.Conn, ns string, name string) {
   605  	c1, err := net.Dial("unix", filepath.Join(ns, name))
   606  	if err != nil {
   607  		writeExtReply(c, []byte("err\n"+err.Error()))
   608  		return
   609  	}
   610  	conns.Lock()
   611  	conns.n++
   612  	id := fmt.Sprint(conns.n)
   613  	if conns.m == nil {
   614  		conns.m = map[string]*conn{}
   615  	}
   616  	conns.m[id] = &conn{c: c1, expire: time.Now().Add(expireDelta)}
   617  	conns.Unlock()
   618  	writeExtReply(c, []byte("ok\n"+id))
   619  }
   620  
   621  func handleClose(c net.Conn, id string) {
   622  	conns.Lock()
   623  	cc := conns.m[id]
   624  	if cc != nil {
   625  		delete(conns.m, id)
   626  	}
   627  	conns.Unlock()
   628  
   629  	if cc == nil {
   630  		writeExtReply(c, []byte("err\nunknown conn"))
   631  		return
   632  	}
   633  
   634  	cc.c.Close()
   635  	writeExtReply(c, []byte("ok\n"))
   636  }
   637  
   638  func handleRead(c net.Conn, n int, id string) {
   639  	conns.Lock()
   640  	cc := conns.m[id]
   641  	if cc != nil {
   642  		cc.expire = time.Now().Add(expireDelta)
   643  	}
   644  	conns.Unlock()
   645  
   646  	if cc == nil {
   647  		writeExtReply(c, []byte("err\nunknown conn"))
   648  		return
   649  	}
   650  
   651  	log.Printf("handleRead %s %d", id, n)
   652  	buf := make([]byte, 3+n)
   653  	n, err := cc.c.Read(buf[3:])
   654  	if n > 0 {
   655  		err = nil
   656  	}
   657  	if err != nil {
   658  		writeExtReply(c, []byte("err\n"+err.Error()))
   659  		return
   660  	}
   661  	copy(buf[0:], "ok\n")
   662  	writeExtReply(c, buf[:3+n])
   663  }
   664  
   665  func handleWrite(c net.Conn, id string, data []byte) {
   666  	conns.Lock()
   667  	cc := conns.m[id]
   668  	if cc != nil {
   669  		cc.expire = time.Now().Add(expireDelta)
   670  	}
   671  	conns.Unlock()
   672  
   673  	if cc == nil {
   674  		writeExtReply(c, []byte("err\nunknown conn"))
   675  		return
   676  	}
   677  
   678  	log.Printf("handleWrite %s %d", id, len(data))
   679  	_, err := cc.c.Write(data)
   680  	if err != nil {
   681  		writeExtReply(c, []byte("err\n"+err.Error()))
   682  		return
   683  	}
   684  	writeExtReply(c, []byte("ok\n"))
   685  }
   686  
   687  func handleRefresh(c net.Conn, id string) {
   688  	conns.Lock()
   689  	cc := conns.m[id]
   690  	if cc != nil {
   691  		cc.expire = time.Now().Add(expireDelta)
   692  	}
   693  	conns.Unlock()
   694  	if cc == nil {
   695  		writeExtReply(c, []byte("err\nunknown conn"))
   696  		return
   697  	}
   698  	writeExtReply(c, []byte("ok\n"))
   699  }