github.com/pf-qiu/concourse/v6@v6.7.3-0.20201207032516-1f455d73275f/fly/commands/internal/hijacker/hijacker.go (about) 1 package hijacker 2 3 import ( 4 "context" 5 "crypto/tls" 6 "fmt" 7 "io" 8 "net/http" 9 "os" 10 "time" 11 12 "github.com/pf-qiu/concourse/v6/atc" 13 "github.com/pf-qiu/concourse/v6/fly/pty" 14 "github.com/pf-qiu/concourse/v6/fly/rc" 15 "github.com/pf-qiu/concourse/v6/fly/ui" 16 "github.com/gorilla/websocket" 17 "github.com/mgutz/ansi" 18 "github.com/tedsuo/rata" 19 ) 20 21 type ProcessIO struct { 22 In chan atc.HijackInput 23 Out io.Writer 24 Err io.Writer 25 } 26 27 type Hijacker struct { 28 tlsConfig *tls.Config 29 requestGenerator *rata.RequestGenerator 30 token *rc.TargetToken 31 interval time.Duration 32 } 33 34 func New(tlsConfig *tls.Config, requestGenerator *rata.RequestGenerator, token *rc.TargetToken) *Hijacker { 35 return &Hijacker{ 36 tlsConfig: tlsConfig, 37 requestGenerator: requestGenerator, 38 token: token, 39 interval: 10 * time.Second, 40 } 41 } 42 43 func (h *Hijacker) SetHeartbeatInterval(interval time.Duration) { 44 h.interval = interval 45 } 46 47 func (h *Hijacker) Hijack(ctx context.Context, teamName, handle string, spec atc.HijackProcessSpec, pio ProcessIO) (int, bool, error) { 48 url, header, err := h.hijackRequestParts(teamName, handle) 49 if err != nil { 50 return -1, false, err 51 } 52 53 dialer := websocket.Dialer{ 54 TLSClientConfig: h.tlsConfig, 55 Proxy: http.ProxyFromEnvironment, 56 } 57 conn, response, err := dialer.Dial(url, header) 58 if err != nil { 59 return -1, false, fmt.Errorf("%s %w", response.Status, err) 60 } 61 62 defer conn.Close() 63 64 err = conn.WriteJSON(spec) 65 if err != nil { 66 return -1, false, err 67 } 68 69 ctx, cancel := context.WithCancel(ctx) 70 defer cancel() 71 72 go h.monitorTTYSize(ctx, pio.In) 73 go h.handleInput(ctx, conn, pio.In) 74 75 exitStatus, exeNotFound := h.handleOutput(conn, pio) 76 77 return exitStatus, exeNotFound, nil 78 } 79 80 func (h *Hijacker) hijackRequestParts(teamName, handle string) (string, http.Header, error) { 81 hijackReq, err := h.requestGenerator.CreateRequest( 82 atc.HijackContainer, 83 rata.Params{"id": handle, "team_name": teamName}, 84 nil, 85 ) 86 87 if err != nil { 88 panic(err) 89 } 90 91 if h.token != nil { 92 hijackReq.Header.Add("Authorization", h.token.Type+" "+h.token.Value) 93 } 94 95 wsUrl := hijackReq.URL 96 97 var found bool 98 wsUrl.Scheme, found = websocketSchemeMap[wsUrl.Scheme] 99 if !found { 100 return "", nil, fmt.Errorf("unknown target scheme: %s", wsUrl.Scheme) 101 } 102 103 return wsUrl.String(), hijackReq.Header, nil 104 } 105 106 func (h *Hijacker) handleOutput(conn *websocket.Conn, pio ProcessIO) (int, bool) { 107 var exitStatus int 108 var exeNotFound bool 109 for { 110 var output atc.HijackOutput 111 err := conn.ReadJSON(&output) 112 if err != nil { 113 if !websocket.IsCloseError(err) && !websocket.IsUnexpectedCloseError(err) { 114 fmt.Println(err) 115 } 116 break 117 } 118 119 if output.ExitStatus != nil { 120 exitStatus = *output.ExitStatus 121 } else if output.ExecutableNotFound { 122 exeNotFound = true 123 } else if len(output.Error) > 0 { 124 fmt.Fprintf(ui.Stderr, "%s\n", ansi.Color(output.Error, "red+b")) 125 exitStatus = 255 126 } else if len(output.Stdout) > 0 { 127 pio.Out.Write(output.Stdout) 128 } else if len(output.Stderr) > 0 { 129 pio.Err.Write(output.Stderr) 130 } 131 } 132 133 return exitStatus, exeNotFound 134 } 135 136 func (h *Hijacker) handleInput(ctx context.Context, conn *websocket.Conn, inputs <-chan atc.HijackInput) { 137 ticker := time.NewTicker(h.interval) 138 defer ticker.Stop() 139 140 for { 141 select { 142 case input := <-inputs: 143 err := conn.WriteJSON(input) 144 if err != nil { 145 fmt.Fprintf(ui.Stderr, "failed to send input: %s", err.Error()) 146 return 147 } 148 case t := <-ticker.C: 149 err := conn.WriteControl(websocket.PingMessage, []byte(t.String()), time.Now().Add(time.Second)) 150 if err != nil { 151 fmt.Fprintf(ui.Stderr, "failed to send heartbeat: %s", err.Error()) 152 } 153 case <-ctx.Done(): 154 return 155 } 156 } 157 } 158 159 func (h *Hijacker) monitorTTYSize(ctx context.Context, inputs chan<- atc.HijackInput) { 160 resized := pty.ResizeNotifier() 161 162 for { 163 select { 164 case <-resized: 165 rows, cols, err := pty.Getsize(os.Stdin) 166 if err == nil { 167 inputs <- atc.HijackInput{ 168 TTYSpec: &atc.HijackTTYSpec{ 169 WindowSize: atc.HijackWindowSize{ 170 Columns: cols, 171 Rows: rows, 172 }, 173 }, 174 } 175 } 176 case <-ctx.Done(): 177 return 178 } 179 } 180 } 181 182 var websocketSchemeMap = map[string]string{ 183 "http": "ws", 184 "https": "wss", 185 }