github.com/hashicorp/nomad/api@v0.0.0-20240306165712-3193ac204f65/allocations_exec.go (about)

     1  // Copyright (c) HashiCorp, Inc.
     2  // SPDX-License-Identifier: MPL-2.0
     3  
     4  package api
     5  
     6  import (
     7  	"context"
     8  	"encoding/json"
     9  	"errors"
    10  	"fmt"
    11  	"io"
    12  	"net/url"
    13  	"strconv"
    14  	"sync"
    15  	"time"
    16  
    17  	"github.com/gorilla/websocket"
    18  )
    19  
    20  const (
    21  	// heartbeatInterval is the amount of time to wait between sending heartbeats
    22  	// during an exec streaming operation
    23  	heartbeatInterval = 10 * time.Second
    24  )
    25  
    26  type execSession struct {
    27  	client  *Client
    28  	alloc   *Allocation
    29  	job     string
    30  	task    string
    31  	tty     bool
    32  	command []string
    33  	action  string
    34  
    35  	stdin  io.Reader
    36  	stdout io.Writer
    37  	stderr io.Writer
    38  
    39  	terminalSizeCh <-chan TerminalSize
    40  
    41  	q *QueryOptions
    42  }
    43  
    44  func (s *execSession) run(ctx context.Context) (exitCode int, err error) {
    45  	ctx, cancelFn := context.WithCancel(ctx)
    46  	defer cancelFn()
    47  
    48  	conn, err := s.startConnection()
    49  	if err != nil {
    50  		return -2, err
    51  	}
    52  	defer conn.Close()
    53  
    54  	sendErrCh := s.startTransmit(ctx, conn)
    55  	exitCh, recvErrCh := s.startReceiving(ctx, conn)
    56  
    57  	for {
    58  		select {
    59  		case <-ctx.Done():
    60  			return -2, ctx.Err()
    61  		case exitCode := <-exitCh:
    62  			return exitCode, nil
    63  		case recvErr := <-recvErrCh:
    64  			// drop websocket code, not relevant to user
    65  			if wsErr, ok := recvErr.(*websocket.CloseError); ok && wsErr.Text != "" {
    66  				return -2, errors.New(wsErr.Text)
    67  			}
    68  
    69  			return -2, recvErr
    70  		case sendErr := <-sendErrCh:
    71  			return -2, fmt.Errorf("failed to send input: %w", sendErr)
    72  		}
    73  	}
    74  }
    75  
    76  func (s *execSession) startConnection() (*websocket.Conn, error) {
    77  	// First, attempt to connect to the node directly, but may fail due to network isolation
    78  	// and network errors.  Fallback to using server-side forwarding instead.
    79  	nodeClient, err := s.client.GetNodeClientWithTimeout(s.alloc.NodeID, ClientConnTimeout, s.q)
    80  	if err == NodeDownErr {
    81  		return nil, NodeDownErr
    82  	}
    83  
    84  	q := s.q
    85  	if q == nil {
    86  		q = &QueryOptions{}
    87  	}
    88  	if q.Params == nil {
    89  		q.Params = make(map[string]string)
    90  	}
    91  
    92  	commandBytes, err := json.Marshal(s.command)
    93  	if err != nil {
    94  		return nil, fmt.Errorf("failed to marshal command: %W", err)
    95  	}
    96  
    97  	q.Params["tty"] = strconv.FormatBool(s.tty)
    98  	q.Params["task"] = s.task
    99  	q.Params["command"] = string(commandBytes)
   100  	reqPath := fmt.Sprintf("/v1/client/allocation/%s/exec", s.alloc.ID)
   101  
   102  	if s.action != "" {
   103  		q.Params["action"] = s.action
   104  		q.Params["allocID"] = s.alloc.ID
   105  		q.Params["group"] = s.alloc.TaskGroup
   106  		reqPath = fmt.Sprintf("/v1/job/%s/action", url.PathEscape(s.job))
   107  	}
   108  
   109  	var conn *websocket.Conn
   110  
   111  	if nodeClient != nil {
   112  		conn, _, _ = nodeClient.websocket(reqPath, q) //nolint:bodyclose // gorilla/websocket Dialer.DialContext() does not require the body to be closed.
   113  	}
   114  
   115  	if conn == nil {
   116  		conn, _, err = s.client.websocket(reqPath, q) //nolint:bodyclose // gorilla/websocket Dialer.DialContext() does not require the body to be closed.
   117  		if err != nil {
   118  			return nil, err
   119  		}
   120  	}
   121  
   122  	return conn, nil
   123  }
   124  
   125  func (s *execSession) startTransmit(ctx context.Context, conn *websocket.Conn) <-chan error {
   126  
   127  	// FIXME: Handle websocket send errors.
   128  	// Currently, websocket write failures are dropped. As sending and
   129  	// receiving are running concurrently, it's expected that some send
   130  	// requests may fail with connection errors when connection closes.
   131  	// Connection errors should surface in the receive paths already,
   132  	// but I'm unsure about one-sided communication errors.
   133  	var sendLock sync.Mutex
   134  	send := func(v *ExecStreamingInput) {
   135  		sendLock.Lock()
   136  		defer sendLock.Unlock()
   137  
   138  		conn.WriteJSON(v)
   139  	}
   140  
   141  	errCh := make(chan error, 4)
   142  
   143  	// propagate stdin
   144  	go func() {
   145  
   146  		bytes := make([]byte, 2048)
   147  		for {
   148  			if ctx.Err() != nil {
   149  				return
   150  			}
   151  
   152  			input := ExecStreamingInput{Stdin: &ExecStreamingIOOperation{}}
   153  
   154  			n, err := s.stdin.Read(bytes)
   155  
   156  			// always send data if we read some
   157  			if n != 0 {
   158  				input.Stdin.Data = bytes[:n]
   159  				send(&input)
   160  			}
   161  
   162  			// then handle error
   163  			if err == io.EOF {
   164  				// if n != 0, send data and we'll get n = 0 on next read
   165  				if n == 0 {
   166  					input.Stdin.Close = true
   167  					send(&input)
   168  					return
   169  				}
   170  			} else if err != nil {
   171  				errCh <- err
   172  				return
   173  			}
   174  		}
   175  	}()
   176  
   177  	// propagate terminal sizing updates
   178  	go func() {
   179  		for {
   180  			resizeInput := ExecStreamingInput{}
   181  
   182  			select {
   183  			case <-ctx.Done():
   184  				return
   185  			case size, ok := <-s.terminalSizeCh:
   186  				if !ok {
   187  					return
   188  				}
   189  				resizeInput.TTYSize = &size
   190  				send(&resizeInput)
   191  			}
   192  
   193  		}
   194  	}()
   195  
   196  	// send a heartbeat every 10 seconds
   197  	go func() {
   198  		t := time.NewTimer(heartbeatInterval)
   199  		defer t.Stop()
   200  
   201  		for {
   202  			t.Reset(heartbeatInterval)
   203  
   204  			select {
   205  			case <-ctx.Done():
   206  				return
   207  			case <-t.C:
   208  				// heartbeat message
   209  				send(&execStreamingInputHeartbeat)
   210  			}
   211  		}
   212  	}()
   213  
   214  	return errCh
   215  }
   216  
   217  func (s *execSession) startReceiving(ctx context.Context, conn *websocket.Conn) (<-chan int, <-chan error) {
   218  	exitCodeCh := make(chan int, 1)
   219  	errCh := make(chan error, 1)
   220  
   221  	go func() {
   222  		for ctx.Err() == nil {
   223  
   224  			// Decode the next frame
   225  			var frame ExecStreamingOutput
   226  			err := conn.ReadJSON(&frame)
   227  			if websocket.IsCloseError(err, websocket.CloseNormalClosure) {
   228  				errCh <- fmt.Errorf("websocket closed before receiving exit code: %w", err)
   229  				return
   230  			} else if err != nil {
   231  				errCh <- err
   232  				return
   233  			}
   234  
   235  			switch {
   236  			case frame.Stdout != nil:
   237  				if len(frame.Stdout.Data) != 0 {
   238  					s.stdout.Write(frame.Stdout.Data)
   239  				}
   240  				// don't really do anything if stdout is closing
   241  			case frame.Stderr != nil:
   242  				if len(frame.Stderr.Data) != 0 {
   243  					s.stderr.Write(frame.Stderr.Data)
   244  				}
   245  				// don't really do anything if stderr is closing
   246  			case frame.Exited && frame.Result != nil:
   247  				exitCodeCh <- frame.Result.ExitCode
   248  				return
   249  			default:
   250  				// noop - heartbeat
   251  			}
   252  
   253  		}
   254  
   255  	}()
   256  
   257  	return exitCodeCh, errCh
   258  }