github.com/hashicorp/nomad/api@v0.0.0-20240306165712-3193ac204f65/allocations_exec.go (about) 1 // Copyright (c) HashiCorp, Inc. 2 // SPDX-License-Identifier: MPL-2.0 3 4 package api 5 6 import ( 7 "context" 8 "encoding/json" 9 "errors" 10 "fmt" 11 "io" 12 "net/url" 13 "strconv" 14 "sync" 15 "time" 16 17 "github.com/gorilla/websocket" 18 ) 19 20 const ( 21 // heartbeatInterval is the amount of time to wait between sending heartbeats 22 // during an exec streaming operation 23 heartbeatInterval = 10 * time.Second 24 ) 25 26 type execSession struct { 27 client *Client 28 alloc *Allocation 29 job string 30 task string 31 tty bool 32 command []string 33 action string 34 35 stdin io.Reader 36 stdout io.Writer 37 stderr io.Writer 38 39 terminalSizeCh <-chan TerminalSize 40 41 q *QueryOptions 42 } 43 44 func (s *execSession) run(ctx context.Context) (exitCode int, err error) { 45 ctx, cancelFn := context.WithCancel(ctx) 46 defer cancelFn() 47 48 conn, err := s.startConnection() 49 if err != nil { 50 return -2, err 51 } 52 defer conn.Close() 53 54 sendErrCh := s.startTransmit(ctx, conn) 55 exitCh, recvErrCh := s.startReceiving(ctx, conn) 56 57 for { 58 select { 59 case <-ctx.Done(): 60 return -2, ctx.Err() 61 case exitCode := <-exitCh: 62 return exitCode, nil 63 case recvErr := <-recvErrCh: 64 // drop websocket code, not relevant to user 65 if wsErr, ok := recvErr.(*websocket.CloseError); ok && wsErr.Text != "" { 66 return -2, errors.New(wsErr.Text) 67 } 68 69 return -2, recvErr 70 case sendErr := <-sendErrCh: 71 return -2, fmt.Errorf("failed to send input: %w", sendErr) 72 } 73 } 74 } 75 76 func (s *execSession) startConnection() (*websocket.Conn, error) { 77 // First, attempt to connect to the node directly, but may fail due to network isolation 78 // and network errors. Fallback to using server-side forwarding instead. 79 nodeClient, err := s.client.GetNodeClientWithTimeout(s.alloc.NodeID, ClientConnTimeout, s.q) 80 if err == NodeDownErr { 81 return nil, NodeDownErr 82 } 83 84 q := s.q 85 if q == nil { 86 q = &QueryOptions{} 87 } 88 if q.Params == nil { 89 q.Params = make(map[string]string) 90 } 91 92 commandBytes, err := json.Marshal(s.command) 93 if err != nil { 94 return nil, fmt.Errorf("failed to marshal command: %W", err) 95 } 96 97 q.Params["tty"] = strconv.FormatBool(s.tty) 98 q.Params["task"] = s.task 99 q.Params["command"] = string(commandBytes) 100 reqPath := fmt.Sprintf("/v1/client/allocation/%s/exec", s.alloc.ID) 101 102 if s.action != "" { 103 q.Params["action"] = s.action 104 q.Params["allocID"] = s.alloc.ID 105 q.Params["group"] = s.alloc.TaskGroup 106 reqPath = fmt.Sprintf("/v1/job/%s/action", url.PathEscape(s.job)) 107 } 108 109 var conn *websocket.Conn 110 111 if nodeClient != nil { 112 conn, _, _ = nodeClient.websocket(reqPath, q) //nolint:bodyclose // gorilla/websocket Dialer.DialContext() does not require the body to be closed. 113 } 114 115 if conn == nil { 116 conn, _, err = s.client.websocket(reqPath, q) //nolint:bodyclose // gorilla/websocket Dialer.DialContext() does not require the body to be closed. 117 if err != nil { 118 return nil, err 119 } 120 } 121 122 return conn, nil 123 } 124 125 func (s *execSession) startTransmit(ctx context.Context, conn *websocket.Conn) <-chan error { 126 127 // FIXME: Handle websocket send errors. 128 // Currently, websocket write failures are dropped. As sending and 129 // receiving are running concurrently, it's expected that some send 130 // requests may fail with connection errors when connection closes. 131 // Connection errors should surface in the receive paths already, 132 // but I'm unsure about one-sided communication errors. 133 var sendLock sync.Mutex 134 send := func(v *ExecStreamingInput) { 135 sendLock.Lock() 136 defer sendLock.Unlock() 137 138 conn.WriteJSON(v) 139 } 140 141 errCh := make(chan error, 4) 142 143 // propagate stdin 144 go func() { 145 146 bytes := make([]byte, 2048) 147 for { 148 if ctx.Err() != nil { 149 return 150 } 151 152 input := ExecStreamingInput{Stdin: &ExecStreamingIOOperation{}} 153 154 n, err := s.stdin.Read(bytes) 155 156 // always send data if we read some 157 if n != 0 { 158 input.Stdin.Data = bytes[:n] 159 send(&input) 160 } 161 162 // then handle error 163 if err == io.EOF { 164 // if n != 0, send data and we'll get n = 0 on next read 165 if n == 0 { 166 input.Stdin.Close = true 167 send(&input) 168 return 169 } 170 } else if err != nil { 171 errCh <- err 172 return 173 } 174 } 175 }() 176 177 // propagate terminal sizing updates 178 go func() { 179 for { 180 resizeInput := ExecStreamingInput{} 181 182 select { 183 case <-ctx.Done(): 184 return 185 case size, ok := <-s.terminalSizeCh: 186 if !ok { 187 return 188 } 189 resizeInput.TTYSize = &size 190 send(&resizeInput) 191 } 192 193 } 194 }() 195 196 // send a heartbeat every 10 seconds 197 go func() { 198 t := time.NewTimer(heartbeatInterval) 199 defer t.Stop() 200 201 for { 202 t.Reset(heartbeatInterval) 203 204 select { 205 case <-ctx.Done(): 206 return 207 case <-t.C: 208 // heartbeat message 209 send(&execStreamingInputHeartbeat) 210 } 211 } 212 }() 213 214 return errCh 215 } 216 217 func (s *execSession) startReceiving(ctx context.Context, conn *websocket.Conn) (<-chan int, <-chan error) { 218 exitCodeCh := make(chan int, 1) 219 errCh := make(chan error, 1) 220 221 go func() { 222 for ctx.Err() == nil { 223 224 // Decode the next frame 225 var frame ExecStreamingOutput 226 err := conn.ReadJSON(&frame) 227 if websocket.IsCloseError(err, websocket.CloseNormalClosure) { 228 errCh <- fmt.Errorf("websocket closed before receiving exit code: %w", err) 229 return 230 } else if err != nil { 231 errCh <- err 232 return 233 } 234 235 switch { 236 case frame.Stdout != nil: 237 if len(frame.Stdout.Data) != 0 { 238 s.stdout.Write(frame.Stdout.Data) 239 } 240 // don't really do anything if stdout is closing 241 case frame.Stderr != nil: 242 if len(frame.Stderr.Data) != 0 { 243 s.stderr.Write(frame.Stderr.Data) 244 } 245 // don't really do anything if stderr is closing 246 case frame.Exited && frame.Result != nil: 247 exitCodeCh <- frame.Result.ExitCode 248 return 249 default: 250 // noop - heartbeat 251 } 252 253 } 254 255 }() 256 257 return exitCodeCh, errCh 258 }