github.com/chenbh/concourse/v6@v6.4.2/fly/commands/internal/hijacker/hijacker.go (about)

     1  package hijacker
     2  
     3  import (
     4  	"crypto/tls"
     5  	"fmt"
     6  	"io"
     7  	"net/http"
     8  	"os"
     9  	"time"
    10  
    11  	"github.com/chenbh/concourse/v6/atc"
    12  	"github.com/chenbh/concourse/v6/fly/pty"
    13  	"github.com/chenbh/concourse/v6/fly/rc"
    14  	"github.com/chenbh/concourse/v6/fly/ui"
    15  	"github.com/gorilla/websocket"
    16  	"github.com/mgutz/ansi"
    17  	"github.com/tedsuo/rata"
    18  )
    19  
    20  type ProcessIO struct {
    21  	In  io.Reader
    22  	Out io.Writer
    23  	Err io.Writer
    24  }
    25  
    26  type Hijacker struct {
    27  	tlsConfig        *tls.Config
    28  	requestGenerator *rata.RequestGenerator
    29  	token            *rc.TargetToken
    30  	interval         time.Duration
    31  }
    32  
    33  func New(tlsConfig *tls.Config, requestGenerator *rata.RequestGenerator, token *rc.TargetToken) *Hijacker {
    34  	return &Hijacker{
    35  		tlsConfig:        tlsConfig,
    36  		requestGenerator: requestGenerator,
    37  		token:            token,
    38  		interval:         10 * time.Second,
    39  	}
    40  }
    41  
    42  func (h *Hijacker) SetHeartbeatInterval(interval time.Duration) {
    43  	h.interval = interval
    44  }
    45  
    46  func (h *Hijacker) Hijack(teamName, handle string, spec atc.HijackProcessSpec, pio ProcessIO) (int, error) {
    47  	url, header, err := h.hijackRequestParts(teamName, handle)
    48  	if err != nil {
    49  		return -1, err
    50  	}
    51  
    52  	dialer := websocket.Dialer{
    53  		TLSClientConfig: h.tlsConfig,
    54  		Proxy:           http.ProxyFromEnvironment,
    55  	}
    56  	conn, response, err := dialer.Dial(url, header)
    57  	if err != nil {
    58  		return -1, fmt.Errorf("%s %w", response.Status, err)
    59  	}
    60  
    61  	defer conn.Close()
    62  
    63  	err = conn.WriteJSON(spec)
    64  	if err != nil {
    65  		return -1, err
    66  	}
    67  
    68  	inputs := make(chan atc.HijackInput, 1)
    69  	finished := make(chan struct{}, 1)
    70  
    71  	go h.monitorTTYSize(inputs, finished)
    72  	go func() {
    73  		io.Copy(&stdinWriter{inputs}, pio.In)
    74  		inputs <- atc.HijackInput{Closed: true}
    75  	}()
    76  	go h.handleInput(conn, inputs, finished)
    77  
    78  	exitStatus := h.handleOutput(conn, pio)
    79  
    80  	close(finished)
    81  
    82  	return exitStatus, nil
    83  }
    84  
    85  func (h *Hijacker) hijackRequestParts(teamName, handle string) (string, http.Header, error) {
    86  	hijackReq, err := h.requestGenerator.CreateRequest(
    87  		atc.HijackContainer,
    88  		rata.Params{"id": handle, "team_name": teamName},
    89  		nil,
    90  	)
    91  
    92  	if err != nil {
    93  		panic(err)
    94  	}
    95  
    96  	if h.token != nil {
    97  		hijackReq.Header.Add("Authorization", h.token.Type+" "+h.token.Value)
    98  	}
    99  
   100  	wsUrl := hijackReq.URL
   101  
   102  	var found bool
   103  	wsUrl.Scheme, found = websocketSchemeMap[wsUrl.Scheme]
   104  	if !found {
   105  		return "", nil, fmt.Errorf("unknown target scheme: %s", wsUrl.Scheme)
   106  	}
   107  
   108  	return wsUrl.String(), hijackReq.Header, nil
   109  }
   110  
   111  func (h *Hijacker) handleOutput(conn *websocket.Conn, pio ProcessIO) int {
   112  	var exitStatus int
   113  	for {
   114  		var output atc.HijackOutput
   115  		err := conn.ReadJSON(&output)
   116  		if err != nil {
   117  			if !websocket.IsCloseError(err) && !websocket.IsUnexpectedCloseError(err) {
   118  				fmt.Println(err)
   119  			}
   120  			break
   121  		}
   122  
   123  		if output.ExitStatus != nil {
   124  			exitStatus = *output.ExitStatus
   125  		} else if len(output.Error) > 0 {
   126  			fmt.Fprintf(ui.Stderr, "%s\n", ansi.Color(output.Error, "red+b"))
   127  			exitStatus = 255
   128  		} else if len(output.Stdout) > 0 {
   129  			pio.Out.Write(output.Stdout)
   130  		} else if len(output.Stderr) > 0 {
   131  			pio.Err.Write(output.Stderr)
   132  		}
   133  	}
   134  
   135  	return exitStatus
   136  }
   137  
   138  func (h *Hijacker) handleInput(conn *websocket.Conn, inputs <-chan atc.HijackInput, finished chan struct{}) {
   139  	ticker := time.NewTicker(h.interval)
   140  	defer ticker.Stop()
   141  
   142  	for {
   143  		select {
   144  		case input := <-inputs:
   145  			err := conn.WriteJSON(input)
   146  			if err != nil {
   147  				fmt.Fprintf(ui.Stderr, "failed to send input: %s", err.Error())
   148  				return
   149  			}
   150  		case t := <-ticker.C:
   151  			err := conn.WriteControl(websocket.PingMessage, []byte(t.String()), time.Now().Add(time.Second))
   152  			if err != nil {
   153  				fmt.Fprintf(ui.Stderr, "failed to send heartbeat: %s", err.Error())
   154  			}
   155  		case <-finished:
   156  			return
   157  		}
   158  	}
   159  }
   160  
   161  func (h *Hijacker) monitorTTYSize(inputs chan<- atc.HijackInput, finished chan struct{}) {
   162  	resized := pty.ResizeNotifier()
   163  
   164  	for {
   165  		select {
   166  		case <-resized:
   167  			rows, cols, err := pty.Getsize(os.Stdin)
   168  			if err == nil {
   169  				inputs <- atc.HijackInput{
   170  					TTYSpec: &atc.HijackTTYSpec{
   171  						WindowSize: atc.HijackWindowSize{
   172  							Columns: cols,
   173  							Rows:    rows,
   174  						},
   175  					},
   176  				}
   177  			}
   178  		case <-finished:
   179  			return
   180  		}
   181  	}
   182  }
   183  
   184  type stdinWriter struct {
   185  	inputs chan<- atc.HijackInput
   186  }
   187  
   188  func (w *stdinWriter) Write(d []byte) (int, error) {
   189  	w.inputs <- atc.HijackInput{
   190  		Stdin: d,
   191  	}
   192  
   193  	return len(d), nil
   194  }
   195  
   196  var websocketSchemeMap = map[string]string{
   197  	"http":  "ws",
   198  	"https": "wss",
   199  }