github.com/pf-qiu/concourse/v6@v6.7.3-0.20201207032516-1f455d73275f/tsa/tsacmd/remote_cli.go (about)

     1  package tsacmd
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"fmt"
     7  	"net/http"
     8  	"time"
     9  
    10  	"code.cloudfoundry.org/clock"
    11  	"code.cloudfoundry.org/lager"
    12  	"code.cloudfoundry.org/lager/lagerctx"
    13  	bclient "github.com/concourse/baggageclaim/client"
    14  	"github.com/pf-qiu/concourse/v6/atc"
    15  	"github.com/pf-qiu/concourse/v6/atc/worker/gclient"
    16  	"github.com/pf-qiu/concourse/v6/tsa"
    17  	"golang.org/x/crypto/ssh"
    18  )
    19  
    20  type request interface {
    21  	Handle(context.Context, ConnState, ssh.Channel) error
    22  }
    23  
    24  type forwardWorkerRequest struct {
    25  	server *server
    26  
    27  	gardenAddr       string
    28  	baggageclaimAddr string
    29  }
    30  
    31  func (req forwardWorkerRequest) Handle(ctx context.Context, state ConnState, channel ssh.Channel) error {
    32  	logger := lagerctx.FromContext(ctx)
    33  
    34  	var worker atc.Worker
    35  	err := json.NewDecoder(channel).Decode(&worker)
    36  	if err != nil {
    37  		return err
    38  	}
    39  
    40  	if err := checkTeam(state, worker); err != nil {
    41  		return err
    42  	}
    43  
    44  	forwards := map[string]ForwardedTCPIP{}
    45  	for i := 0; i < 2; i++ {
    46  		select {
    47  		case forwarded := <-state.ForwardedTCPIPs:
    48  			logger.Info("forwarded-tcpip", lager.Data{
    49  				"bind-addr":  forwarded.BindAddr,
    50  				"bound-port": forwarded.BoundPort,
    51  			})
    52  
    53  			forwards[forwarded.BindAddr] = forwarded
    54  
    55  		case <-time.After(10 * time.Second):
    56  			logger.Info("never-forwarded-tcpip")
    57  		}
    58  	}
    59  
    60  	gardenForward, found := forwards[req.gardenAddr]
    61  	if !found {
    62  		return fmt.Errorf("garden address (%s) not forwarded", req.gardenAddr)
    63  	}
    64  
    65  	baggageclaimForward, found := forwards[req.baggageclaimAddr]
    66  	if !found {
    67  		return fmt.Errorf("baggageclaim address (%s) not forwarded", req.baggageclaimAddr)
    68  	}
    69  
    70  	worker.GardenAddr = fmt.Sprintf("%s:%d", req.server.forwardHost, gardenForward.BoundPort)
    71  	worker.BaggageclaimURL = fmt.Sprintf("http://%s:%d", req.server.forwardHost, baggageclaimForward.BoundPort)
    72  
    73  	heartbeater := tsa.NewHeartbeater(
    74  		clock.NewClock(),
    75  		req.server.heartbeatInterval,
    76  		req.server.cprInterval,
    77  		gclient.BasicGardenClientWithRequestTimeout(
    78  			lagerctx.WithSession(ctx, "garden-connection"),
    79  			req.server.gardenRequestTimeout,
    80  			gardenURL(worker.GardenAddr),
    81  		),
    82  		bclient.NewWithHTTPClient(worker.BaggageclaimURL, &http.Client{
    83  			Transport: &http.Transport{
    84  				DisableKeepAlives:     true,
    85  				ResponseHeaderTimeout: 1 * time.Minute,
    86  			},
    87  		}),
    88  		req.server.atcEndpointPicker,
    89  		req.server.httpClient,
    90  		worker,
    91  		tsa.NewEventWriter(channel),
    92  	)
    93  
    94  	err = heartbeater.Heartbeat(ctx)
    95  	if err != nil {
    96  		logger.Error("failed-to-heartbeat", err)
    97  		return err
    98  	}
    99  
   100  	for _, forward := range forwards {
   101  		// prevent new connections from being accepted
   102  		close(forward.Drain)
   103  	}
   104  
   105  	// only drain if heartbeating was interrupted; otherwise the worker landed or
   106  	// retired, so it's time to go away
   107  	if ctx.Err() != nil {
   108  		logger.Info("draining-forwarded-connections")
   109  
   110  		for _, forward := range forwards {
   111  			// wait for connections to drain
   112  			forward.Wait()
   113  
   114  			logger.Info("forward-process-exited", lager.Data{
   115  				"bind-addr":  forward.BindAddr,
   116  				"bound-port": forward.BoundPort,
   117  			})
   118  		}
   119  	}
   120  
   121  	return nil
   122  }
   123  
   124  func (r forwardWorkerRequest) expectedForwards() int {
   125  	expected := 0
   126  
   127  	// Garden should always be forwarded;
   128  	// if not explicitly given, the only given forward is used
   129  	expected++
   130  
   131  	if r.baggageclaimAddr != "" {
   132  		expected++
   133  	}
   134  
   135  	return expected
   136  }
   137  
   138  type landWorkerRequest struct {
   139  	server *server
   140  }
   141  
   142  func checkTeam(state ConnState, worker atc.Worker) error {
   143  	if state.Team == "" {
   144  		// global keys can be used for all teams
   145  		return nil
   146  	}
   147  
   148  	if worker.Team == "" && state.Team != "" {
   149  		return fmt.Errorf("key is authorized for team %s, but worker is global", state.Team)
   150  	}
   151  
   152  	if worker.Team != state.Team {
   153  		return fmt.Errorf("key is authorized for team %s, but worker belongs to team %s", state.Team, worker.Team)
   154  	}
   155  
   156  	return nil
   157  }
   158  
   159  func (req landWorkerRequest) Handle(ctx context.Context, state ConnState, channel ssh.Channel) error {
   160  	var worker atc.Worker
   161  	err := json.NewDecoder(channel).Decode(&worker)
   162  	if err != nil {
   163  		return err
   164  	}
   165  
   166  	if err := checkTeam(state, worker); err != nil {
   167  		return err
   168  	}
   169  
   170  	return (&tsa.Lander{
   171  		ATCEndpoint: req.server.atcEndpointPicker.Pick(),
   172  		HTTPClient:  req.server.httpClient,
   173  	}).Land(ctx, worker)
   174  }
   175  
   176  type retireWorkerRequest struct {
   177  	server *server
   178  }
   179  
   180  func (req retireWorkerRequest) Handle(ctx context.Context, state ConnState, channel ssh.Channel) error {
   181  	var worker atc.Worker
   182  	err := json.NewDecoder(channel).Decode(&worker)
   183  	if err != nil {
   184  		return err
   185  	}
   186  
   187  	if err := checkTeam(state, worker); err != nil {
   188  		return err
   189  	}
   190  
   191  	return (&tsa.Retirer{
   192  		ATCEndpoint: req.server.atcEndpointPicker.Pick(),
   193  		HTTPClient:  req.server.httpClient,
   194  	}).Retire(ctx, worker)
   195  }
   196  
   197  type deleteWorkerRequest struct {
   198  	server *server
   199  }
   200  
   201  func (req deleteWorkerRequest) Handle(ctx context.Context, state ConnState, channel ssh.Channel) error {
   202  	var worker atc.Worker
   203  	err := json.NewDecoder(channel).Decode(&worker)
   204  	if err != nil {
   205  		return err
   206  	}
   207  
   208  	if err := checkTeam(state, worker); err != nil {
   209  		return err
   210  	}
   211  
   212  	return (&tsa.Deleter{
   213  		ATCEndpoint: req.server.atcEndpointPicker.Pick(),
   214  		HTTPClient:  req.server.httpClient,
   215  	}).Delete(ctx, worker)
   216  }
   217  
   218  type sweepContainersRequest struct {
   219  	server *server
   220  }
   221  
   222  func (req sweepContainersRequest) Handle(ctx context.Context, state ConnState, channel ssh.Channel) error {
   223  	var worker atc.Worker
   224  	err := json.NewDecoder(channel).Decode(&worker)
   225  	if err != nil {
   226  		return err
   227  	}
   228  
   229  	if err := checkTeam(state, worker); err != nil {
   230  		return err
   231  	}
   232  
   233  	sweeper := &tsa.Sweeper{
   234  		ATCEndpoint: req.server.atcEndpointPicker.Pick(),
   235  		HTTPClient:  req.server.httpClient,
   236  	}
   237  
   238  	handles, err := sweeper.Sweep(ctx, worker, tsa.SweepContainers)
   239  	if err != nil {
   240  		return err
   241  	}
   242  
   243  	_, err = channel.Write(handles)
   244  	if err != nil {
   245  		return err
   246  	}
   247  
   248  	return nil
   249  }
   250  
   251  type reportContainersRequest struct {
   252  	server           *server
   253  	containerHandles []string
   254  }
   255  
   256  func (req reportContainersRequest) Handle(ctx context.Context, state ConnState, channel ssh.Channel) error {
   257  	var worker atc.Worker
   258  	err := json.NewDecoder(channel).Decode(&worker)
   259  	if err != nil {
   260  		return err
   261  	}
   262  
   263  	if err := checkTeam(state, worker); err != nil {
   264  		return err
   265  	}
   266  
   267  	return (&tsa.WorkerStatus{
   268  		ATCEndpoint:      req.server.atcEndpointPicker.Pick(),
   269  		HTTPClient:       req.server.httpClient,
   270  		ContainerHandles: req.containerHandles,
   271  	}).WorkerStatus(ctx, worker, tsa.ReportContainers)
   272  }
   273  
   274  type sweepVolumesRequest struct {
   275  	server *server
   276  }
   277  
   278  func (req sweepVolumesRequest) Handle(ctx context.Context, state ConnState, channel ssh.Channel) error {
   279  	var worker atc.Worker
   280  	err := json.NewDecoder(channel).Decode(&worker)
   281  	if err != nil {
   282  		return err
   283  	}
   284  
   285  	if err := checkTeam(state, worker); err != nil {
   286  		return err
   287  	}
   288  
   289  	sweeper := &tsa.Sweeper{
   290  		ATCEndpoint: req.server.atcEndpointPicker.Pick(),
   291  		HTTPClient:  req.server.httpClient,
   292  	}
   293  
   294  	handles, err := sweeper.Sweep(ctx, worker, tsa.SweepVolumes)
   295  	if err != nil {
   296  		return err
   297  	}
   298  
   299  	_, err = channel.Write(handles)
   300  	if err != nil {
   301  		return err
   302  	}
   303  
   304  	return nil
   305  }
   306  
   307  type reportVolumesRequest struct {
   308  	server        *server
   309  	volumeHandles []string
   310  }
   311  
   312  func (req reportVolumesRequest) Handle(ctx context.Context, state ConnState, channel ssh.Channel) error {
   313  	var worker atc.Worker
   314  	err := json.NewDecoder(channel).Decode(&worker)
   315  	if err != nil {
   316  		return err
   317  	}
   318  
   319  	if err := checkTeam(state, worker); err != nil {
   320  		return err
   321  	}
   322  
   323  	return (&tsa.WorkerStatus{
   324  		ATCEndpoint:   req.server.atcEndpointPicker.Pick(),
   325  		HTTPClient:    req.server.httpClient,
   326  		VolumeHandles: req.volumeHandles,
   327  	}).WorkerStatus(ctx, worker, tsa.ReportVolumes)
   328  }
   329  
   330  func gardenURL(addr string) string {
   331  	return fmt.Sprintf("http://%s", addr)
   332  }