github.com/juju/juju@v0.0.0-20240327075706-a90865de2538/worker/uniter/runlistener.go (about)

     1  // Copyright 2013 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  // The run listener is a worker go-routine that listens on either a unix
     5  // socket or a tcp connection for juju-exec commands.
     6  
     7  package uniter
     8  
     9  import (
    10  	"net"
    11  	"net/rpc"
    12  	"os"
    13  	"path/filepath"
    14  	"sync"
    15  
    16  	"github.com/juju/errors"
    17  	"github.com/juju/names/v5"
    18  	"github.com/juju/utils/v3/exec"
    19  	"github.com/juju/worker/v3"
    20  	"gopkg.in/tomb.v2"
    21  	"gopkg.in/yaml.v2"
    22  
    23  	"github.com/juju/juju/agent"
    24  	"github.com/juju/juju/caas"
    25  	agentconfig "github.com/juju/juju/cmd/jujud/agent/config"
    26  	"github.com/juju/juju/juju/sockets"
    27  	"github.com/juju/juju/worker/uniter/operation"
    28  	"github.com/juju/juju/worker/uniter/runcommands"
    29  	"github.com/juju/juju/worker/uniter/runner"
    30  )
    31  
    32  const JujuExecEndpoint = "JujuExecServer.RunCommands"
    33  
    34  var errCommandAborted = errors.New("command execution aborted")
    35  
    36  // RunCommandsArgs stores the arguments for a RunCommands call.
    37  type RunCommandsArgs struct {
    38  	// Commands is the arbitrary commands to execute on the unit
    39  	Commands string
    40  	// RelationId is the relation context to execute the commands in.
    41  	RelationId int
    42  	// RemoteUnitName is the remote unit for the relation context.
    43  	RemoteUnitName string
    44  	// RemoteUnitName is the remote unit for the relation context.
    45  	RemoteApplicationName string
    46  	// ForceRemoteUnit skips relation membership and existence validation.
    47  	ForceRemoteUnit bool
    48  	// UnitName is the unit for which the command is being run.
    49  	UnitName string
    50  	// Token is the unit token when run under CAAS environments for auth.
    51  	Token string
    52  	// Operator is true when the command should be run on the operator.
    53  	// This only affects k8s workload charms.
    54  	Operator bool
    55  }
    56  
    57  // A CommandRunner is something that will actually execute the commands and
    58  // return the results of that execution in the exec.ExecResponse (which
    59  // contains stdout, stderr, and return code).
    60  type CommandRunner interface {
    61  	RunCommands(RunCommandsArgs RunCommandsArgs) (results *exec.ExecResponse, err error)
    62  }
    63  
    64  // RunListener is responsible for listening on the network connection and
    65  // setting up the rpc server on that net connection. Also starts the go routine
    66  // that listens and hands off the work.
    67  type RunListener struct {
    68  	logger Logger
    69  
    70  	mu sync.Mutex
    71  
    72  	// commandRunners holds the CommandRunner that will run commands
    73  	// for each unit name.
    74  	commandRunners map[string]CommandRunner
    75  
    76  	listener net.Listener
    77  	server   *rpc.Server
    78  	closed   chan struct{}
    79  	closing  chan struct{}
    80  	wg       sync.WaitGroup
    81  
    82  	requiresAuth bool
    83  }
    84  
    85  // NewRunListener returns a new RunListener that is listening on given
    86  // socket or named pipe passed in. If a valid RunListener is returned, is
    87  // has the go routine running, and should be closed by the creator
    88  // when they are done with it.
    89  func NewRunListener(socket sockets.Socket, logger Logger) (*RunListener, error) {
    90  	listener, err := sockets.Listen(socket)
    91  	if err != nil {
    92  		return nil, errors.Trace(err)
    93  	}
    94  	runListener := &RunListener{
    95  		logger:         logger,
    96  		listener:       listener,
    97  		commandRunners: make(map[string]CommandRunner),
    98  		server:         rpc.NewServer(),
    99  		closed:         make(chan struct{}),
   100  		closing:        make(chan struct{}),
   101  	}
   102  	if socket.Network == "tcp" || socket.TLSConfig != nil {
   103  		runListener.requiresAuth = true
   104  	}
   105  	if err := runListener.server.Register(&JujuExecServer{runListener, logger}); err != nil {
   106  		return nil, errors.Trace(err)
   107  	}
   108  	// TODO (stickupkid) - We should probably log out when an accept fails, so
   109  	// we can at least track it.
   110  	go func() { _ = runListener.Run() }()
   111  	return runListener, nil
   112  }
   113  
   114  // Run accepts new connections until it encounters an error, or until Close is
   115  // called, and then blocks until all existing connections have been closed.
   116  func (r *RunListener) Run() (err error) {
   117  	r.logger.Debugf("juju-exec listener running")
   118  	var conn net.Conn
   119  	for {
   120  		conn, err = r.listener.Accept()
   121  		if err != nil {
   122  			break
   123  		}
   124  		r.wg.Add(1)
   125  		go func(conn net.Conn) {
   126  			r.server.ServeConn(conn)
   127  			r.wg.Done()
   128  		}(conn)
   129  	}
   130  	r.logger.Debugf("juju-exec listener stopping")
   131  	select {
   132  	case <-r.closing:
   133  		// Someone has called Close(), so it is overwhelmingly likely that
   134  		// the error from Accept is a direct result of the Listener being
   135  		// closed, and can therefore be safely ignored.
   136  		err = nil
   137  	default:
   138  	}
   139  	r.wg.Wait()
   140  	close(r.closed)
   141  	return
   142  }
   143  
   144  // Close immediately stops accepting connections, and blocks until all existing
   145  // connections have been closed.
   146  func (r *RunListener) Close() error {
   147  	defer func() {
   148  		<-r.closed
   149  		r.logger.Debugf("juju-exec listener stopped")
   150  	}()
   151  	close(r.closing)
   152  	return r.listener.Close()
   153  }
   154  
   155  // RegisterRunner registers a command runner for a given unit.
   156  func (r *RunListener) RegisterRunner(unitName string, runner CommandRunner) {
   157  	r.mu.Lock()
   158  	r.commandRunners[unitName] = runner
   159  	r.mu.Unlock()
   160  }
   161  
   162  // UnregisterRunner unregisters a command runner for a given unit.
   163  func (r *RunListener) UnregisterRunner(unitName string) {
   164  	r.mu.Lock()
   165  	delete(r.commandRunners, unitName)
   166  	r.mu.Unlock()
   167  }
   168  
   169  // RunCommands executes the supplied commands in a hook context.
   170  func (r *RunListener) RunCommands(args RunCommandsArgs) (results *exec.ExecResponse, err error) {
   171  	r.logger.Debugf("run commands on unit %v: %s", args.UnitName, args.Commands)
   172  	if args.UnitName == "" {
   173  		return nil, errors.New("missing unit name running command")
   174  	}
   175  	r.mu.Lock()
   176  	runner, ok := r.commandRunners[args.UnitName]
   177  	r.mu.Unlock()
   178  	if !ok {
   179  		return nil, errors.Errorf("no runner is registered for unit %v", args.UnitName)
   180  	}
   181  
   182  	if r.requiresAuth {
   183  		// TODO: Cache unit password
   184  		baseDir := agent.Dir(agentconfig.DataDir, names.NewUnitTag(args.UnitName))
   185  		infoFilePath := filepath.Join(baseDir, caas.OperatorClientInfoCacheFile)
   186  		d, err := os.ReadFile(infoFilePath)
   187  		if err != nil {
   188  			return nil, errors.Annotatef(err, "reading %s", infoFilePath)
   189  		}
   190  		op := caas.OperatorClientInfo{}
   191  		err = yaml.Unmarshal(d, &op)
   192  		if err != nil {
   193  			return nil, errors.Trace(err)
   194  		}
   195  		if args.Token != op.Token {
   196  			return nil, errors.Forbiddenf("unit token mismatch")
   197  		}
   198  	}
   199  
   200  	return runner.RunCommands(args)
   201  }
   202  
   203  // NewRunListenerWrapper returns a worker that will Close the supplied run
   204  // listener when the worker is killed. The Wait() method will never return
   205  // an error -- NewRunListener just drops the Run error on the floor and that's
   206  // not what I'm fixing here.
   207  func NewRunListenerWrapper(rl *RunListener, logger Logger) worker.Worker {
   208  	rlw := &runListenerWrapper{logger: logger, rl: rl}
   209  	rlw.tomb.Go(func() error {
   210  		defer rlw.tearDown()
   211  		<-rlw.tomb.Dying()
   212  		return nil
   213  	})
   214  	return rlw
   215  }
   216  
   217  type runListenerWrapper struct {
   218  	logger Logger
   219  	tomb   tomb.Tomb
   220  	rl     *RunListener
   221  }
   222  
   223  func (rlw *runListenerWrapper) tearDown() {
   224  	if err := rlw.rl.Close(); err != nil {
   225  		rlw.logger.Warningf("error closing runlistener: %v", err)
   226  	}
   227  }
   228  
   229  // Kill is part of the worker.Worker interface.
   230  func (rlw *runListenerWrapper) Kill() {
   231  	rlw.tomb.Kill(nil)
   232  }
   233  
   234  // Wait is part of the worker.Worker interface.
   235  func (rlw *runListenerWrapper) Wait() error {
   236  	return rlw.tomb.Wait()
   237  }
   238  
   239  // The JujuExecServer is the entity that has the methods that are called over
   240  // the rpc connection.
   241  type JujuExecServer struct {
   242  	runner CommandRunner
   243  	logger Logger
   244  }
   245  
   246  // RunCommands delegates the actual running to the runner and populates the
   247  // response structure.
   248  func (r *JujuExecServer) RunCommands(args RunCommandsArgs, result *exec.ExecResponse) error {
   249  	r.logger.Debugf("RunCommands: %+v", args)
   250  	runResult, err := r.runner.RunCommands(args)
   251  	if err != nil {
   252  		return errors.Annotate(err, "r.runner.RunCommands")
   253  	}
   254  	*result = *runResult
   255  	return err
   256  }
   257  
   258  // ChannelCommandRunnerConfig contains the configuration for a ChannelCommandRunner.
   259  type ChannelCommandRunnerConfig struct {
   260  	// Abort is a channel that will be closed when the runner should abort
   261  	// the execution of run commands.
   262  	Abort <-chan struct{}
   263  
   264  	// Commands is used to add commands received from the listener.
   265  	Commands runcommands.Commands
   266  
   267  	// CommandChannel will be sent the IDs of commands added to Commands.
   268  	CommandChannel chan<- string
   269  }
   270  
   271  func (cfg ChannelCommandRunnerConfig) Validate() error {
   272  	if cfg.Abort == nil {
   273  		return errors.NotValidf("Abort unspecified")
   274  	}
   275  	if cfg.Commands == nil {
   276  		return errors.NotValidf("Commands unspecified")
   277  	}
   278  	if cfg.CommandChannel == nil {
   279  		return errors.NotValidf("CommandChannel unspecified")
   280  	}
   281  	return nil
   282  }
   283  
   284  // ChannelCommandRunner is a CommandRunner that registers command
   285  // arguments in a runcommands.Commands, sends the returned IDs to
   286  // a channel and waits for response callbacks.
   287  type ChannelCommandRunner struct {
   288  	config ChannelCommandRunnerConfig
   289  }
   290  
   291  // NewChannelCommandRunner returns a new ChannelCommandRunner with the
   292  // given configuration.
   293  func NewChannelCommandRunner(cfg ChannelCommandRunnerConfig) (*ChannelCommandRunner, error) {
   294  	if err := cfg.Validate(); err != nil {
   295  		return nil, errors.Trace(err)
   296  	}
   297  	return &ChannelCommandRunner{cfg}, nil
   298  }
   299  
   300  // RunCommands executes the supplied run commands by registering the
   301  // arguments in a runcommands.Commands, and then sending the returned
   302  // ID to a channel and waiting for a response callback.
   303  func (c *ChannelCommandRunner) RunCommands(args RunCommandsArgs) (results *exec.ExecResponse, err error) {
   304  	runLocation := runner.Workload
   305  	if args.Operator {
   306  		runLocation = runner.Operator
   307  	}
   308  	operationArgs := operation.CommandArgs{
   309  		Commands:       args.Commands,
   310  		RelationId:     args.RelationId,
   311  		RemoteUnitName: args.RemoteUnitName,
   312  		// TODO(jam): 2019-10-24 Include RemoteAppName
   313  		ForceRemoteUnit: args.ForceRemoteUnit,
   314  		RunLocation:     runLocation,
   315  	}
   316  	if err := operationArgs.Validate(); err != nil {
   317  		return nil, errors.Trace(err)
   318  	}
   319  
   320  	type responseInfo struct {
   321  		response *exec.ExecResponse
   322  		err      error
   323  	}
   324  
   325  	// NOTE(axw) the response channel must be synchronous so that the
   326  	// response is received before the uniter resumes operation, and
   327  	// potentially aborts. This prevents a race when rebooting.
   328  	responseChan := make(chan responseInfo)
   329  	responseFunc := func(response *exec.ExecResponse, err error) bool {
   330  		select {
   331  		case <-c.config.Abort:
   332  			return false
   333  		case responseChan <- responseInfo{response, err}:
   334  			return true
   335  		}
   336  	}
   337  
   338  	id := c.config.Commands.AddCommand(operationArgs, responseFunc)
   339  	select {
   340  	case <-c.config.Abort:
   341  		return nil, errCommandAborted
   342  	case c.config.CommandChannel <- id:
   343  	}
   344  
   345  	select {
   346  	case <-c.config.Abort:
   347  		return nil, errCommandAborted
   348  	case response := <-responseChan:
   349  		return response.response, response.err
   350  	}
   351  }