github.com/pf-qiu/concourse/v6@v6.7.3-0.20201207032516-1f455d73275f/fly/commands/internal/hijacker/hijacker.go (about)

     1  package hijacker
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"fmt"
     7  	"io"
     8  	"net/http"
     9  	"os"
    10  	"time"
    11  
    12  	"github.com/pf-qiu/concourse/v6/atc"
    13  	"github.com/pf-qiu/concourse/v6/fly/pty"
    14  	"github.com/pf-qiu/concourse/v6/fly/rc"
    15  	"github.com/pf-qiu/concourse/v6/fly/ui"
    16  	"github.com/gorilla/websocket"
    17  	"github.com/mgutz/ansi"
    18  	"github.com/tedsuo/rata"
    19  )
    20  
    21  type ProcessIO struct {
    22  	In  chan atc.HijackInput
    23  	Out io.Writer
    24  	Err io.Writer
    25  }
    26  
    27  type Hijacker struct {
    28  	tlsConfig        *tls.Config
    29  	requestGenerator *rata.RequestGenerator
    30  	token            *rc.TargetToken
    31  	interval         time.Duration
    32  }
    33  
    34  func New(tlsConfig *tls.Config, requestGenerator *rata.RequestGenerator, token *rc.TargetToken) *Hijacker {
    35  	return &Hijacker{
    36  		tlsConfig:        tlsConfig,
    37  		requestGenerator: requestGenerator,
    38  		token:            token,
    39  		interval:         10 * time.Second,
    40  	}
    41  }
    42  
    43  func (h *Hijacker) SetHeartbeatInterval(interval time.Duration) {
    44  	h.interval = interval
    45  }
    46  
    47  func (h *Hijacker) Hijack(ctx context.Context, teamName, handle string, spec atc.HijackProcessSpec, pio ProcessIO) (int, bool, error) {
    48  	url, header, err := h.hijackRequestParts(teamName, handle)
    49  	if err != nil {
    50  		return -1, false, err
    51  	}
    52  
    53  	dialer := websocket.Dialer{
    54  		TLSClientConfig: h.tlsConfig,
    55  		Proxy:           http.ProxyFromEnvironment,
    56  	}
    57  	conn, response, err := dialer.Dial(url, header)
    58  	if err != nil {
    59  		return -1, false, fmt.Errorf("%s %w", response.Status, err)
    60  	}
    61  
    62  	defer conn.Close()
    63  
    64  	err = conn.WriteJSON(spec)
    65  	if err != nil {
    66  		return -1, false, err
    67  	}
    68  
    69  	ctx, cancel := context.WithCancel(ctx)
    70  	defer cancel()
    71  
    72  	go h.monitorTTYSize(ctx, pio.In)
    73  	go h.handleInput(ctx, conn, pio.In)
    74  
    75  	exitStatus, exeNotFound := h.handleOutput(conn, pio)
    76  
    77  	return exitStatus, exeNotFound, nil
    78  }
    79  
    80  func (h *Hijacker) hijackRequestParts(teamName, handle string) (string, http.Header, error) {
    81  	hijackReq, err := h.requestGenerator.CreateRequest(
    82  		atc.HijackContainer,
    83  		rata.Params{"id": handle, "team_name": teamName},
    84  		nil,
    85  	)
    86  
    87  	if err != nil {
    88  		panic(err)
    89  	}
    90  
    91  	if h.token != nil {
    92  		hijackReq.Header.Add("Authorization", h.token.Type+" "+h.token.Value)
    93  	}
    94  
    95  	wsUrl := hijackReq.URL
    96  
    97  	var found bool
    98  	wsUrl.Scheme, found = websocketSchemeMap[wsUrl.Scheme]
    99  	if !found {
   100  		return "", nil, fmt.Errorf("unknown target scheme: %s", wsUrl.Scheme)
   101  	}
   102  
   103  	return wsUrl.String(), hijackReq.Header, nil
   104  }
   105  
   106  func (h *Hijacker) handleOutput(conn *websocket.Conn, pio ProcessIO) (int, bool) {
   107  	var exitStatus int
   108  	var exeNotFound bool
   109  	for {
   110  		var output atc.HijackOutput
   111  		err := conn.ReadJSON(&output)
   112  		if err != nil {
   113  			if !websocket.IsCloseError(err) && !websocket.IsUnexpectedCloseError(err) {
   114  				fmt.Println(err)
   115  			}
   116  			break
   117  		}
   118  
   119  		if output.ExitStatus != nil {
   120  			exitStatus = *output.ExitStatus
   121  		} else if output.ExecutableNotFound {
   122  			exeNotFound = true
   123  		} else if len(output.Error) > 0 {
   124  			fmt.Fprintf(ui.Stderr, "%s\n", ansi.Color(output.Error, "red+b"))
   125  			exitStatus = 255
   126  		} else if len(output.Stdout) > 0 {
   127  			pio.Out.Write(output.Stdout)
   128  		} else if len(output.Stderr) > 0 {
   129  			pio.Err.Write(output.Stderr)
   130  		}
   131  	}
   132  
   133  	return exitStatus, exeNotFound
   134  }
   135  
   136  func (h *Hijacker) handleInput(ctx context.Context, conn *websocket.Conn, inputs <-chan atc.HijackInput) {
   137  	ticker := time.NewTicker(h.interval)
   138  	defer ticker.Stop()
   139  
   140  	for {
   141  		select {
   142  		case input := <-inputs:
   143  			err := conn.WriteJSON(input)
   144  			if err != nil {
   145  				fmt.Fprintf(ui.Stderr, "failed to send input: %s", err.Error())
   146  				return
   147  			}
   148  		case t := <-ticker.C:
   149  			err := conn.WriteControl(websocket.PingMessage, []byte(t.String()), time.Now().Add(time.Second))
   150  			if err != nil {
   151  				fmt.Fprintf(ui.Stderr, "failed to send heartbeat: %s", err.Error())
   152  			}
   153  		case <-ctx.Done():
   154  			return
   155  		}
   156  	}
   157  }
   158  
   159  func (h *Hijacker) monitorTTYSize(ctx context.Context, inputs chan<- atc.HijackInput) {
   160  	resized := pty.ResizeNotifier()
   161  
   162  	for {
   163  		select {
   164  		case <-resized:
   165  			rows, cols, err := pty.Getsize(os.Stdin)
   166  			if err == nil {
   167  				inputs <- atc.HijackInput{
   168  					TTYSpec: &atc.HijackTTYSpec{
   169  						WindowSize: atc.HijackWindowSize{
   170  							Columns: cols,
   171  							Rows:    rows,
   172  						},
   173  					},
   174  				}
   175  			}
   176  		case <-ctx.Done():
   177  			return
   178  		}
   179  	}
   180  }
   181  
   182  var websocketSchemeMap = map[string]string{
   183  	"http":  "ws",
   184  	"https": "wss",
   185  }