github.com/juju/juju@v0.0.0-20240430160146-1752b71fcf00/worker/uniter/runner/jujuc/server.go (about)

     1  // Copyright 2012, 2013, 2014 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package jujuc
     5  
     6  import (
     7  	"bytes"
     8  	"fmt"
     9  	"io"
    10  	"net"
    11  	"net/rpc"
    12  	"os"
    13  	"path/filepath"
    14  	"sort"
    15  	"sync"
    16  
    17  	"github.com/juju/cmd/v3"
    18  	"github.com/juju/errors"
    19  	"github.com/juju/loggo"
    20  	"github.com/juju/utils/v3/exec"
    21  
    22  	jujucmd "github.com/juju/juju/cmd"
    23  	"github.com/juju/juju/juju/sockets"
    24  )
    25  
    26  // This logger is fine being package level as this jujuc executable
    27  // is separate from the uniter in that it is run inside hooks.
    28  var logger = loggo.GetLogger("jujuc")
    29  
    30  // ErrNoStdin is returned by Jujuc.Main if the hook tool requests
    31  // stdin, and none is supplied.
    32  var ErrNoStdin = errors.New("hook tool requires stdin, none supplied")
    33  
    34  type creator func(Context) (cmd.Command, error)
    35  
    36  // baseCommands maps Command names to creators.
    37  var baseCommands = map[string]creator{
    38  	"close-port":              NewClosePortCommand,
    39  	"config-get":              NewConfigGetCommand,
    40  	"juju-log":                NewJujuLogCommand,
    41  	"open-port":               NewOpenPortCommand,
    42  	"opened-ports":            NewOpenedPortsCommand,
    43  	"relation-get":            NewRelationGetCommand,
    44  	"relation-ids":            NewRelationIdsCommand,
    45  	"relation-list":           NewRelationListCommand,
    46  	"relation-set":            NewRelationSetCommand,
    47  	"unit-get":                NewUnitGetCommand,
    48  	"add-metric":              NewAddMetricCommand,
    49  	"juju-reboot":             NewJujuRebootCommand,
    50  	"status-get":              NewStatusGetCommand,
    51  	"status-set":              NewStatusSetCommand,
    52  	"network-get":             NewNetworkGetCommand,
    53  	"application-version-set": NewApplicationVersionSetCommand,
    54  	"k8s-spec-set":            constructCommandCreator("k8s-spec-set", NewK8sSpecSetCommand),
    55  	"k8s-spec-get":            constructCommandCreator("k8s-spec-get", NewK8sSpecGetCommand),
    56  	"k8s-raw-set":             NewK8sRawSetCommand,
    57  	"k8s-raw-get":             NewK8sRawGetCommand,
    58  	// "pod" variants are deprecated.
    59  	"pod-spec-set": constructCommandCreator("pod-spec-set", NewK8sSpecSetCommand),
    60  	"pod-spec-get": constructCommandCreator("pod-spec-get", NewK8sSpecGetCommand),
    61  
    62  	"goal-state":     NewGoalStateCommand,
    63  	"credential-get": NewCredentialGetCommand,
    64  
    65  	"action-get":  NewActionGetCommand,
    66  	"action-set":  NewActionSetCommand,
    67  	"action-fail": NewActionFailCommand,
    68  	"action-log":  NewActionLogCommand,
    69  
    70  	"state-get":    NewStateGetCommand,
    71  	"state-delete": NewStateDeleteCommand,
    72  	"state-set":    NewStateSetCommand,
    73  }
    74  
    75  type functionCmdCreator func(Context, string) (cmd.Command, error)
    76  
    77  func constructCommandCreator(name string, newCmd functionCmdCreator) creator {
    78  	return func(ctx Context) (cmd.Command, error) {
    79  		return newCmd(ctx, name)
    80  	}
    81  }
    82  
    83  var secretCommands = map[string]creator{
    84  	"secret-add":      NewSecretAddCommand,
    85  	"secret-set":      NewSecretSetCommand,
    86  	"secret-remove":   NewSecretRemoveCommand,
    87  	"secret-get":      NewSecretGetCommand,
    88  	"secret-info-get": NewSecretInfoGetCommand,
    89  	"secret-grant":    NewSecretGrantCommand,
    90  	"secret-revoke":   NewSecretRevokeCommand,
    91  	"secret-ids":      NewSecretIdsCommand,
    92  }
    93  
    94  var storageCommands = map[string]creator{
    95  	"storage-add":  NewStorageAddCommand,
    96  	"storage-get":  NewStorageGetCommand,
    97  	"storage-list": NewStorageListCommand,
    98  }
    99  
   100  var leaderCommands = map[string]creator{
   101  	"is-leader":  NewIsLeaderCommand,
   102  	"leader-get": NewLeaderGetCommand,
   103  	"leader-set": NewLeaderSetCommand,
   104  }
   105  
   106  var resourceCommands = map[string]creator{
   107  	"resource-get": NewResourceGetCmd,
   108  }
   109  
   110  var payloadCommands = map[string]creator{
   111  	"payload-register":   NewPayloadRegisterCmd,
   112  	"payload-unregister": NewPayloadUnregisterCmd,
   113  	"payload-status-set": NewPayloadStatusSetCmd,
   114  }
   115  
   116  func allEnabledCommands() map[string]creator {
   117  	all := map[string]creator{}
   118  	add := func(m map[string]creator) {
   119  		for k, v := range m {
   120  			all[k] = v
   121  		}
   122  	}
   123  	add(baseCommands)
   124  	add(storageCommands)
   125  	add(leaderCommands)
   126  	add(resourceCommands)
   127  	add(payloadCommands)
   128  	add(secretCommands)
   129  	return all
   130  }
   131  
   132  // CommandNames returns the names of all jujuc commands.
   133  func CommandNames() (names []string) {
   134  	for name := range allEnabledCommands() {
   135  		names = append(names, name)
   136  	}
   137  	sort.Strings(names)
   138  	return
   139  }
   140  
   141  // NewCommand returns an instance of the named Command, initialized to execute
   142  // against the supplied Context.
   143  func NewCommand(ctx Context, name string) (cmd.Command, error) {
   144  	f := allEnabledCommands()[name]
   145  	if f == nil {
   146  		return nil, errors.Errorf("unknown command: %s", name)
   147  	}
   148  	command, err := f(ctx)
   149  	if err != nil {
   150  		return nil, errors.Trace(err)
   151  	}
   152  	return command, nil
   153  }
   154  
   155  // Request contains the information necessary to run a Command remotely.
   156  type Request struct {
   157  	ContextId   string
   158  	Dir         string
   159  	CommandName string
   160  	Args        []string
   161  
   162  	// StdinSet indicates whether or not the client supplied stdin. This is
   163  	// necessary as Stdin will be nil if the client supplied stdin but it
   164  	// is empty.
   165  	StdinSet bool
   166  	Stdin    []byte
   167  
   168  	Token string
   169  }
   170  
   171  // CmdGetter looks up a Command implementation connected to a particular Context.
   172  type CmdGetter func(contextId, cmdName string) (cmd.Command, error)
   173  
   174  // Jujuc implements the jujuc command in the form required by net/rpc.
   175  type Jujuc struct {
   176  	mu     sync.Mutex
   177  	getCmd CmdGetter
   178  	token  string
   179  }
   180  
   181  // badReqErrorf returns an error indicating a bad Request.
   182  func badReqErrorf(format string, v ...interface{}) error {
   183  	return fmt.Errorf("bad request: "+format, v...)
   184  }
   185  
   186  // Main runs the Command specified by req, and fills in resp. A single command
   187  // is run at a time.
   188  func (j *Jujuc) Main(req Request, resp *exec.ExecResponse) error {
   189  	if req.Token != j.token {
   190  		return badReqErrorf("token does not match")
   191  	}
   192  	if req.CommandName == "" {
   193  		return badReqErrorf("command not specified")
   194  	}
   195  	if !filepath.IsAbs(req.Dir) {
   196  		return badReqErrorf("Dir is not absolute")
   197  	}
   198  	c, err := j.getCmd(req.ContextId, req.CommandName)
   199  	if err != nil {
   200  		return badReqErrorf("%s", err)
   201  	}
   202  	var stdin io.Reader
   203  	if req.StdinSet {
   204  		stdin = bytes.NewReader(req.Stdin)
   205  	} else {
   206  		// noStdinReader will error with ErrNoStdin
   207  		// if its Read method is called.
   208  		stdin = noStdinReader{}
   209  	}
   210  	var stdout, stderr bytes.Buffer
   211  	ctx := &cmd.Context{
   212  		Dir:    req.Dir,
   213  		Stdin:  stdin,
   214  		Stdout: &stdout,
   215  		Stderr: &stderr,
   216  	}
   217  	j.mu.Lock()
   218  	defer j.mu.Unlock()
   219  	// Beware, reducing the log level of the following line will lead
   220  	// to passwords leaking if passed as args.
   221  	logger.Tracef("running hook tool %q %q", req.CommandName, req.Args)
   222  	logger.Debugf("running hook tool %q for %s", req.CommandName, req.ContextId)
   223  	logger.Tracef("hook context id %q; dir %q", req.ContextId, req.Dir)
   224  	wrapper := &cmdWrapper{c, nil}
   225  	resp.Code = cmd.Main(wrapper, ctx, req.Args)
   226  	if errors.Cause(wrapper.err) == ErrNoStdin {
   227  		return ErrNoStdin
   228  	}
   229  	resp.Stdout = stdout.Bytes()
   230  	resp.Stderr = stderr.Bytes()
   231  	return nil
   232  }
   233  
   234  // Server implements a server that serves command invocations via
   235  // a unix domain socket.
   236  type Server struct {
   237  	socket   sockets.Socket
   238  	listener net.Listener
   239  	server   *rpc.Server
   240  	closed   chan bool
   241  	closing  chan bool
   242  	wg       sync.WaitGroup
   243  }
   244  
   245  // NewServer creates an RPC server bound to socketPath, which can execute
   246  // remote command invocations against an appropriate Context. It will not
   247  // actually do so until Run is called.
   248  func NewServer(getCmd CmdGetter, socket sockets.Socket, token string) (*Server, error) {
   249  	server := rpc.NewServer()
   250  	if err := server.Register(&Jujuc{getCmd: getCmd, token: token}); err != nil {
   251  		return nil, err
   252  	}
   253  	listener, err := sockets.Listen(socket)
   254  	if err != nil {
   255  		return nil, errors.Annotate(err, "listening to jujuc socket")
   256  	}
   257  	s := &Server{
   258  		socket:   socket,
   259  		listener: listener,
   260  		server:   server,
   261  		closed:   make(chan bool),
   262  		closing:  make(chan bool),
   263  	}
   264  	return s, nil
   265  }
   266  
   267  // Run accepts new connections until it encounters an error, or until Close is
   268  // called, and then blocks until all existing connections have been closed.
   269  func (s *Server) Run() (err error) {
   270  	var conn net.Conn
   271  	for {
   272  		conn, err = s.listener.Accept()
   273  		if err != nil {
   274  			break
   275  		}
   276  		s.wg.Add(1)
   277  		go func(conn net.Conn) {
   278  			s.server.ServeConn(conn)
   279  			s.wg.Done()
   280  		}(conn)
   281  	}
   282  	select {
   283  	case <-s.closing:
   284  		// Someone has called Close(), so it is overwhelmingly likely that
   285  		// the error from Accept is a direct result of the Listener being
   286  		// closed, and can therefore be safely ignored.
   287  		err = nil
   288  	default:
   289  	}
   290  	s.wg.Wait()
   291  	close(s.closed)
   292  	return
   293  }
   294  
   295  // Close immediately stops accepting connections, and blocks until all existing
   296  // connections have been closed.
   297  func (s *Server) Close() {
   298  	close(s.closing)
   299  	s.listener.Close()
   300  	// We need to remove the socket path because
   301  	// we renamed the path after opening the
   302  	// socket and it won't be cleaned up automatically.
   303  	// Ignore error as we can't do much here
   304  	// anyway and remove the path if we start the
   305  	// server again.
   306  	_ = os.Remove(s.socket.Address)
   307  	<-s.closed
   308  }
   309  
   310  type noStdinReader struct{}
   311  
   312  // Read implements io.Reader, simply returning ErrNoStdin any time it's called.
   313  func (noStdinReader) Read([]byte) (int, error) {
   314  	return 0, ErrNoStdin
   315  }
   316  
   317  // cmdWrapper wraps a cmd.Command's Run method so the error returned can be
   318  // intercepted when the command is run via cmd.Main.
   319  type cmdWrapper struct {
   320  	cmd.Command
   321  	err error
   322  }
   323  
   324  func (c *cmdWrapper) Run(ctx *cmd.Context) error {
   325  	c.err = c.Command.Run(ctx)
   326  	return c.err
   327  }
   328  
   329  func (c *cmdWrapper) Info() *cmd.Info {
   330  	return jujucmd.Info(c.Command.Info())
   331  }