github.com/pf-qiu/concourse/v6@v6.7.3-0.20201207032516-1f455d73275f/tsa/tsacmd/server.go (about) 1 package tsacmd 2 3 import ( 4 "context" 5 "flag" 6 "fmt" 7 "io" 8 "net" 9 "net/http" 10 "strings" 11 "sync" 12 "time" 13 14 "code.cloudfoundry.org/lager" 15 "code.cloudfoundry.org/lager/lagerctx" 16 "github.com/pf-qiu/concourse/v6/tsa" 17 "golang.org/x/crypto/ssh" 18 ) 19 20 const maxForwards = 2 21 22 type server struct { 23 logger lager.Logger 24 atcEndpointPicker tsa.EndpointPicker 25 heartbeatInterval time.Duration 26 cprInterval time.Duration 27 gardenRequestTimeout time.Duration 28 forwardHost string 29 config *ssh.ServerConfig 30 httpClient *http.Client 31 sessionTeam *sessionTeam 32 } 33 34 type sessionTeam struct { 35 sessionTeams map[string]string 36 lock *sync.RWMutex 37 } 38 39 func (s *sessionTeam) AuthorizeTeam(sessionID, team string) { 40 s.lock.Lock() 41 defer s.lock.Unlock() 42 43 s.sessionTeams[sessionID] = team 44 } 45 46 func (s *sessionTeam) IsNotAuthorized(sessionID, team string) bool { 47 s.lock.RLock() 48 defer s.lock.RUnlock() 49 50 t, found := s.sessionTeams[sessionID] 51 52 return found && t != team 53 } 54 55 func (s *sessionTeam) AuthorizedTeamFor(sessionID string) string { 56 s.lock.RLock() 57 defer s.lock.RUnlock() 58 59 return s.sessionTeams[sessionID] 60 } 61 62 type ConnState struct { 63 Team string 64 65 ForwardedTCPIPs <-chan ForwardedTCPIP 66 } 67 68 type ForwardedTCPIP struct { 69 Logger lager.Logger 70 71 BindAddr string 72 BoundPort uint32 73 74 Drain chan<- struct{} 75 76 wg *sync.WaitGroup 77 } 78 79 func (forward ForwardedTCPIP) Wait() { 80 forward.Logger.Debug("draining") 81 forward.wg.Wait() 82 forward.Logger.Debug("drained") 83 } 84 85 func (server *server) Serve(listener net.Listener) { 86 for { 87 c, err := listener.Accept() 88 if err != nil { 89 if !strings.Contains(err.Error(), "use of closed network connection") { 90 server.logger.Error("failed-to-accept", err) 91 } 92 93 return 94 } 95 96 logger := server.logger.Session("connection", lager.Data{ 97 "remote": c.RemoteAddr().String(), 98 }) 99 100 go server.handshake(logger, c) 101 } 102 } 103 104 func (server *server) handshake(logger lager.Logger, netConn net.Conn) { 105 conn, chans, reqs, err := ssh.NewServerConn(netConn, server.config) 106 if err != nil { 107 logger.Info("handshake-failed", lager.Data{"error": err.Error()}) 108 return 109 } 110 111 defer conn.Close() 112 113 ctx, cancel := context.WithCancel(lagerctx.NewContext(context.Background(), logger)) 114 defer cancel() 115 116 sessionID := string(conn.SessionID()) 117 118 forwardedTCPIPs := make(chan ForwardedTCPIP, maxForwards) 119 go server.handleForwardRequests(ctx, conn, reqs, forwardedTCPIPs) 120 121 state := ConnState{ 122 Team: server.sessionTeam.AuthorizedTeamFor(sessionID), 123 124 ForwardedTCPIPs: forwardedTCPIPs, 125 } 126 127 chansGroup := new(sync.WaitGroup) 128 129 for newChannel := range chans { 130 if newChannel.ChannelType() != "session" { 131 logger.Info("rejecting-unknown-channel-type", lager.Data{ 132 "type": newChannel.ChannelType(), 133 }) 134 135 newChannel.Reject(ssh.UnknownChannelType, "unknown channel type") 136 continue 137 } 138 139 channel, requests, err := newChannel.Accept() 140 if err != nil { 141 logger.Error("failed-to-accept-channel", err) 142 return 143 } 144 145 chansGroup.Add(1) 146 go server.handleChannel(logger.Session("channel"), chansGroup, channel, requests, state) 147 } 148 149 chansGroup.Wait() 150 } 151 152 type signalMsg struct { 153 Signal string 154 } 155 156 func (server *server) handleChannel( 157 logger lager.Logger, 158 chansGroup *sync.WaitGroup, 159 channel ssh.Channel, 160 requests <-chan *ssh.Request, 161 state ConnState, 162 ) { 163 ctx, cancel := context.WithCancel(context.Background()) 164 defer cancel() 165 166 defer chansGroup.Done() 167 defer channel.Close() 168 169 execExited := make(chan error, 1) 170 171 for { 172 select { 173 case req, ok := <-requests: 174 if !ok { 175 return 176 } 177 178 logger.Debug("channel-request", lager.Data{ 179 "type": req.Type, 180 }) 181 182 switch req.Type { 183 case "signal": 184 req.Reply(true, nil) 185 186 var sig signalMsg 187 err := ssh.Unmarshal(req.Payload, &sig) 188 if err != nil { 189 logger.Error("malformed-signal", err) 190 req.Reply(false, nil) 191 continue 192 } 193 194 logger.Debug("received-signal", lager.Data{ 195 "signal": sig, 196 }) 197 198 cancel() 199 200 case "exec": 201 var request execRequest 202 err := ssh.Unmarshal(req.Payload, &request) 203 if err != nil { 204 logger.Error("malformed-exec-request", err) 205 req.Reply(false, nil) 206 return 207 } 208 209 workerRequest, command, err := server.parseRequest(request.Command) 210 if err != nil { 211 fmt.Fprintf(channel, "invalid command: %s", err) 212 req.Reply(false, nil) 213 continue 214 } 215 216 req.Reply(true, nil) 217 218 cmdLogger := logger.Session("command", lager.Data{ 219 "command": command, 220 }) 221 222 go func() { 223 execExited <- workerRequest.Handle(lagerctx.NewContext(ctx, cmdLogger), state, channel) 224 }() 225 226 default: 227 logger.Info("rejecting") 228 req.Reply(false, nil) 229 continue 230 } 231 232 case err := <-execExited: 233 req := exitStatusRequest{0} 234 235 if err != nil { 236 logger.Error("exited-with-error", err) 237 req.ExitStatus = 1 238 } else { 239 logger.Debug("exited-successfully") 240 } 241 242 _, err = channel.SendRequest("exit-status", false, ssh.Marshal(req)) 243 if err != nil { 244 logger.Error("failed-to-send-exit-status", err) 245 } 246 247 // RFC 4254: "The channel needs to be closed with SSH_MSG_CHANNEL_CLOSE after 248 // this message." 249 err = channel.Close() 250 if err != nil { 251 logger.Error("failed-to-close-channel", err) 252 } else { 253 logger.Debug("closed-channel") 254 } 255 } 256 } 257 } 258 259 func (server *server) handleForwardRequests( 260 ctx context.Context, 261 conn *ssh.ServerConn, 262 reqs <-chan *ssh.Request, 263 forwardedTCPIPs chan<- ForwardedTCPIP, 264 ) { 265 logger := lagerctx.FromContext(ctx) 266 267 var forwardedThings int 268 269 for r := range reqs { 270 reqLog := logger.Session("request", lager.Data{ 271 "type": r.Type, 272 }) 273 274 switch r.Type { 275 case "tcpip-forward": 276 forwardedThings++ 277 278 if forwardedThings > maxForwards { 279 reqLog.Info("rejecting-extra-forward-request") 280 r.Reply(false, nil) 281 continue 282 } 283 284 var req tcpipForwardRequest 285 err := ssh.Unmarshal(r.Payload, &req) 286 if err != nil { 287 reqLog.Error("malformed-tcpip-request", err) 288 r.Reply(false, nil) 289 continue 290 } 291 292 bindAddr := net.JoinHostPort(req.BindIP, fmt.Sprintf("%d", req.BindPort)) 293 294 listener, err := net.Listen("tcp", "0.0.0.0:0") 295 if err != nil { 296 reqLog.Error("failed-to-listen", err) 297 r.Reply(false, nil) 298 continue 299 } 300 301 defer listener.Close() 302 303 _, port, err := net.SplitHostPort(listener.Addr().String()) 304 if err != nil { 305 r.Reply(false, nil) 306 continue 307 } 308 309 var res tcpipForwardResponse 310 _, err = fmt.Sscanf(port, "%d", &res.BoundPort) 311 if err != nil { 312 r.Reply(false, nil) 313 continue 314 } 315 316 reqLog = reqLog.WithData(lager.Data{ 317 "addr": listener.Addr().String(), 318 "requested-addr": bindAddr, 319 }) 320 321 reqLog.Debug("listening") 322 323 forPort := req.BindPort 324 if forPort == 0 { 325 forPort = res.BoundPort 326 } 327 328 drain := make(chan struct{}) 329 wait := new(sync.WaitGroup) 330 331 wait.Add(1) 332 go server.forwardTCPIP(lagerctx.NewContext(ctx, reqLog), drain, wait, conn, listener, req.BindIP, forPort) 333 334 forwardedTCPIPs <- ForwardedTCPIP{ 335 Logger: reqLog, 336 337 BindAddr: bindAddr, 338 BoundPort: res.BoundPort, 339 340 Drain: drain, 341 342 wg: wait, 343 } 344 345 r.Reply(true, ssh.Marshal(res)) 346 347 default: 348 // OpenSSH sends keepalive@openssh.com, but there may be other clients; 349 // just check for 'keepalive' 350 if strings.Contains(r.Type, "keepalive") { 351 reqLog.Debug("keepalive") 352 r.Reply(true, nil) 353 } else { 354 reqLog.Info("ignoring") 355 r.Reply(false, nil) 356 } 357 } 358 } 359 } 360 361 func (server *server) forwardTCPIP( 362 ctx context.Context, 363 drain <-chan struct{}, 364 connsWg *sync.WaitGroup, 365 conn *ssh.ServerConn, 366 listener net.Listener, 367 forwardIP string, 368 forwardPort uint32, 369 ) { 370 defer connsWg.Done() 371 372 logger := lagerctx.FromContext(ctx) 373 374 done := make(chan struct{}) 375 defer close(done) 376 377 interrupted := false 378 go func() { 379 select { 380 case <-drain: 381 logger.Debug("draining") 382 interrupted = true 383 listener.Close() 384 case <-done: 385 logger.Debug("done") 386 } 387 }() 388 389 for { 390 localConn, err := listener.Accept() 391 if err != nil { 392 if !interrupted { 393 logger.Error("failed-to-accept", err) 394 } 395 396 break 397 } 398 399 connsWg.Add(1) 400 401 go func() { 402 defer connsWg.Done() 403 404 forwardLocalConn( 405 lagerctx.NewContext(ctx, logger.Session("forward-conn")), 406 localConn, 407 conn, 408 forwardIP, 409 forwardPort, 410 ) 411 }() 412 } 413 } 414 415 func forwardLocalConn(ctx context.Context, localConn net.Conn, conn *ssh.ServerConn, forwardIP string, forwardPort uint32) { 416 logger := lagerctx.FromContext(ctx) 417 418 defer localConn.Close() 419 420 var req forwardTCPIPChannelRequest 421 req.ForwardIP = forwardIP 422 req.ForwardPort = forwardPort 423 424 host, port, err := net.SplitHostPort(localConn.RemoteAddr().String()) 425 if err != nil { 426 logger.Error("failed-to-split-host-port", err) 427 return 428 } 429 430 req.OriginIP = host 431 432 _, err = fmt.Sscanf(port, "%d", &req.OriginPort) 433 if err != nil { 434 logger.Error("failed-to-parse-port", err) 435 return 436 } 437 438 channel, reqs, err := conn.OpenChannel("forwarded-tcpip", ssh.Marshal(req)) 439 if err != nil { 440 logger.Error("failed-to-open-channel", err) 441 return 442 } 443 444 defer channel.Close() 445 446 go ssh.DiscardRequests(reqs) 447 448 numPipes := 2 449 wait := make(chan struct{}, numPipes) 450 451 pipe := func(to io.WriteCloser, from io.ReadCloser) { 452 // if either end breaks, close both ends to ensure they're both unblocked, 453 // otherwise io.Copy can block forever if e.g. reading after write end has 454 // gone away 455 defer to.Close() 456 defer from.Close() 457 defer func() { 458 wait <- struct{}{} 459 }() 460 461 io.Copy(to, from) 462 } 463 464 go pipe(localConn, channel) 465 go pipe(channel, localConn) 466 467 done := 0 468 dance: 469 for { 470 select { 471 case <-wait: 472 done++ 473 if done == numPipes { 474 break dance 475 } 476 477 logger.Debug("tcpip-io-complete") 478 case <-ctx.Done(): 479 logger.Debug("tcpip-io-interrupted") 480 break dance 481 } 482 } 483 } 484 485 func (server *server) parseRequest(cli string) (request, string, error) { 486 argv := strings.Split(cli, " ") 487 488 command := argv[0] 489 args := argv[1:] 490 491 var req request 492 switch command { 493 case tsa.ForwardWorker: 494 var fs = flag.NewFlagSet(command, flag.ContinueOnError) 495 496 var garden = fs.String("garden", "", "garden address to forward") 497 var baggageclaim = fs.String("baggageclaim", "", "baggageclaim address to forward") 498 499 err := fs.Parse(args) 500 if err != nil { 501 return nil, "", err 502 } 503 504 req = forwardWorkerRequest{ 505 server: server, 506 507 gardenAddr: *garden, 508 baggageclaimAddr: *baggageclaim, 509 } 510 case tsa.LandWorker: 511 req = landWorkerRequest{ 512 server: server, 513 } 514 case tsa.RetireWorker: 515 req = retireWorkerRequest{ 516 server: server, 517 } 518 case tsa.DeleteWorker: 519 req = deleteWorkerRequest{ 520 server: server, 521 } 522 case tsa.SweepContainers: 523 req = sweepContainersRequest{ 524 server: server, 525 } 526 case tsa.ReportContainers: 527 req = reportContainersRequest{ 528 server: server, 529 containerHandles: args, 530 } 531 case tsa.SweepVolumes: 532 req = sweepVolumesRequest{ 533 server: server, 534 } 535 case tsa.ReportVolumes: 536 req = reportVolumesRequest{ 537 server: server, 538 volumeHandles: args, 539 } 540 default: 541 return nil, "", fmt.Errorf("unknown command: %s", command) 542 } 543 544 return req, command, nil 545 }