github.com/devops-filetransfer/sshego@v7.0.4+incompatible/cli.go (about)

     1  package sshego
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"log"
     7  	"net"
     8  	"strings"
     9  	"time"
    10  
    11  	ssh "github.com/glycerine/sshego/xendor/github.com/glycerine/xcryptossh"
    12  )
    13  
    14  //go:generate greenpack
    15  
    16  //msgp:ignore DialConfig
    17  
    18  // DialConfig provides Dial() with what
    19  // it needs in order to establish an encrypted
    20  // and authenticated ssh connection.
    21  //
    22  type DialConfig struct {
    23  
    24  	// ClientKnownHostsPath is the path to the file
    25  	// on client's disk that holds the known server keys.
    26  	ClientKnownHostsPath string
    27  
    28  	// cached to avoid a disk read, we only read
    29  	// from ClientKnownHostsPath if KnownHosts is nil.
    30  	// Users of DialConfig can leave this nil and
    31  	// simply provide ClientKnownHostsPath. It is
    32  	// exposed in case you need to invalidate the
    33  	// cache and start again.
    34  	KnownHosts *KnownHosts
    35  
    36  	// the username to login under
    37  	Mylogin string
    38  
    39  	// the path on the local file system (client side) from
    40  	// which to read the client's RSA private key.
    41  	RsaPath string
    42  
    43  	// the time-based one-time password configuration
    44  	TotpUrl string
    45  
    46  	// Pw is the passphrase
    47  	Pw string
    48  
    49  	// which sshd to connect to, host and port.
    50  	Sshdhost string
    51  	Sshdport int64
    52  
    53  	// DownstreamHostPort is the host:port string of
    54  	// the tcp address to which the sshd should forward
    55  	// our connection to.
    56  	DownstreamHostPort string
    57  
    58  	// TofuAddIfNotKnown, for maximum security,
    59  	// should be always left false and
    60  	// the host key database should be configured
    61  	// manually. If true, the client trusts the server's
    62  	// provided key and stores it, which creates
    63  	// vulnerability to a MITM attack.
    64  	//
    65  	// TOFU stands for Trust-On-First-Use.
    66  	//
    67  	// If set to true, Dial() will stoop
    68  	// after storing a new key, or error
    69  	// out if the key is already known.
    70  	// In either case, a 2nd attempt at
    71  	// Dial is required wherein on the
    72  	// TofuAddIfNotKnown is set to false.
    73  	//
    74  	TofuAddIfNotKnown bool
    75  
    76  	// DoNotUpdateSshKnownHosts prevents writing
    77  	// to the file given by ClientKnownHostsPath, if true.
    78  	DoNotUpdateSshKnownHosts bool
    79  
    80  	Verbose bool
    81  
    82  	// test only; see SshegoConfig
    83  	TestAllowOneshotConnect bool
    84  
    85  	// SkipKeepAlive default to false and we send
    86  	// a keepalive every so often.
    87  	SkipKeepAlive bool
    88  
    89  	KeepAliveEvery time.Duration // default 1 second
    90  
    91  	// identify who is calling.
    92  	LocalNickname string
    93  
    94  	// remote destination for sshdhost
    95  	DestNickname string
    96  }
    97  
    98  // Dial is a convenience method for contacting an sshd
    99  // over tcp and creating a direct-tcpip encrypted stream.
   100  // It is a simple two-step sequence of calling
   101  // dc.Cfg.SSHConnect() and then calling Dial() on the
   102  // returned *ssh.Client.
   103  //
   104  // PRE: dc.Cfg.KnownHosts should already be instantiated.
   105  // To prevent MITM attacks, the host we contact at
   106  // hostport must have its server key must be already
   107  // in the KnownHosts.
   108  //
   109  // dc.RsaPath is the path to the our (the client's) rsa
   110  // private key file.
   111  //
   112  // dc.DownstreamHostPort is the host:port tcp address string
   113  // to which the sshd should forward our connection after successful
   114  // authentication.
   115  //
   116  
   117  func (dc *DialConfig) DeriveNewConfig() (cfg *SshegoConfig, err error) {
   118  
   119  	cfg = NewSshegoConfig()
   120  	cfg.Nickname = dc.LocalNickname
   121  	cfg.BitLenRSAkeys = 4096
   122  	cfg.DirectTcp = true
   123  	cfg.AddIfNotKnown = dc.TofuAddIfNotKnown
   124  	cfg.Debug = dc.Verbose
   125  	cfg.TestAllowOneshotConnect = dc.TestAllowOneshotConnect
   126  	cfg.IdleTimeoutDur = 5 * time.Second
   127  	if !dc.SkipKeepAlive {
   128  		if dc.KeepAliveEvery <= 0 {
   129  			cfg.KeepAliveEvery = time.Second // default to 1 sec.
   130  		} else {
   131  			cfg.KeepAliveEvery = dc.KeepAliveEvery
   132  		}
   133  	}
   134  
   135  	p("DialConfig.Dial: dc= %#v\n", dc)
   136  	if dc.KnownHosts == nil {
   137  		dc.KnownHosts, err = NewKnownHosts(dc.ClientKnownHostsPath, KHSsh)
   138  		if err != nil {
   139  			return nil, err
   140  		}
   141  		p("after NewKnownHosts: DialConfig.Dial: dc.KnownHosts = %#v\n", dc.KnownHosts)
   142  		dc.KnownHosts.NoSave = dc.DoNotUpdateSshKnownHosts
   143  	}
   144  	cfg.KnownHosts = dc.KnownHosts
   145  	cfg.PrivateKeyPath = dc.RsaPath
   146  	return cfg, nil
   147  }
   148  
   149  // cfg0 can be nil, in which case we will make
   150  // a new SshegoConfig and return it in cfg. If
   151  // cfg0 is not nil, then we use it and return
   152  // it in cfg.
   153  func (dc *DialConfig) Dial(parCtx context.Context, cfg0 *SshegoConfig, skipDownstream bool) (nc net.Conn, sshClient *ssh.Client, cfg *SshegoConfig, err error) {
   154  
   155  	if cfg0 == nil {
   156  		cfg, err = dc.DeriveNewConfig()
   157  		if err != nil {
   158  			return
   159  		}
   160  	} else {
   161  		cfg = cfg0
   162  	}
   163  	p("about to SSHConnect to dc.Sshdhost='%s'", dc.Sshdhost)
   164  	p("  ...and SSHConnect called on cfg = '%#v'\n", cfg)
   165  
   166  	// connection refused errors are common enough
   167  	// that we do a simple retry logic after a brief pause here.
   168  	retryCount := 3
   169  	try := 0
   170  	var okCtx context.Context
   171  	var okHalt *ssh.Halter
   172  
   173  	for ; try < retryCount; try++ {
   174  		ctx, cancelctx := context.WithCancel(parCtx)
   175  		childHalt := ssh.NewHalter()
   176  		// the 2nd argument is the underlying most-basic
   177  		// TCP net.Conn. We don't need to retrieve here since
   178  		// ctx or cfg.Halt will close it for us if need be.
   179  		sshClient, _, err = cfg.SSHConnect(ctx, dc.KnownHosts,
   180  			dc.Mylogin, dc.RsaPath, dc.Sshdhost, dc.Sshdport,
   181  			dc.Pw, dc.TotpUrl, childHalt)
   182  		if err == nil {
   183  			// tie ctx and childHalt together
   184  			go ssh.MAD(ctx, cancelctx, childHalt)
   185  			okCtx = ctx
   186  			okHalt = childHalt
   187  			break
   188  		} else {
   189  			cancelctx()
   190  			childHalt.RequestStop()
   191  			childHalt.MarkDone()
   192  			if strings.Contains(err.Error(), "getsockopt: connection refused") {
   193  				// simple connection error, just try again in a bit
   194  				time.Sleep(10 * time.Millisecond)
   195  				continue
   196  			}
   197  			break
   198  		}
   199  	}
   200  	if err != nil {
   201  		return nil, nil, nil, err
   202  	}
   203  	// enforce safe known-hosts hygene
   204  	//cfg.TestAllowOneshotConnect = false
   205  	//cfg.AddIfNotKnown = false
   206  	//dc.TofuAddIfNotKnown = false
   207  
   208  	if skipDownstream {
   209  		return nil, sshClient, cfg, err
   210  	}
   211  
   212  	// Here is how to dial over an encrypted ssh channel.
   213  	// This produces direct-tcpip forwarding -- in other
   214  	// words we talk to the server at dest via the sshd,
   215  	// but no other port is opened and so we have
   216  	// exclusive access. This locally prevents other users and
   217  	// their processes on this localhost from also
   218  	// using the ssh connection (i.e. without authenticating).
   219  	// The local end of a simple tunnel is vulnerable to
   220  	// such issues.
   221  
   222  	hp := strings.Trim(dc.DownstreamHostPort, "\n\r\t ")
   223  	tryUnixDomain := false
   224  	var host string
   225  	if strings.HasSuffix(hp, ":-2") {
   226  		tryUnixDomain = true
   227  		host = hp[:len(hp)-3]
   228  	} else {
   229  		host, _, err = net.SplitHostPort(hp)
   230  	}
   231  	if err != nil {
   232  		if strings.Contains(err.Error(), "missing port in address") {
   233  			// probably unix-domain
   234  			tryUnixDomain = true
   235  			host = hp
   236  		} else {
   237  			log.Printf("error from net.SplitHostPort on '%s': '%v'",
   238  				hp, err)
   239  			return nil, nil, nil, fmt.Errorf("error from net.SplitHostPort "+
   240  				"on '%s': '%v'", hp, err)
   241  		}
   242  	}
   243  	if tryUnixDomain || (len(host) > 0 && host[0] == '/') {
   244  		// a unix-domain socket request
   245  		nc, err = DialRemoteUnixDomain(okCtx, sshClient, host, okHalt)
   246  		p("DialRemoteUnixDomain had error '%v'", err)
   247  		return nc, sshClient, cfg, err
   248  	}
   249  	sshClient.TmpCtx = okCtx
   250  	nc, err = sshClient.Dial("tcp", hp)
   251  
   252  	return nc, sshClient, cfg, err
   253  }
   254  
   255  type KeepAlivePing struct {
   256  	Sent    time.Time `zid:"0"`
   257  	Replied time.Time `zid:"1"`
   258  	Serial  int64     `zid:"2"`
   259  }
   260  
   261  // startKeepalives starts a background goroutine
   262  // that will send a keepalive on sshClientConn
   263  // every dur (default every second).
   264  //
   265  func (cfg *SshegoConfig) startKeepalives(ctx context.Context, dur time.Duration, sshClientConn *ssh.Client, uhp *UHP) error {
   266  	if dur <= 0 {
   267  		panic(fmt.Sprintf("cannot call startKeepalives with dur <= 0: dur=%v", dur))
   268  	}
   269  
   270  	serial := int64(0)
   271  	var ping KeepAlivePing
   272  	ping.Sent = time.Now()
   273  	pingBy, err := ping.MarshalMsg(nil)
   274  	panicOn(err)
   275  	serial++
   276  
   277  	responseStatus, responsePayload, err := sshClientConn.SendRequest(ctx, "keepalive@sshego.glycerine.github.com", true, pingBy)
   278  	if err != nil {
   279  		return err
   280  	}
   281  	//pp("startKeepalives: have responseStatus: '%v'", responseStatus)
   282  
   283  	if responseStatus {
   284  		n := len(responsePayload)
   285  		if n > 0 {
   286  			var ping2 KeepAlivePing
   287  			_, err := ping2.UnmarshalMsg(responsePayload)
   288  			if err == nil {
   289  				//pp("startKeepalives: have responsePayload.Replied: '%v'/serial=%v. at now='%v'", ping2.Replied, ping2.Serial, time.Now())
   290  			}
   291  		}
   292  	}
   293  	go func() {
   294  		for {
   295  			select {
   296  			case <-time.After(dur):
   297  				ping.Sent = time.Now()
   298  				ping.Serial = serial
   299  				serial++
   300  				pingBy, err := ping.MarshalMsg(nil)
   301  				panicOn(err)
   302  
   303  				responseStatus, responsePayload, err := sshClientConn.SendRequest(
   304  					ctx, "keepalive@sshego.glycerine.github.com", true, pingBy)
   305  				if err != nil {
   306  					log.Printf("%s startKeepalives: keepalive send error: '%v', notifying reconnect needed to '%#v'", cfg.Nickname, err, uhp)
   307  					// notify here
   308  					cfg.ClientReconnectNeededTower.Broadcast(uhp)
   309  					//pp("SshegoConfig.startKeepalives() goroutine exiting!")
   310  					return
   311  				}
   312  				//pp("startKeepalives: have responseStatus: '%v'", responseStatus)
   313  
   314  				if responseStatus {
   315  					n := len(responsePayload)
   316  					if n > 0 {
   317  						var ping3 KeepAlivePing
   318  						_, err := ping3.UnmarshalMsg(responsePayload)
   319  						if err == nil {
   320  							//p("startKeepalives: have "+
   321  							//	"responsePayload.Replied: '%v'/serial=%v. at now='%v'",
   322  							//	ping3.Replied, ping3.Serial, time.Now())
   323  						}
   324  					}
   325  				} else {
   326  					// !responseStatus
   327  				}
   328  
   329  			case <-sshClientConn.Halt.ReqStopChan():
   330  				return
   331  			}
   332  		}
   333  	}()
   334  	return nil
   335  }
   336  
   337  // derived from ssh.NewClient: NewSSHClient creates a Client on top of the given connection.
   338  func (cfg *SshegoConfig) NewSSHClient(ctx context.Context, c ssh.Conn, chans <-chan ssh.NewChannel, reqs <-chan *ssh.Request, halt *ssh.Halter) *ssh.Client {
   339  	conn := &ssh.Client{
   340  		Conn:            c,
   341  		ChannelHandlers: make(map[string]chan ssh.NewChannel, 1),
   342  		Halt:            halt,
   343  	}
   344  
   345  	// replace conn.HandleGlobalRequests with custom handler.
   346  	//go conn.HandleGlobalRequests(ctx, reqs)
   347  	go customHandleGlobalRequests(ctx, conn, reqs)
   348  
   349  	go conn.HandleChannelOpens(ctx, chans)
   350  	go func() {
   351  		conn.Wait()
   352  		conn.Forwards.CloseAll()
   353  	}()
   354  	go conn.Forwards.HandleChannels(ctx, conn.HandleChannelOpen("forwarded-tcpip"), c)
   355  	go conn.Forwards.HandleChannels(ctx, conn.HandleChannelOpen("forwarded-streamlocal@openssh.com"), c)
   356  
   357  	// custom-inproc-stream is how reptile replication requests are sent,
   358  	// originating from the server and sent to the client.
   359  	if len(cfg.CustomChannelHandlers) > 0 && cfg.CustomChannelHandlers["custom-inproc-stream"] != nil {
   360  		var ca *ConnectionAlert
   361  		// or ???
   362  		//		ca := &ConnectionAlert{
   363  		//			PortOne:  make(chan ssh.Channel),
   364  		//			ShutDown: cfg.Halt.ReqStopChan(),
   365  		//		}
   366  
   367  		newChanChan := conn.HandleChannelOpen("custom-inproc-stream")
   368  		if newChanChan != nil {
   369  			go cfg.handleChannels(ctx, newChanChan, c, ca)
   370  		}
   371  	}
   372  
   373  	return conn
   374  }
   375  
   376  func customHandleGlobalRequests(ctx context.Context, sshCli *ssh.Client, incoming <-chan *ssh.Request) {
   377  
   378  	for {
   379  		select {
   380  		case r := <-incoming:
   381  			if r == nil {
   382  				continue
   383  			}
   384  			log.Printf("customHandleGlobalRequests sees request r='%#v'", r)
   385  			if r.Type != "keepalive@sshego.glycerine.github.com" || len(r.Payload) == 0 {
   386  				// This handles keepalive messages and matches
   387  				// the behaviour of OpenSSH.
   388  				r.Reply(false, nil)
   389  				continue
   390  			}
   391  
   392  			var ping KeepAlivePing
   393  			_, err := ping.UnmarshalMsg(r.Payload)
   394  			if err != nil {
   395  				r.Reply(false, nil)
   396  				continue
   397  			}
   398  
   399  			now := time.Now()
   400  			log.Printf("customHandleGlobalRequests sees keepalive! ping: '%#v'. setting replied to now='%v'", ping, now)
   401  
   402  			ping.Replied = now
   403  			pingReplyBy, err := ping.MarshalMsg(nil)
   404  			panicOn(err)
   405  			r.Reply(true, pingReplyBy)
   406  
   407  		case <-sshCli.Halt.ReqStopChan():
   408  			return
   409  		case <-sshCli.Conn.Done():
   410  			return
   411  		case <-ctx.Done():
   412  			return
   413  		}
   414  	}
   415  }