github.com/juju/juju@v0.0.0-20240430160146-1752b71fcf00/apiserver/embeddedcli.go (about)

     1  // Copyright 2020 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package apiserver
     5  
     6  import (
     7  	"fmt"
     8  	"io"
     9  	"net/http"
    10  	"strings"
    11  	"time"
    12  
    13  	"github.com/go-macaroon-bakery/macaroon-bakery/v3/bakery"
    14  	gorillaws "github.com/gorilla/websocket"
    15  	"github.com/juju/cmd/v3"
    16  	"github.com/juju/errors"
    17  	"github.com/juju/featureflag"
    18  	"github.com/juju/loggo"
    19  	"github.com/juju/names/v5"
    20  	"github.com/mitchellh/go-linereader"
    21  
    22  	apiservererrors "github.com/juju/juju/apiserver/errors"
    23  	"github.com/juju/juju/apiserver/httpcontext"
    24  	"github.com/juju/juju/apiserver/websocket"
    25  	"github.com/juju/juju/core/model"
    26  	"github.com/juju/juju/feature"
    27  	"github.com/juju/juju/jujuclient"
    28  	"github.com/juju/juju/rpc/params"
    29  	"github.com/juju/juju/state"
    30  )
    31  
    32  func newEmbeddedCLIHandler(
    33  	ctxt httpContext,
    34  ) http.Handler {
    35  	return &embeddedCLIHandler{
    36  		ctxt:   ctxt,
    37  		logger: loggo.GetLogger("juju.apiserver.embeddedcli"),
    38  	}
    39  }
    40  
    41  // embeddedCLIHandler handles requests to run Juju CLi commands directly on the controller.
    42  type embeddedCLIHandler struct {
    43  	ctxt   httpContext
    44  	logger loggo.Logger
    45  }
    46  
    47  // ServeHTTP implements the http.Handler interface.
    48  func (h *embeddedCLIHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
    49  	handler := func(socket *websocket.Conn) {
    50  		h.logger.Tracef("start of *embeddedCLIHandler.ServeHTTP")
    51  		defer socket.Close()
    52  
    53  		// If we get to here, no more errors to report, so we report a nil
    54  		// error.  This way the first line of the socket is always a json
    55  		// formatted simple error.
    56  		if sendErr := socket.SendInitialErrorV0(nil); sendErr != nil {
    57  			h.logger.Errorf("closing websocket, %v", sendErr)
    58  			return
    59  		}
    60  
    61  		// Here we configure the ping/pong handling for the websocket so
    62  		// the server can notice when the client goes away.
    63  		// See the long note in logsink.go for the rationale.
    64  		_ = socket.SetReadDeadline(time.Now().Add(websocket.PongDelay))
    65  		socket.SetPongHandler(func(string) error {
    66  			_ = socket.SetReadDeadline(time.Now().Add(websocket.PongDelay))
    67  			return nil
    68  		})
    69  		ticker := time.NewTicker(websocket.PingPeriod)
    70  		defer ticker.Stop()
    71  
    72  		modelUUID := httpcontext.RequestModelUUID(req)
    73  		commandCh := h.receiveCommands(socket)
    74  		for {
    75  			select {
    76  			case <-h.ctxt.stop():
    77  				return
    78  			case <-ticker.C:
    79  				deadline := time.Now().Add(websocket.WriteWait)
    80  				if err := socket.WriteControl(gorillaws.PingMessage, []byte{}, deadline); err != nil {
    81  					// This error is expected if the other end goes away. By
    82  					// returning we close the socket through the defer call.
    83  					h.logger.Debugf("failed to write ping: %s", err)
    84  					return
    85  				}
    86  			case jujuCmd := <-commandCh:
    87  				h.logger.Debugf("running embedded commands: %#v", jujuCmd)
    88  				cmdErr := h.runEmbeddedCommands(socket, modelUUID, jujuCmd)
    89  				// Only developers need this for debugging.
    90  				if cmdErr != nil && featureflag.Enabled(feature.DeveloperMode) {
    91  					h.logger.Debugf("command exec error: %v", cmdErr)
    92  				}
    93  				if err := socket.WriteJSON(params.CLICommandStatus{
    94  					Done:  true,
    95  					Error: apiservererrors.ServerError(cmdErr),
    96  				}); err != nil {
    97  					h.logger.Errorf("sending command result to caller: %v", err)
    98  				}
    99  			}
   100  		}
   101  	}
   102  	websocket.Serve(w, req, handler)
   103  }
   104  
   105  func (h *embeddedCLIHandler) receiveCommands(socket *websocket.Conn) <-chan params.CLICommands {
   106  	commandCh := make(chan params.CLICommands)
   107  
   108  	go func() {
   109  		for {
   110  			var cmd params.CLICommands
   111  			// ReadJSON() blocks until data arrives but will also be
   112  			// unblocked when the API handler calls socket.Close as it
   113  			// finishes.
   114  			if err := socket.ReadJSON(&cmd); err != nil {
   115  				// Since we don't give a list of expected error codes,
   116  				// any CloseError type is considered unexpected.
   117  				if gorillaws.IsUnexpectedCloseError(err) {
   118  					h.logger.Tracef("websocket closed")
   119  				} else {
   120  					h.logger.Errorf("embedded CLI receive error: %v", err)
   121  				}
   122  				return
   123  			}
   124  
   125  			// Send the command.
   126  			select {
   127  			case <-h.ctxt.stop():
   128  				return
   129  			case commandCh <- cmd:
   130  			}
   131  		}
   132  	}()
   133  
   134  	return commandCh
   135  }
   136  
   137  func (h *embeddedCLIHandler) runEmbeddedCommands(
   138  	ws *websocket.Conn,
   139  	modelUUID string,
   140  	commands params.CLICommands,
   141  ) error {
   142  
   143  	// Figure out what model to run the commands on.
   144  	resolvedModelUUID := modelUUID
   145  	if resolvedModelUUID == "" {
   146  		systemState, err := h.ctxt.srv.shared.statePool.SystemState()
   147  		if err != nil {
   148  			return errors.Trace(err)
   149  		}
   150  		resolvedModelUUID = systemState.ModelUUID()
   151  	}
   152  	m, closer, err := h.ctxt.srv.shared.statePool.GetModel(resolvedModelUUID)
   153  	if err != nil {
   154  		return errors.Trace(err)
   155  	}
   156  	defer closer.Release()
   157  
   158  	// Make a pipe to stream the stdout/stderr of the commands.
   159  	errCh := make(chan error, 1)
   160  	in, err := runCLICommands(m, errCh, commands, h.ctxt.srv.execEmbeddedCommand)
   161  	if err != nil {
   162  		return errors.Trace(err)
   163  	}
   164  
   165  	var cmdErr error
   166  	lines := newLineReader(in)
   167  	cmdDone := false
   168  	outputDone := false
   169  done:
   170  	for {
   171  		select {
   172  		case <-h.ctxt.stop():
   173  			return errors.New("command aborted due to server shutdown")
   174  		case line, ok := <-lines.Ch:
   175  			if !ok {
   176  				if cmdDone {
   177  					break done
   178  				}
   179  				outputDone = true
   180  				// Wait for cmd result.
   181  				continue
   182  			}
   183  			// If there's been a macaroon discharge required, we don't yet
   184  			// process it in embedded mode so just return it so the caller
   185  			// can deal with it, eg login again to get another macaroon.
   186  			// This string is hard coded in the bakery library.
   187  			idx := strings.Index(line, "cannot get discharge from")
   188  			if idx >= 0 {
   189  				return apiservererrors.ServerError(&apiservererrors.DischargeRequiredError{
   190  					Cause: &bakery.DischargeRequiredError{Message: line[idx:]},
   191  				})
   192  			}
   193  
   194  			if err := ws.WriteJSON(params.CLICommandStatus{
   195  				Output: []string{line},
   196  			}); err != nil {
   197  				h.logger.Warningf("error writing CLI output: %v", err)
   198  				cmdErr = err
   199  				break done
   200  			}
   201  		case cmdErr = <-errCh:
   202  			if outputDone {
   203  				break done
   204  			}
   205  			// Wait for cmd output to all be read.
   206  			cmdDone = true
   207  		}
   208  	}
   209  	return cmdErr
   210  }
   211  
   212  // newLineReader returns a new linereader Reader for the
   213  // provided io Reader.
   214  func newLineReader(r io.Reader) *linereader.Reader {
   215  	// Do the same as linereader.New(), with the juju
   216  	// timeout values.  Changing the timeout of the
   217  	// Reader is unsafe after calling New.
   218  	result := &linereader.Reader{
   219  		Reader:  r,
   220  		Timeout: 10 * time.Millisecond,
   221  		Ch:      make(chan string),
   222  	}
   223  	go result.Run()
   224  	return result
   225  }
   226  
   227  // ExecEmbeddedCommandFunc defines a function which runs a named Juju command
   228  // with the whitelisted sub commands.
   229  type ExecEmbeddedCommandFunc func(ctx *cmd.Context, store jujuclient.ClientStore, whitelist []string, cmdPlusArgs string) int
   230  
   231  // runCLICommands creates a CLI command instance with an in-memory copy of the controller,
   232  // model, and account details and runs the command against the host controller.
   233  func runCLICommands(m *state.Model, errCh chan<- error, commands params.CLICommands, execEmbeddedCommand ExecEmbeddedCommandFunc) (io.Reader, error) {
   234  	if commands.User == "" {
   235  		return nil, errors.NotSupportedf("CLI command for anonymous user")
   236  	}
   237  	// Check passed in username is valid.
   238  	if !names.IsValidUser(commands.User) {
   239  		return nil, errors.NotValidf("user name %q", commands.User)
   240  	}
   241  
   242  	cfg, err := m.State().ControllerConfig()
   243  	if err != nil {
   244  		return nil, errors.Trace(err)
   245  	}
   246  
   247  	// Set up a juju client store used to configure the
   248  	// embedded command to give it the controller, model
   249  	// and account details to use.
   250  	store := jujuclient.NewEmbeddedMemStore()
   251  	cert, _ := cfg.CACert()
   252  	controllerName := cfg.ControllerName()
   253  	if controllerName == "" {
   254  		controllerName = "interactive"
   255  	}
   256  	store.Controllers[controllerName] = jujuclient.ControllerDetails{
   257  		ControllerUUID: cfg.ControllerUUID(),
   258  		APIEndpoints:   []string{fmt.Sprintf("localhost:%d", cfg.APIPort())},
   259  		CACert:         cert,
   260  	}
   261  	store.CurrentControllerName = controllerName
   262  
   263  	qualifiedModelName := jujuclient.JoinOwnerModelName(m.Owner(), m.Name())
   264  	store.Models[controllerName] = &jujuclient.ControllerModels{
   265  		Models: map[string]jujuclient.ModelDetails{
   266  			qualifiedModelName: {
   267  				ModelUUID:    m.UUID(),
   268  				ModelType:    model.ModelType(m.Type()),
   269  				ActiveBranch: commands.ActiveBranch,
   270  			},
   271  		},
   272  		CurrentModel: qualifiedModelName,
   273  	}
   274  	store.Accounts[controllerName] = jujuclient.AccountDetails{
   275  		User:      commands.User,
   276  		Password:  commands.Credentials,
   277  		Macaroons: commands.Macaroons,
   278  	}
   279  
   280  	in, out := io.Pipe()
   281  	go func() {
   282  		defer in.Close()
   283  		for _, cliCmd := range commands.Commands {
   284  			ctx, err := cmd.DefaultContext()
   285  			if err != nil {
   286  				errCh <- errors.Trace(err)
   287  			}
   288  			ctx.Stdout = out
   289  			ctx.Stderr = out
   290  			code := execEmbeddedCommand(ctx, store, allowedEmbeddedCommands, cliCmd)
   291  			if code != 0 {
   292  				errCh <- errors.Annotatef(err, "command %q: exit code %d", cliCmd, code)
   293  				continue
   294  			}
   295  			errCh <- nil
   296  		}
   297  	}()
   298  	return in, nil
   299  }