github.com/devops-filetransfer/sshego@v7.0.4+incompatible/tri.go (about)

     1  package sshego
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"io"
     7  	"net"
     8  	"strings"
     9  	"time"
    10  
    11  	ssh "github.com/glycerine/sshego/xendor/github.com/glycerine/xcryptossh"
    12  )
    13  
    14  var ErrShutdown = fmt.Errorf("shutting down")
    15  
    16  // Tricorder records (holds) three key objects:
    17  //   an *ssh.Client, the underlyign net.Conn, and a
    18  //   set of ssh.Channel(s).
    19  //
    20  // Tricorder supports auto reconnect when disconnected.
    21  //
    22  // There should be exactly one Tricorder per (username, sshdHost, sshdPort) triple.
    23  //
    24  type Tricorder struct {
    25  	Name string
    26  
    27  	// shuts down everything, include the cli
    28  	Halt *ssh.Halter
    29  
    30  	// shared with cfg
    31  	ClientReconnectNeededTower *UHPTower
    32  
    33  	// optional, parent can provide us
    34  	// a Halter, and we will ParentHalt.AddDownstream(self.ChannelHalt)
    35  	parentHalt *ssh.Halter
    36  
    37  	// should only reflect close of the internal sshChannels, not cli nor nc.
    38  	// This is not public because we may replace it internally during run.
    39  	channelsHalt *ssh.Halter
    40  
    41  	dc  *DialConfig
    42  	cfg *SshegoConfig
    43  
    44  	sshdHostPort string
    45  
    46  	cli         *ssh.Client
    47  	nc          io.Closer
    48  	uhp         *UHP
    49  	sshChannels map[net.Conn]context.CancelFunc
    50  
    51  	getChannelCh      chan *getChannelTicket
    52  	getCliCh          chan *ssh.Client
    53  	getNcCh           chan io.Closer
    54  	reconnectNeededCh chan *UHP
    55  
    56  	tofu bool
    57  
    58  	retries             int           // example: 10
    59  	pauseBetweenRetries time.Duration // example: 1000 * time.Millisecond
    60  
    61  	lastConnectTime time.Time
    62  }
    63  
    64  /*
    65  NewTricorder has got to wait to allocate
    66  ssh.Channel until requested. Otherwise we
    67  make too many, and get them mixed up.
    68  */
    69  func NewTricorder(dc *DialConfig, halt *ssh.Halter, name string) (tri *Tricorder, err error) {
    70  
    71  	cfg, err := dc.DeriveNewConfig()
    72  	if err != nil {
    73  		return nil, err
    74  	}
    75  	sshdHostPort := fmt.Sprintf("%v:%v", dc.Sshdhost, dc.Sshdport)
    76  
    77  	tri = &Tricorder{
    78  		Name:         name,
    79  		dc:           dc,
    80  		cfg:          cfg,
    81  		sshdHostPort: sshdHostPort,
    82  		parentHalt:   halt,
    83  		Halt:         ssh.NewHalter(),
    84  		channelsHalt: ssh.NewHalter(),
    85  
    86  		sshChannels: make(map[net.Conn]context.CancelFunc),
    87  
    88  		reconnectNeededCh:   make(chan *UHP, 1),
    89  		getChannelCh:        make(chan *getChannelTicket),
    90  		getCliCh:            make(chan *ssh.Client),
    91  		getNcCh:             make(chan io.Closer),
    92  		tofu:                dc.TofuAddIfNotKnown,
    93  		retries:             10,
    94  		pauseBetweenRetries: 1000 * time.Millisecond,
    95  	}
    96  	tri.uhp = &UHP{
    97  		User:     tri.dc.Mylogin,
    98  		HostPort: tri.sshdHostPort,
    99  		Nickname: tri.dc.DestNickname,
   100  	}
   101  
   102  	if tri.parentHalt != nil {
   103  		tri.parentHalt.AddDownstream(tri.Halt)
   104  	}
   105  	tri.Halt.AddDownstream(tri.channelsHalt)
   106  	cfg.ClientReconnectNeededTower.Subscribe(tri.reconnectNeededCh)
   107  	tri.ClientReconnectNeededTower = cfg.ClientReconnectNeededTower
   108  
   109  	err = tri.startReconnectLoop()
   110  	if err != nil {
   111  		return nil, err
   112  	}
   113  	return tri, nil
   114  }
   115  
   116  // CustomInprocStreamChanName is how sshego/reptile specific
   117  // channels are named.
   118  //const CustomInprocStreamChanName = "custom-inproc-stream"
   119  const CustomInprocStreamChanName = "direct-tcpip"
   120  
   121  func (t *Tricorder) closeChannels() {
   122  	if len(t.sshChannels) > 0 {
   123  		for ch, cancel := range t.sshChannels {
   124  			ch.Close()
   125  			if cancel != nil {
   126  				cancel()
   127  			}
   128  		}
   129  	}
   130  	t.sshChannels = make(map[net.Conn]context.CancelFunc)
   131  }
   132  
   133  func (t *Tricorder) startReconnectLoop() error {
   134  
   135  	// do the initial connect.
   136  	err := t.helperNewClientConnect(context.Background())
   137  	if err != nil {
   138  		return err
   139  	}
   140  
   141  	go func() {
   142  		defer func() {
   143  			t.channelsHalt.RequestStop()
   144  			t.channelsHalt.MarkDone()
   145  			t.Halt.RequestStop()
   146  			t.Halt.MarkDone()
   147  			if t.parentHalt != nil {
   148  				t.parentHalt.RemoveDownstream(t.Halt)
   149  			}
   150  			t.closeChannels()
   151  		}()
   152  		for {
   153  			select {
   154  			case <-t.Halt.ReqStopChan():
   155  				return
   156  			case uhp := <-t.reconnectNeededCh:
   157  				pp("%s Tricorder sees reconnectNeeded to '%#v'!!", uhp, t.Name)
   158  
   159  				if uhp.User != t.uhp.User {
   160  					panic(fmt.Sprintf("%s yikes, bad! uhp from reconnectNeededChan asks for change of user: '%v' != '%v' previous", t.Name, uhp.User, t.uhp.User))
   161  				}
   162  				if uhp.HostPort != t.uhp.HostPort {
   163  					panic(fmt.Sprintf("%s yikes, bad! uhp from reconnectNeededChan asks for change of hostport: '%v' != '%v' previous", t.Name, uhp.HostPort, t.uhp.HostPort))
   164  				}
   165  				now := time.Now()
   166  				if now.Sub(t.lastConnectTime) < time.Second {
   167  					pp("%s Tricorder ignoring reconnectNeeded within "+
   168  						"1 second of successful connection.", t.Name)
   169  					continue
   170  				}
   171  				t.uhp = uhp
   172  				t.closeChannels()
   173  
   174  				t.channelsHalt.RequestStop()
   175  				t.channelsHalt.MarkDone()
   176  
   177  				t.Halt.RemoveDownstream(t.channelsHalt)
   178  				t.channelsHalt = ssh.NewHalter()
   179  				t.Halt.AddDownstream(t.channelsHalt)
   180  
   181  				t.cli = nil
   182  				t.nc = nil
   183  				// need to reconnect!
   184  				ctx := context.Background()
   185  				err := t.helperNewClientConnect(ctx)
   186  				if err == ErrShutdown {
   187  					return
   188  				}
   189  				panicOn(err)
   190  
   191  				// provide current state
   192  			case t.getCliCh <- t.cli:
   193  			case t.getNcCh <- t.nc:
   194  				pp("%s tri sent t.nc='%#v'", t.Name, t.nc)
   195  
   196  				// bring up a new channel
   197  			case tk := <-t.getChannelCh:
   198  				t.helperGetChannel(tk)
   199  			}
   200  		}
   201  	}()
   202  	return nil
   203  }
   204  
   205  // only reconnect, don't open any new channels!
   206  func (t *Tricorder) helperNewClientConnect(ctx context.Context) (err error) {
   207  
   208  	pp("%s Tricorder.helperNewClientConnect starting! t.uhp='%#v'.", t.Name, t.uhp)
   209  
   210  	defer func() {
   211  		if err != nil {
   212  			t.lastConnectTime = time.Now()
   213  		}
   214  	}()
   215  
   216  	destHost, port, err := SplitHostPort(t.uhp.HostPort)
   217  	_, _ = destHost, port
   218  	if err != nil {
   219  		return err
   220  	}
   221  
   222  	// TODO: pw & totpUrl currently required in the test... change this.
   223  	//pw := t.dc.Pw
   224  	//totpUrl := t.dc.TotpUrl
   225  
   226  	//t.cfg.AddIfNotKnown = false
   227  	var sshcli *ssh.Client
   228  	tries := t.retries
   229  	pause := t.pauseBetweenRetries
   230  	if t.cfg.KnownHosts == nil {
   231  		panic("problem! t.cfg.KnownHosts is nil")
   232  	}
   233  	if t.cfg.PrivateKeyPath == "" {
   234  		panic("problem! t.cfg.PrivateKeyPath is empty")
   235  	}
   236  
   237  	var okCtx context.Context
   238  
   239  	for i := 0; i < tries; i++ {
   240  		pp("%s Tricorder.helperNewClientConnect() calling t.dc.Dial(), i=%v", t.Name, i)
   241  
   242  		// check for shutdown request
   243  		select {
   244  		case <-t.Halt.ReqStopChan():
   245  			return ErrShutdown
   246  		default:
   247  		}
   248  
   249  		ctxChild, cancelChildCtx := context.WithCancel(ctx)
   250  
   251  		//t.cfg.AddIfNotKnown = t.tofu
   252  		//t.dc.TofuAddIfNotKnown = t.tofu
   253  
   254  		_, sshcli, _, err = t.dc.Dial(ctxChild, t.cfg, true)
   255  		if err == nil {
   256  			t.tofu = false
   257  			t.cfg.AddIfNotKnown = false
   258  			okCtx = ctxChild
   259  
   260  			if sshcli == nil {
   261  				panic("err must not be nil if sshcli is nil, back from cfg.SSHConnect")
   262  			}
   263  			break
   264  		} else {
   265  			cancelChildCtx()
   266  			errs := err.Error()
   267  			if strings.Contains(errs, "Re-run without -new") {
   268  				if t.tofu {
   269  					p("auto-handling tofu b/c t.tofu is true")
   270  					t.tofu = false
   271  					t.dc.TofuAddIfNotKnown = false
   272  					t.cfg.AddIfNotKnown = false
   273  					continue
   274  				}
   275  				return err
   276  			}
   277  			if strings.Contains(errs, "getsockopt: connection refused") {
   278  				pp("%s Tricorder.helperNewClientConnect: ignoring 'connection refused' and retrying after %v. connecting to '%#v'", t.Name, pause, t.uhp)
   279  				time.Sleep(pause)
   280  				continue
   281  			}
   282  			pp("%s Tricorder: err = '%v'. retrying after %v", t.Name, err, pause)
   283  			time.Sleep(pause)
   284  			continue
   285  		}
   286  	} // end i over tries
   287  
   288  	if sshcli != nil && okCtx != nil {
   289  		sshcli.TmpCtx = okCtx
   290  	}
   291  	if err != nil {
   292  		return err
   293  	}
   294  	pp("good: %s Tricorder.helperNewClientConnect succeeded to '%#v'.", t.Name, t.uhp)
   295  	t.cli = sshcli
   296  	if t.cli != nil {
   297  		t.nc = t.cli.NcCloser()
   298  	} else {
   299  		panic("why no NcCloser()???")
   300  	}
   301  	return nil
   302  }
   303  
   304  func (t *Tricorder) helperGetChannel(tk *getChannelTicket) {
   305  
   306  	pp("%s Tricorder.helperGetChannel starting! t.uhp='%#v'", t.Name, t.uhp)
   307  
   308  	var ch ssh.Channel
   309  	var in <-chan *ssh.Request
   310  	var err error
   311  	if t.cli == nil {
   312  		pp("%s Tricorder.helperGetChannel: saw nil cli, so making new client", t.Name)
   313  		err = t.helperNewClientConnect(tk.ctx)
   314  		if err != nil {
   315  			tk.err = err
   316  			close(tk.done)
   317  			return
   318  		}
   319  	}
   320  
   321  	pp("%s Tricorder.helperGetChannel: had cli already, so calling t.cli.Dial()", t.Name)
   322  	discardCtx, discardCtxCancel := context.WithCancel(tk.ctx)
   323  
   324  	if tk.typ == "direct-tcpip" {
   325  		hp := strings.Trim(tk.targetHostPort, "\n\r\t ")
   326  
   327  		pp("%s Tricorder.helperGetChannel dialing hp='%v'", t.Name, hp)
   328  		ch, err = t.cli.DialWithContext(discardCtx, "tcp", hp)
   329  
   330  	} else {
   331  
   332  		ch, in, err = t.cli.OpenChannel(tk.ctx, tk.typ, nil, t.channelsHalt)
   333  		if err == nil {
   334  			go DiscardRequestsExceptKeepalives(discardCtx, in, t.channelsHalt.ReqStopChan())
   335  		}
   336  	}
   337  	if ch != nil {
   338  		t.sshChannels[ch] = discardCtxCancel
   339  
   340  		if t.cfg.IdleTimeoutDur > 0 {
   341  			sshChan, ok := ch.(ssh.Channel)
   342  			if ok {
   343  				sshChan.SetIdleTimeout(t.cfg.IdleTimeoutDur)
   344  			}
   345  		}
   346  	}
   347  
   348  	tk.sshChannel = ch
   349  	tk.err = err
   350  
   351  	close(tk.done)
   352  }
   353  
   354  type getChannelTicket struct {
   355  	done           chan struct{}
   356  	sshChannel     ssh.Channel
   357  	targetHostPort string // leave empty for "custom-inproc-stream", else downstream addr
   358  	typ            string // "direct-tcpip" or "custom-inproc-stream"
   359  	err            error
   360  	ctx            context.Context
   361  }
   362  
   363  func newGetChannelTicket(ctx context.Context) *getChannelTicket {
   364  	return &getChannelTicket{
   365  		done: make(chan struct{}),
   366  		ctx:  ctx,
   367  	}
   368  }
   369  
   370  // typ can be "direct-tcpip" (specify destHostPort), or "custom-inproc-stream"
   371  // in which case leave destHostPort as the empty string.
   372  func (t *Tricorder) SSHChannel(ctx context.Context, typ, targetHostPort string) (ssh.Channel, error) {
   373  	tk := newGetChannelTicket(ctx)
   374  	tk.typ = typ
   375  	tk.targetHostPort = targetHostPort
   376  	t.getChannelCh <- tk
   377  	<-tk.done
   378  	return tk.sshChannel, tk.err
   379  }
   380  
   381  func (t *Tricorder) Cli() (cli *ssh.Client, err error) {
   382  	select {
   383  	case cli = <-t.getCliCh:
   384  	case <-t.Halt.ReqStopChan():
   385  		err = ErrShutdown
   386  	}
   387  	return
   388  }
   389  
   390  func (t *Tricorder) Nc() (nc io.Closer, err error) {
   391  	select {
   392  	case nc = <-t.getNcCh:
   393  	case <-t.Halt.ReqStopChan():
   394  		err = ErrShutdown
   395  	}
   396  	return
   397  }