github.com/devops-filetransfer/sshego@v7.0.4+incompatible/tri.go (about) 1 package sshego 2 3 import ( 4 "context" 5 "fmt" 6 "io" 7 "net" 8 "strings" 9 "time" 10 11 ssh "github.com/glycerine/sshego/xendor/github.com/glycerine/xcryptossh" 12 ) 13 14 var ErrShutdown = fmt.Errorf("shutting down") 15 16 // Tricorder records (holds) three key objects: 17 // an *ssh.Client, the underlyign net.Conn, and a 18 // set of ssh.Channel(s). 19 // 20 // Tricorder supports auto reconnect when disconnected. 21 // 22 // There should be exactly one Tricorder per (username, sshdHost, sshdPort) triple. 23 // 24 type Tricorder struct { 25 Name string 26 27 // shuts down everything, include the cli 28 Halt *ssh.Halter 29 30 // shared with cfg 31 ClientReconnectNeededTower *UHPTower 32 33 // optional, parent can provide us 34 // a Halter, and we will ParentHalt.AddDownstream(self.ChannelHalt) 35 parentHalt *ssh.Halter 36 37 // should only reflect close of the internal sshChannels, not cli nor nc. 38 // This is not public because we may replace it internally during run. 39 channelsHalt *ssh.Halter 40 41 dc *DialConfig 42 cfg *SshegoConfig 43 44 sshdHostPort string 45 46 cli *ssh.Client 47 nc io.Closer 48 uhp *UHP 49 sshChannels map[net.Conn]context.CancelFunc 50 51 getChannelCh chan *getChannelTicket 52 getCliCh chan *ssh.Client 53 getNcCh chan io.Closer 54 reconnectNeededCh chan *UHP 55 56 tofu bool 57 58 retries int // example: 10 59 pauseBetweenRetries time.Duration // example: 1000 * time.Millisecond 60 61 lastConnectTime time.Time 62 } 63 64 /* 65 NewTricorder has got to wait to allocate 66 ssh.Channel until requested. Otherwise we 67 make too many, and get them mixed up. 68 */ 69 func NewTricorder(dc *DialConfig, halt *ssh.Halter, name string) (tri *Tricorder, err error) { 70 71 cfg, err := dc.DeriveNewConfig() 72 if err != nil { 73 return nil, err 74 } 75 sshdHostPort := fmt.Sprintf("%v:%v", dc.Sshdhost, dc.Sshdport) 76 77 tri = &Tricorder{ 78 Name: name, 79 dc: dc, 80 cfg: cfg, 81 sshdHostPort: sshdHostPort, 82 parentHalt: halt, 83 Halt: ssh.NewHalter(), 84 channelsHalt: ssh.NewHalter(), 85 86 sshChannels: make(map[net.Conn]context.CancelFunc), 87 88 reconnectNeededCh: make(chan *UHP, 1), 89 getChannelCh: make(chan *getChannelTicket), 90 getCliCh: make(chan *ssh.Client), 91 getNcCh: make(chan io.Closer), 92 tofu: dc.TofuAddIfNotKnown, 93 retries: 10, 94 pauseBetweenRetries: 1000 * time.Millisecond, 95 } 96 tri.uhp = &UHP{ 97 User: tri.dc.Mylogin, 98 HostPort: tri.sshdHostPort, 99 Nickname: tri.dc.DestNickname, 100 } 101 102 if tri.parentHalt != nil { 103 tri.parentHalt.AddDownstream(tri.Halt) 104 } 105 tri.Halt.AddDownstream(tri.channelsHalt) 106 cfg.ClientReconnectNeededTower.Subscribe(tri.reconnectNeededCh) 107 tri.ClientReconnectNeededTower = cfg.ClientReconnectNeededTower 108 109 err = tri.startReconnectLoop() 110 if err != nil { 111 return nil, err 112 } 113 return tri, nil 114 } 115 116 // CustomInprocStreamChanName is how sshego/reptile specific 117 // channels are named. 118 //const CustomInprocStreamChanName = "custom-inproc-stream" 119 const CustomInprocStreamChanName = "direct-tcpip" 120 121 func (t *Tricorder) closeChannels() { 122 if len(t.sshChannels) > 0 { 123 for ch, cancel := range t.sshChannels { 124 ch.Close() 125 if cancel != nil { 126 cancel() 127 } 128 } 129 } 130 t.sshChannels = make(map[net.Conn]context.CancelFunc) 131 } 132 133 func (t *Tricorder) startReconnectLoop() error { 134 135 // do the initial connect. 136 err := t.helperNewClientConnect(context.Background()) 137 if err != nil { 138 return err 139 } 140 141 go func() { 142 defer func() { 143 t.channelsHalt.RequestStop() 144 t.channelsHalt.MarkDone() 145 t.Halt.RequestStop() 146 t.Halt.MarkDone() 147 if t.parentHalt != nil { 148 t.parentHalt.RemoveDownstream(t.Halt) 149 } 150 t.closeChannels() 151 }() 152 for { 153 select { 154 case <-t.Halt.ReqStopChan(): 155 return 156 case uhp := <-t.reconnectNeededCh: 157 pp("%s Tricorder sees reconnectNeeded to '%#v'!!", uhp, t.Name) 158 159 if uhp.User != t.uhp.User { 160 panic(fmt.Sprintf("%s yikes, bad! uhp from reconnectNeededChan asks for change of user: '%v' != '%v' previous", t.Name, uhp.User, t.uhp.User)) 161 } 162 if uhp.HostPort != t.uhp.HostPort { 163 panic(fmt.Sprintf("%s yikes, bad! uhp from reconnectNeededChan asks for change of hostport: '%v' != '%v' previous", t.Name, uhp.HostPort, t.uhp.HostPort)) 164 } 165 now := time.Now() 166 if now.Sub(t.lastConnectTime) < time.Second { 167 pp("%s Tricorder ignoring reconnectNeeded within "+ 168 "1 second of successful connection.", t.Name) 169 continue 170 } 171 t.uhp = uhp 172 t.closeChannels() 173 174 t.channelsHalt.RequestStop() 175 t.channelsHalt.MarkDone() 176 177 t.Halt.RemoveDownstream(t.channelsHalt) 178 t.channelsHalt = ssh.NewHalter() 179 t.Halt.AddDownstream(t.channelsHalt) 180 181 t.cli = nil 182 t.nc = nil 183 // need to reconnect! 184 ctx := context.Background() 185 err := t.helperNewClientConnect(ctx) 186 if err == ErrShutdown { 187 return 188 } 189 panicOn(err) 190 191 // provide current state 192 case t.getCliCh <- t.cli: 193 case t.getNcCh <- t.nc: 194 pp("%s tri sent t.nc='%#v'", t.Name, t.nc) 195 196 // bring up a new channel 197 case tk := <-t.getChannelCh: 198 t.helperGetChannel(tk) 199 } 200 } 201 }() 202 return nil 203 } 204 205 // only reconnect, don't open any new channels! 206 func (t *Tricorder) helperNewClientConnect(ctx context.Context) (err error) { 207 208 pp("%s Tricorder.helperNewClientConnect starting! t.uhp='%#v'.", t.Name, t.uhp) 209 210 defer func() { 211 if err != nil { 212 t.lastConnectTime = time.Now() 213 } 214 }() 215 216 destHost, port, err := SplitHostPort(t.uhp.HostPort) 217 _, _ = destHost, port 218 if err != nil { 219 return err 220 } 221 222 // TODO: pw & totpUrl currently required in the test... change this. 223 //pw := t.dc.Pw 224 //totpUrl := t.dc.TotpUrl 225 226 //t.cfg.AddIfNotKnown = false 227 var sshcli *ssh.Client 228 tries := t.retries 229 pause := t.pauseBetweenRetries 230 if t.cfg.KnownHosts == nil { 231 panic("problem! t.cfg.KnownHosts is nil") 232 } 233 if t.cfg.PrivateKeyPath == "" { 234 panic("problem! t.cfg.PrivateKeyPath is empty") 235 } 236 237 var okCtx context.Context 238 239 for i := 0; i < tries; i++ { 240 pp("%s Tricorder.helperNewClientConnect() calling t.dc.Dial(), i=%v", t.Name, i) 241 242 // check for shutdown request 243 select { 244 case <-t.Halt.ReqStopChan(): 245 return ErrShutdown 246 default: 247 } 248 249 ctxChild, cancelChildCtx := context.WithCancel(ctx) 250 251 //t.cfg.AddIfNotKnown = t.tofu 252 //t.dc.TofuAddIfNotKnown = t.tofu 253 254 _, sshcli, _, err = t.dc.Dial(ctxChild, t.cfg, true) 255 if err == nil { 256 t.tofu = false 257 t.cfg.AddIfNotKnown = false 258 okCtx = ctxChild 259 260 if sshcli == nil { 261 panic("err must not be nil if sshcli is nil, back from cfg.SSHConnect") 262 } 263 break 264 } else { 265 cancelChildCtx() 266 errs := err.Error() 267 if strings.Contains(errs, "Re-run without -new") { 268 if t.tofu { 269 p("auto-handling tofu b/c t.tofu is true") 270 t.tofu = false 271 t.dc.TofuAddIfNotKnown = false 272 t.cfg.AddIfNotKnown = false 273 continue 274 } 275 return err 276 } 277 if strings.Contains(errs, "getsockopt: connection refused") { 278 pp("%s Tricorder.helperNewClientConnect: ignoring 'connection refused' and retrying after %v. connecting to '%#v'", t.Name, pause, t.uhp) 279 time.Sleep(pause) 280 continue 281 } 282 pp("%s Tricorder: err = '%v'. retrying after %v", t.Name, err, pause) 283 time.Sleep(pause) 284 continue 285 } 286 } // end i over tries 287 288 if sshcli != nil && okCtx != nil { 289 sshcli.TmpCtx = okCtx 290 } 291 if err != nil { 292 return err 293 } 294 pp("good: %s Tricorder.helperNewClientConnect succeeded to '%#v'.", t.Name, t.uhp) 295 t.cli = sshcli 296 if t.cli != nil { 297 t.nc = t.cli.NcCloser() 298 } else { 299 panic("why no NcCloser()???") 300 } 301 return nil 302 } 303 304 func (t *Tricorder) helperGetChannel(tk *getChannelTicket) { 305 306 pp("%s Tricorder.helperGetChannel starting! t.uhp='%#v'", t.Name, t.uhp) 307 308 var ch ssh.Channel 309 var in <-chan *ssh.Request 310 var err error 311 if t.cli == nil { 312 pp("%s Tricorder.helperGetChannel: saw nil cli, so making new client", t.Name) 313 err = t.helperNewClientConnect(tk.ctx) 314 if err != nil { 315 tk.err = err 316 close(tk.done) 317 return 318 } 319 } 320 321 pp("%s Tricorder.helperGetChannel: had cli already, so calling t.cli.Dial()", t.Name) 322 discardCtx, discardCtxCancel := context.WithCancel(tk.ctx) 323 324 if tk.typ == "direct-tcpip" { 325 hp := strings.Trim(tk.targetHostPort, "\n\r\t ") 326 327 pp("%s Tricorder.helperGetChannel dialing hp='%v'", t.Name, hp) 328 ch, err = t.cli.DialWithContext(discardCtx, "tcp", hp) 329 330 } else { 331 332 ch, in, err = t.cli.OpenChannel(tk.ctx, tk.typ, nil, t.channelsHalt) 333 if err == nil { 334 go DiscardRequestsExceptKeepalives(discardCtx, in, t.channelsHalt.ReqStopChan()) 335 } 336 } 337 if ch != nil { 338 t.sshChannels[ch] = discardCtxCancel 339 340 if t.cfg.IdleTimeoutDur > 0 { 341 sshChan, ok := ch.(ssh.Channel) 342 if ok { 343 sshChan.SetIdleTimeout(t.cfg.IdleTimeoutDur) 344 } 345 } 346 } 347 348 tk.sshChannel = ch 349 tk.err = err 350 351 close(tk.done) 352 } 353 354 type getChannelTicket struct { 355 done chan struct{} 356 sshChannel ssh.Channel 357 targetHostPort string // leave empty for "custom-inproc-stream", else downstream addr 358 typ string // "direct-tcpip" or "custom-inproc-stream" 359 err error 360 ctx context.Context 361 } 362 363 func newGetChannelTicket(ctx context.Context) *getChannelTicket { 364 return &getChannelTicket{ 365 done: make(chan struct{}), 366 ctx: ctx, 367 } 368 } 369 370 // typ can be "direct-tcpip" (specify destHostPort), or "custom-inproc-stream" 371 // in which case leave destHostPort as the empty string. 372 func (t *Tricorder) SSHChannel(ctx context.Context, typ, targetHostPort string) (ssh.Channel, error) { 373 tk := newGetChannelTicket(ctx) 374 tk.typ = typ 375 tk.targetHostPort = targetHostPort 376 t.getChannelCh <- tk 377 <-tk.done 378 return tk.sshChannel, tk.err 379 } 380 381 func (t *Tricorder) Cli() (cli *ssh.Client, err error) { 382 select { 383 case cli = <-t.getCliCh: 384 case <-t.Halt.ReqStopChan(): 385 err = ErrShutdown 386 } 387 return 388 } 389 390 func (t *Tricorder) Nc() (nc io.Closer, err error) { 391 select { 392 case nc = <-t.getNcCh: 393 case <-t.Halt.ReqStopChan(): 394 err = ErrShutdown 395 } 396 return 397 }