github.com/pf-qiu/concourse/v6@v6.7.3-0.20201207032516-1f455d73275f/tsa/tsacmd/server.go (about)

     1  package tsacmd
     2  
     3  import (
     4  	"context"
     5  	"flag"
     6  	"fmt"
     7  	"io"
     8  	"net"
     9  	"net/http"
    10  	"strings"
    11  	"sync"
    12  	"time"
    13  
    14  	"code.cloudfoundry.org/lager"
    15  	"code.cloudfoundry.org/lager/lagerctx"
    16  	"github.com/pf-qiu/concourse/v6/tsa"
    17  	"golang.org/x/crypto/ssh"
    18  )
    19  
    20  const maxForwards = 2
    21  
    22  type server struct {
    23  	logger               lager.Logger
    24  	atcEndpointPicker    tsa.EndpointPicker
    25  	heartbeatInterval    time.Duration
    26  	cprInterval          time.Duration
    27  	gardenRequestTimeout time.Duration
    28  	forwardHost          string
    29  	config               *ssh.ServerConfig
    30  	httpClient           *http.Client
    31  	sessionTeam          *sessionTeam
    32  }
    33  
    34  type sessionTeam struct {
    35  	sessionTeams map[string]string
    36  	lock         *sync.RWMutex
    37  }
    38  
    39  func (s *sessionTeam) AuthorizeTeam(sessionID, team string) {
    40  	s.lock.Lock()
    41  	defer s.lock.Unlock()
    42  
    43  	s.sessionTeams[sessionID] = team
    44  }
    45  
    46  func (s *sessionTeam) IsNotAuthorized(sessionID, team string) bool {
    47  	s.lock.RLock()
    48  	defer s.lock.RUnlock()
    49  
    50  	t, found := s.sessionTeams[sessionID]
    51  
    52  	return found && t != team
    53  }
    54  
    55  func (s *sessionTeam) AuthorizedTeamFor(sessionID string) string {
    56  	s.lock.RLock()
    57  	defer s.lock.RUnlock()
    58  
    59  	return s.sessionTeams[sessionID]
    60  }
    61  
    62  type ConnState struct {
    63  	Team string
    64  
    65  	ForwardedTCPIPs <-chan ForwardedTCPIP
    66  }
    67  
    68  type ForwardedTCPIP struct {
    69  	Logger lager.Logger
    70  
    71  	BindAddr  string
    72  	BoundPort uint32
    73  
    74  	Drain chan<- struct{}
    75  
    76  	wg *sync.WaitGroup
    77  }
    78  
    79  func (forward ForwardedTCPIP) Wait() {
    80  	forward.Logger.Debug("draining")
    81  	forward.wg.Wait()
    82  	forward.Logger.Debug("drained")
    83  }
    84  
    85  func (server *server) Serve(listener net.Listener) {
    86  	for {
    87  		c, err := listener.Accept()
    88  		if err != nil {
    89  			if !strings.Contains(err.Error(), "use of closed network connection") {
    90  				server.logger.Error("failed-to-accept", err)
    91  			}
    92  
    93  			return
    94  		}
    95  
    96  		logger := server.logger.Session("connection", lager.Data{
    97  			"remote": c.RemoteAddr().String(),
    98  		})
    99  
   100  		go server.handshake(logger, c)
   101  	}
   102  }
   103  
   104  func (server *server) handshake(logger lager.Logger, netConn net.Conn) {
   105  	conn, chans, reqs, err := ssh.NewServerConn(netConn, server.config)
   106  	if err != nil {
   107  		logger.Info("handshake-failed", lager.Data{"error": err.Error()})
   108  		return
   109  	}
   110  
   111  	defer conn.Close()
   112  
   113  	ctx, cancel := context.WithCancel(lagerctx.NewContext(context.Background(), logger))
   114  	defer cancel()
   115  
   116  	sessionID := string(conn.SessionID())
   117  
   118  	forwardedTCPIPs := make(chan ForwardedTCPIP, maxForwards)
   119  	go server.handleForwardRequests(ctx, conn, reqs, forwardedTCPIPs)
   120  
   121  	state := ConnState{
   122  		Team: server.sessionTeam.AuthorizedTeamFor(sessionID),
   123  
   124  		ForwardedTCPIPs: forwardedTCPIPs,
   125  	}
   126  
   127  	chansGroup := new(sync.WaitGroup)
   128  
   129  	for newChannel := range chans {
   130  		if newChannel.ChannelType() != "session" {
   131  			logger.Info("rejecting-unknown-channel-type", lager.Data{
   132  				"type": newChannel.ChannelType(),
   133  			})
   134  
   135  			newChannel.Reject(ssh.UnknownChannelType, "unknown channel type")
   136  			continue
   137  		}
   138  
   139  		channel, requests, err := newChannel.Accept()
   140  		if err != nil {
   141  			logger.Error("failed-to-accept-channel", err)
   142  			return
   143  		}
   144  
   145  		chansGroup.Add(1)
   146  		go server.handleChannel(logger.Session("channel"), chansGroup, channel, requests, state)
   147  	}
   148  
   149  	chansGroup.Wait()
   150  }
   151  
   152  type signalMsg struct {
   153  	Signal string
   154  }
   155  
   156  func (server *server) handleChannel(
   157  	logger lager.Logger,
   158  	chansGroup *sync.WaitGroup,
   159  	channel ssh.Channel,
   160  	requests <-chan *ssh.Request,
   161  	state ConnState,
   162  ) {
   163  	ctx, cancel := context.WithCancel(context.Background())
   164  	defer cancel()
   165  
   166  	defer chansGroup.Done()
   167  	defer channel.Close()
   168  
   169  	execExited := make(chan error, 1)
   170  
   171  	for {
   172  		select {
   173  		case req, ok := <-requests:
   174  			if !ok {
   175  				return
   176  			}
   177  
   178  			logger.Debug("channel-request", lager.Data{
   179  				"type": req.Type,
   180  			})
   181  
   182  			switch req.Type {
   183  			case "signal":
   184  				req.Reply(true, nil)
   185  
   186  				var sig signalMsg
   187  				err := ssh.Unmarshal(req.Payload, &sig)
   188  				if err != nil {
   189  					logger.Error("malformed-signal", err)
   190  					req.Reply(false, nil)
   191  					continue
   192  				}
   193  
   194  				logger.Debug("received-signal", lager.Data{
   195  					"signal": sig,
   196  				})
   197  
   198  				cancel()
   199  
   200  			case "exec":
   201  				var request execRequest
   202  				err := ssh.Unmarshal(req.Payload, &request)
   203  				if err != nil {
   204  					logger.Error("malformed-exec-request", err)
   205  					req.Reply(false, nil)
   206  					return
   207  				}
   208  
   209  				workerRequest, command, err := server.parseRequest(request.Command)
   210  				if err != nil {
   211  					fmt.Fprintf(channel, "invalid command: %s", err)
   212  					req.Reply(false, nil)
   213  					continue
   214  				}
   215  
   216  				req.Reply(true, nil)
   217  
   218  				cmdLogger := logger.Session("command", lager.Data{
   219  					"command": command,
   220  				})
   221  
   222  				go func() {
   223  					execExited <- workerRequest.Handle(lagerctx.NewContext(ctx, cmdLogger), state, channel)
   224  				}()
   225  
   226  			default:
   227  				logger.Info("rejecting")
   228  				req.Reply(false, nil)
   229  				continue
   230  			}
   231  
   232  		case err := <-execExited:
   233  			req := exitStatusRequest{0}
   234  
   235  			if err != nil {
   236  				logger.Error("exited-with-error", err)
   237  				req.ExitStatus = 1
   238  			} else {
   239  				logger.Debug("exited-successfully")
   240  			}
   241  
   242  			_, err = channel.SendRequest("exit-status", false, ssh.Marshal(req))
   243  			if err != nil {
   244  				logger.Error("failed-to-send-exit-status", err)
   245  			}
   246  
   247  			// RFC 4254: "The channel needs to be closed with SSH_MSG_CHANNEL_CLOSE after
   248  			// this message."
   249  			err = channel.Close()
   250  			if err != nil {
   251  				logger.Error("failed-to-close-channel", err)
   252  			} else {
   253  				logger.Debug("closed-channel")
   254  			}
   255  		}
   256  	}
   257  }
   258  
   259  func (server *server) handleForwardRequests(
   260  	ctx context.Context,
   261  	conn *ssh.ServerConn,
   262  	reqs <-chan *ssh.Request,
   263  	forwardedTCPIPs chan<- ForwardedTCPIP,
   264  ) {
   265  	logger := lagerctx.FromContext(ctx)
   266  
   267  	var forwardedThings int
   268  
   269  	for r := range reqs {
   270  		reqLog := logger.Session("request", lager.Data{
   271  			"type": r.Type,
   272  		})
   273  
   274  		switch r.Type {
   275  		case "tcpip-forward":
   276  			forwardedThings++
   277  
   278  			if forwardedThings > maxForwards {
   279  				reqLog.Info("rejecting-extra-forward-request")
   280  				r.Reply(false, nil)
   281  				continue
   282  			}
   283  
   284  			var req tcpipForwardRequest
   285  			err := ssh.Unmarshal(r.Payload, &req)
   286  			if err != nil {
   287  				reqLog.Error("malformed-tcpip-request", err)
   288  				r.Reply(false, nil)
   289  				continue
   290  			}
   291  
   292  			bindAddr := net.JoinHostPort(req.BindIP, fmt.Sprintf("%d", req.BindPort))
   293  
   294  			listener, err := net.Listen("tcp", "0.0.0.0:0")
   295  			if err != nil {
   296  				reqLog.Error("failed-to-listen", err)
   297  				r.Reply(false, nil)
   298  				continue
   299  			}
   300  
   301  			defer listener.Close()
   302  
   303  			_, port, err := net.SplitHostPort(listener.Addr().String())
   304  			if err != nil {
   305  				r.Reply(false, nil)
   306  				continue
   307  			}
   308  
   309  			var res tcpipForwardResponse
   310  			_, err = fmt.Sscanf(port, "%d", &res.BoundPort)
   311  			if err != nil {
   312  				r.Reply(false, nil)
   313  				continue
   314  			}
   315  
   316  			reqLog = reqLog.WithData(lager.Data{
   317  				"addr":           listener.Addr().String(),
   318  				"requested-addr": bindAddr,
   319  			})
   320  
   321  			reqLog.Debug("listening")
   322  
   323  			forPort := req.BindPort
   324  			if forPort == 0 {
   325  				forPort = res.BoundPort
   326  			}
   327  
   328  			drain := make(chan struct{})
   329  			wait := new(sync.WaitGroup)
   330  
   331  			wait.Add(1)
   332  			go server.forwardTCPIP(lagerctx.NewContext(ctx, reqLog), drain, wait, conn, listener, req.BindIP, forPort)
   333  
   334  			forwardedTCPIPs <- ForwardedTCPIP{
   335  				Logger: reqLog,
   336  
   337  				BindAddr:  bindAddr,
   338  				BoundPort: res.BoundPort,
   339  
   340  				Drain: drain,
   341  
   342  				wg: wait,
   343  			}
   344  
   345  			r.Reply(true, ssh.Marshal(res))
   346  
   347  		default:
   348  			// OpenSSH sends keepalive@openssh.com, but there may be other clients;
   349  			// just check for 'keepalive'
   350  			if strings.Contains(r.Type, "keepalive") {
   351  				reqLog.Debug("keepalive")
   352  				r.Reply(true, nil)
   353  			} else {
   354  				reqLog.Info("ignoring")
   355  				r.Reply(false, nil)
   356  			}
   357  		}
   358  	}
   359  }
   360  
   361  func (server *server) forwardTCPIP(
   362  	ctx context.Context,
   363  	drain <-chan struct{},
   364  	connsWg *sync.WaitGroup,
   365  	conn *ssh.ServerConn,
   366  	listener net.Listener,
   367  	forwardIP string,
   368  	forwardPort uint32,
   369  ) {
   370  	defer connsWg.Done()
   371  
   372  	logger := lagerctx.FromContext(ctx)
   373  
   374  	done := make(chan struct{})
   375  	defer close(done)
   376  
   377  	interrupted := false
   378  	go func() {
   379  		select {
   380  		case <-drain:
   381  			logger.Debug("draining")
   382  			interrupted = true
   383  			listener.Close()
   384  		case <-done:
   385  			logger.Debug("done")
   386  		}
   387  	}()
   388  
   389  	for {
   390  		localConn, err := listener.Accept()
   391  		if err != nil {
   392  			if !interrupted {
   393  				logger.Error("failed-to-accept", err)
   394  			}
   395  
   396  			break
   397  		}
   398  
   399  		connsWg.Add(1)
   400  
   401  		go func() {
   402  			defer connsWg.Done()
   403  
   404  			forwardLocalConn(
   405  				lagerctx.NewContext(ctx, logger.Session("forward-conn")),
   406  				localConn,
   407  				conn,
   408  				forwardIP,
   409  				forwardPort,
   410  			)
   411  		}()
   412  	}
   413  }
   414  
   415  func forwardLocalConn(ctx context.Context, localConn net.Conn, conn *ssh.ServerConn, forwardIP string, forwardPort uint32) {
   416  	logger := lagerctx.FromContext(ctx)
   417  
   418  	defer localConn.Close()
   419  
   420  	var req forwardTCPIPChannelRequest
   421  	req.ForwardIP = forwardIP
   422  	req.ForwardPort = forwardPort
   423  
   424  	host, port, err := net.SplitHostPort(localConn.RemoteAddr().String())
   425  	if err != nil {
   426  		logger.Error("failed-to-split-host-port", err)
   427  		return
   428  	}
   429  
   430  	req.OriginIP = host
   431  
   432  	_, err = fmt.Sscanf(port, "%d", &req.OriginPort)
   433  	if err != nil {
   434  		logger.Error("failed-to-parse-port", err)
   435  		return
   436  	}
   437  
   438  	channel, reqs, err := conn.OpenChannel("forwarded-tcpip", ssh.Marshal(req))
   439  	if err != nil {
   440  		logger.Error("failed-to-open-channel", err)
   441  		return
   442  	}
   443  
   444  	defer channel.Close()
   445  
   446  	go ssh.DiscardRequests(reqs)
   447  
   448  	numPipes := 2
   449  	wait := make(chan struct{}, numPipes)
   450  
   451  	pipe := func(to io.WriteCloser, from io.ReadCloser) {
   452  		// if either end breaks, close both ends to ensure they're both unblocked,
   453  		// otherwise io.Copy can block forever if e.g. reading after write end has
   454  		// gone away
   455  		defer to.Close()
   456  		defer from.Close()
   457  		defer func() {
   458  			wait <- struct{}{}
   459  		}()
   460  
   461  		io.Copy(to, from)
   462  	}
   463  
   464  	go pipe(localConn, channel)
   465  	go pipe(channel, localConn)
   466  
   467  	done := 0
   468  dance:
   469  	for {
   470  		select {
   471  		case <-wait:
   472  			done++
   473  			if done == numPipes {
   474  				break dance
   475  			}
   476  
   477  			logger.Debug("tcpip-io-complete")
   478  		case <-ctx.Done():
   479  			logger.Debug("tcpip-io-interrupted")
   480  			break dance
   481  		}
   482  	}
   483  }
   484  
   485  func (server *server) parseRequest(cli string) (request, string, error) {
   486  	argv := strings.Split(cli, " ")
   487  
   488  	command := argv[0]
   489  	args := argv[1:]
   490  
   491  	var req request
   492  	switch command {
   493  	case tsa.ForwardWorker:
   494  		var fs = flag.NewFlagSet(command, flag.ContinueOnError)
   495  
   496  		var garden = fs.String("garden", "", "garden address to forward")
   497  		var baggageclaim = fs.String("baggageclaim", "", "baggageclaim address to forward")
   498  
   499  		err := fs.Parse(args)
   500  		if err != nil {
   501  			return nil, "", err
   502  		}
   503  
   504  		req = forwardWorkerRequest{
   505  			server: server,
   506  
   507  			gardenAddr:       *garden,
   508  			baggageclaimAddr: *baggageclaim,
   509  		}
   510  	case tsa.LandWorker:
   511  		req = landWorkerRequest{
   512  			server: server,
   513  		}
   514  	case tsa.RetireWorker:
   515  		req = retireWorkerRequest{
   516  			server: server,
   517  		}
   518  	case tsa.DeleteWorker:
   519  		req = deleteWorkerRequest{
   520  			server: server,
   521  		}
   522  	case tsa.SweepContainers:
   523  		req = sweepContainersRequest{
   524  			server: server,
   525  		}
   526  	case tsa.ReportContainers:
   527  		req = reportContainersRequest{
   528  			server:           server,
   529  			containerHandles: args,
   530  		}
   531  	case tsa.SweepVolumes:
   532  		req = sweepVolumesRequest{
   533  			server: server,
   534  		}
   535  	case tsa.ReportVolumes:
   536  		req = reportVolumesRequest{
   537  			server:        server,
   538  			volumeHandles: args,
   539  		}
   540  	default:
   541  		return nil, "", fmt.Errorf("unknown command: %s", command)
   542  	}
   543  
   544  	return req, command, nil
   545  }