github.com/niedbalski/juju@v0.0.0-20190215020005-8ff100488e47/worker/uniter/runner/jujuc/server.go (about)

     1  // Copyright 2012, 2013, 2014 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  // The worker/uniter/runner/jujuc package implements the server side of the
     5  // jujuc proxy tool, which forwards command invocations to the unit agent
     6  // process so that they can be executed against specific state.
     7  package jujuc
     8  
     9  import (
    10  	"bytes"
    11  	"fmt"
    12  	"io"
    13  	"net"
    14  	"net/rpc"
    15  	"os"
    16  	"path/filepath"
    17  	"sort"
    18  	"sync"
    19  
    20  	"github.com/juju/cmd"
    21  	"github.com/juju/errors"
    22  	"github.com/juju/loggo"
    23  	"github.com/juju/utils/exec"
    24  
    25  	jujucmd "github.com/juju/juju/cmd"
    26  	"github.com/juju/juju/juju/sockets"
    27  )
    28  
    29  // CmdSuffix is the filename suffix to use for executables.
    30  const CmdSuffix = cmdSuffix
    31  
    32  var logger = loggo.GetLogger("worker.uniter.jujuc")
    33  
    34  // ErrNoStdin is returned by Jujuc.Main if the hook tool requests
    35  // stdin, and none is supplied.
    36  var ErrNoStdin = errors.New("hook tool requires stdin, none supplied")
    37  
    38  type creator func(Context) (cmd.Command, error)
    39  
    40  var registeredCommands = map[string]creator{}
    41  
    42  func RegisterCommand(name string, f creator) {
    43  	registeredCommands[name+cmdSuffix] = f
    44  }
    45  
    46  // baseCommands maps Command names to creators.
    47  var baseCommands = map[string]creator{
    48  	"close-port" + cmdSuffix:              NewClosePortCommand,
    49  	"config-get" + cmdSuffix:              NewConfigGetCommand,
    50  	"juju-log" + cmdSuffix:                NewJujuLogCommand,
    51  	"open-port" + cmdSuffix:               NewOpenPortCommand,
    52  	"opened-ports" + cmdSuffix:            NewOpenedPortsCommand,
    53  	"relation-get" + cmdSuffix:            NewRelationGetCommand,
    54  	"action-get" + cmdSuffix:              NewActionGetCommand,
    55  	"action-set" + cmdSuffix:              NewActionSetCommand,
    56  	"action-fail" + cmdSuffix:             NewActionFailCommand,
    57  	"relation-ids" + cmdSuffix:            NewRelationIdsCommand,
    58  	"relation-list" + cmdSuffix:           NewRelationListCommand,
    59  	"relation-set" + cmdSuffix:            NewRelationSetCommand,
    60  	"unit-get" + cmdSuffix:                NewUnitGetCommand,
    61  	"add-metric" + cmdSuffix:              NewAddMetricCommand,
    62  	"juju-reboot" + cmdSuffix:             NewJujuRebootCommand,
    63  	"status-get" + cmdSuffix:              NewStatusGetCommand,
    64  	"status-set" + cmdSuffix:              NewStatusSetCommand,
    65  	"network-get" + cmdSuffix:             NewNetworkGetCommand,
    66  	"application-version-set" + cmdSuffix: NewApplicationVersionSetCommand,
    67  	"pod-spec-set" + cmdSuffix:            NewPodSpecSetCommand,
    68  	"goal-state" + cmdSuffix:              NewGoalStateCommand,
    69  	"credential-get" + cmdSuffix:          NewCredentialGetCommand,
    70  }
    71  
    72  var storageCommands = map[string]creator{
    73  	"storage-add" + cmdSuffix:  NewStorageAddCommand,
    74  	"storage-get" + cmdSuffix:  NewStorageGetCommand,
    75  	"storage-list" + cmdSuffix: NewStorageListCommand,
    76  }
    77  
    78  var leaderCommands = map[string]creator{
    79  	"is-leader" + cmdSuffix:  NewIsLeaderCommand,
    80  	"leader-get" + cmdSuffix: NewLeaderGetCommand,
    81  	"leader-set" + cmdSuffix: NewLeaderSetCommand,
    82  }
    83  
    84  func allEnabledCommands() map[string]creator {
    85  	all := map[string]creator{}
    86  	add := func(m map[string]creator) {
    87  		for k, v := range m {
    88  			all[k] = v
    89  		}
    90  	}
    91  	add(baseCommands)
    92  	add(storageCommands)
    93  	add(leaderCommands)
    94  	add(registeredCommands)
    95  	return all
    96  }
    97  
    98  // CommandNames returns the names of all jujuc commands.
    99  func CommandNames() (names []string) {
   100  	for name := range allEnabledCommands() {
   101  		names = append(names, name)
   102  	}
   103  	sort.Strings(names)
   104  	return
   105  }
   106  
   107  // NewCommand returns an instance of the named Command, initialized to execute
   108  // against the supplied Context.
   109  func NewCommand(ctx Context, name string) (cmd.Command, error) {
   110  	f := allEnabledCommands()[name]
   111  	if f == nil {
   112  		return nil, errors.Errorf("unknown command: %s", name)
   113  	}
   114  	command, err := f(ctx)
   115  	if err != nil {
   116  		return nil, errors.Trace(err)
   117  	}
   118  	return command, nil
   119  }
   120  
   121  // Request contains the information necessary to run a Command remotely.
   122  type Request struct {
   123  	ContextId   string
   124  	Dir         string
   125  	CommandName string
   126  	Args        []string
   127  
   128  	// StdinSet indicates whether or not the client supplied stdin. This is
   129  	// necessary as Stdin will be nil if the client supplied stdin but it
   130  	// is empty.
   131  	StdinSet bool
   132  	Stdin    []byte
   133  }
   134  
   135  // CmdGetter looks up a Command implementation connected to a particular Context.
   136  type CmdGetter func(contextId, cmdName string) (cmd.Command, error)
   137  
   138  // Jujuc implements the jujuc command in the form required by net/rpc.
   139  type Jujuc struct {
   140  	mu     sync.Mutex
   141  	getCmd CmdGetter
   142  }
   143  
   144  // badReqErrorf returns an error indicating a bad Request.
   145  func badReqErrorf(format string, v ...interface{}) error {
   146  	return fmt.Errorf("bad request: "+format, v...)
   147  }
   148  
   149  // Main runs the Command specified by req, and fills in resp. A single command
   150  // is run at a time.
   151  func (j *Jujuc) Main(req Request, resp *exec.ExecResponse) error {
   152  	if req.CommandName == "" {
   153  		return badReqErrorf("command not specified")
   154  	}
   155  	if !filepath.IsAbs(req.Dir) {
   156  		return badReqErrorf("Dir is not absolute")
   157  	}
   158  	c, err := j.getCmd(req.ContextId, req.CommandName)
   159  	if err != nil {
   160  		return badReqErrorf("%s", err)
   161  	}
   162  	var stdin io.Reader
   163  	if req.StdinSet {
   164  		stdin = bytes.NewReader(req.Stdin)
   165  	} else {
   166  		// noStdinReader will error with ErrNoStdin
   167  		// if its Read method is called.
   168  		stdin = noStdinReader{}
   169  	}
   170  	var stdout, stderr bytes.Buffer
   171  	ctx := &cmd.Context{
   172  		Dir:    req.Dir,
   173  		Stdin:  stdin,
   174  		Stdout: &stdout,
   175  		Stderr: &stderr,
   176  	}
   177  	j.mu.Lock()
   178  	defer j.mu.Unlock()
   179  	// Beware, reducing the log level of the following line will lead
   180  	// to passwords leaking if passed as args.
   181  	logger.Tracef("running hook tool %q %q", req.CommandName, req.Args)
   182  	logger.Debugf("running hook tool %q", req.CommandName)
   183  	logger.Tracef("hook context id %q; dir %q", req.ContextId, req.Dir)
   184  	wrapper := &cmdWrapper{c, nil}
   185  	resp.Code = cmd.Main(wrapper, ctx, req.Args)
   186  	if errors.Cause(wrapper.err) == ErrNoStdin {
   187  		return ErrNoStdin
   188  	}
   189  	resp.Stdout = stdout.Bytes()
   190  	resp.Stderr = stderr.Bytes()
   191  	return nil
   192  }
   193  
   194  // Server implements a server that serves command invocations via
   195  // a unix domain socket.
   196  type Server struct {
   197  	socketPath string
   198  	listener   net.Listener
   199  	server     *rpc.Server
   200  	closed     chan bool
   201  	closing    chan bool
   202  	wg         sync.WaitGroup
   203  }
   204  
   205  // NewServer creates an RPC server bound to socketPath, which can execute
   206  // remote command invocations against an appropriate Context. It will not
   207  // actually do so until Run is called.
   208  func NewServer(getCmd CmdGetter, socketPath string) (*Server, error) {
   209  	server := rpc.NewServer()
   210  	if err := server.Register(&Jujuc{getCmd: getCmd}); err != nil {
   211  		return nil, err
   212  	}
   213  	listener, err := sockets.Listen(socketPath)
   214  	if err != nil {
   215  		return nil, errors.Annotate(err, "listening to jujuc socket")
   216  	}
   217  	s := &Server{
   218  		socketPath: socketPath,
   219  		listener:   listener,
   220  		server:     server,
   221  		closed:     make(chan bool),
   222  		closing:    make(chan bool),
   223  	}
   224  	return s, nil
   225  }
   226  
   227  // Run accepts new connections until it encounters an error, or until Close is
   228  // called, and then blocks until all existing connections have been closed.
   229  func (s *Server) Run() (err error) {
   230  	var conn net.Conn
   231  	for {
   232  		conn, err = s.listener.Accept()
   233  		if err != nil {
   234  			break
   235  		}
   236  		s.wg.Add(1)
   237  		go func(conn net.Conn) {
   238  			s.server.ServeConn(conn)
   239  			s.wg.Done()
   240  		}(conn)
   241  	}
   242  	select {
   243  	case <-s.closing:
   244  		// Someone has called Close(), so it is overwhelmingly likely that
   245  		// the error from Accept is a direct result of the Listener being
   246  		// closed, and can therefore be safely ignored.
   247  		err = nil
   248  	default:
   249  	}
   250  	s.wg.Wait()
   251  	close(s.closed)
   252  	return
   253  }
   254  
   255  // Close immediately stops accepting connections, and blocks until all existing
   256  // connections have been closed.
   257  func (s *Server) Close() {
   258  	close(s.closing)
   259  	s.listener.Close()
   260  	// We need to remove the socket path because
   261  	// we renamed the path after opening the
   262  	// socket and it won't be cleaned up automatically.
   263  	// Ignore error as we can't do much here
   264  	// anyway and remove the path if we start the
   265  	// server again.
   266  	os.Remove(s.socketPath)
   267  	<-s.closed
   268  }
   269  
   270  type noStdinReader struct{}
   271  
   272  // Read implements io.Reader, simply returning ErrNoStdin any time it's called.
   273  func (noStdinReader) Read([]byte) (int, error) {
   274  	return 0, ErrNoStdin
   275  }
   276  
   277  // cmdWrapper wraps a cmd.Command's Run method so the error returned can be
   278  // intercepted when the command is run via cmd.Main.
   279  type cmdWrapper struct {
   280  	cmd.Command
   281  	err error
   282  }
   283  
   284  func (c *cmdWrapper) Run(ctx *cmd.Context) error {
   285  	c.err = c.Command.Run(ctx)
   286  	return c.err
   287  }
   288  
   289  func (c *cmdWrapper) Info() *cmd.Info {
   290  	return jujucmd.Info(c.Command.Info())
   291  }