github.com/wallyworld/juju@v0.0.0-20161013125918-6cf1bc9d917a/worker/uniter/runner/jujuc/server.go (about)

     1  // Copyright 2012, 2013, 2014 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  // The worker/uniter/runner/jujuc package implements the server side of the
     5  // jujuc proxy tool, which forwards command invocations to the unit agent
     6  // process so that they can be executed against specific state.
     7  package jujuc
     8  
     9  import (
    10  	"bytes"
    11  	"fmt"
    12  	"io"
    13  	"net"
    14  	"net/rpc"
    15  	"path/filepath"
    16  	"sort"
    17  	"sync"
    18  
    19  	"github.com/juju/cmd"
    20  	"github.com/juju/errors"
    21  	"github.com/juju/loggo"
    22  	"github.com/juju/utils/exec"
    23  
    24  	"github.com/juju/juju/juju/sockets"
    25  )
    26  
    27  // CmdSuffix is the filename suffix to use for executables.
    28  const CmdSuffix = cmdSuffix
    29  
    30  var logger = loggo.GetLogger("worker.uniter.jujuc")
    31  
    32  // ErrNoStdin is returned by Jujuc.Main if the hook tool requests
    33  // stdin, and none is supplied.
    34  var ErrNoStdin = errors.New("hook tool requires stdin, none supplied")
    35  
    36  type creator func(Context) (cmd.Command, error)
    37  
    38  var registeredCommands = map[string]creator{}
    39  
    40  func RegisterCommand(name string, f creator) {
    41  	registeredCommands[name+cmdSuffix] = f
    42  }
    43  
    44  // baseCommands maps Command names to creators.
    45  var baseCommands = map[string]creator{
    46  	"close-port" + cmdSuffix:              NewClosePortCommand,
    47  	"config-get" + cmdSuffix:              NewConfigGetCommand,
    48  	"juju-log" + cmdSuffix:                NewJujuLogCommand,
    49  	"open-port" + cmdSuffix:               NewOpenPortCommand,
    50  	"opened-ports" + cmdSuffix:            NewOpenedPortsCommand,
    51  	"relation-get" + cmdSuffix:            NewRelationGetCommand,
    52  	"action-get" + cmdSuffix:              NewActionGetCommand,
    53  	"action-set" + cmdSuffix:              NewActionSetCommand,
    54  	"action-fail" + cmdSuffix:             NewActionFailCommand,
    55  	"relation-ids" + cmdSuffix:            NewRelationIdsCommand,
    56  	"relation-list" + cmdSuffix:           NewRelationListCommand,
    57  	"relation-set" + cmdSuffix:            NewRelationSetCommand,
    58  	"unit-get" + cmdSuffix:                NewUnitGetCommand,
    59  	"add-metric" + cmdSuffix:              NewAddMetricCommand,
    60  	"juju-reboot" + cmdSuffix:             NewJujuRebootCommand,
    61  	"status-get" + cmdSuffix:              NewStatusGetCommand,
    62  	"status-set" + cmdSuffix:              NewStatusSetCommand,
    63  	"network-get" + cmdSuffix:             NewNetworkGetCommand,
    64  	"application-version-set" + cmdSuffix: NewApplicationVersionSetCommand,
    65  }
    66  
    67  var storageCommands = map[string]creator{
    68  	"storage-add" + cmdSuffix:  NewStorageAddCommand,
    69  	"storage-get" + cmdSuffix:  NewStorageGetCommand,
    70  	"storage-list" + cmdSuffix: NewStorageListCommand,
    71  }
    72  
    73  var leaderCommands = map[string]creator{
    74  	"is-leader" + cmdSuffix:  NewIsLeaderCommand,
    75  	"leader-get" + cmdSuffix: NewLeaderGetCommand,
    76  	"leader-set" + cmdSuffix: NewLeaderSetCommand,
    77  }
    78  
    79  func allEnabledCommands() map[string]creator {
    80  	all := map[string]creator{}
    81  	add := func(m map[string]creator) {
    82  		for k, v := range m {
    83  			all[k] = v
    84  		}
    85  	}
    86  	add(baseCommands)
    87  	add(storageCommands)
    88  	add(leaderCommands)
    89  	add(registeredCommands)
    90  	return all
    91  }
    92  
    93  // CommandNames returns the names of all jujuc commands.
    94  func CommandNames() (names []string) {
    95  	for name := range allEnabledCommands() {
    96  		names = append(names, name)
    97  	}
    98  	sort.Strings(names)
    99  	return
   100  }
   101  
   102  // NewCommand returns an instance of the named Command, initialized to execute
   103  // against the supplied Context.
   104  func NewCommand(ctx Context, name string) (cmd.Command, error) {
   105  	f := allEnabledCommands()[name]
   106  	if f == nil {
   107  		return nil, errors.Errorf("unknown command: %s", name)
   108  	}
   109  	command, err := f(ctx)
   110  	if err != nil {
   111  		return nil, errors.Trace(err)
   112  	}
   113  	return command, nil
   114  }
   115  
   116  // Request contains the information necessary to run a Command remotely.
   117  type Request struct {
   118  	ContextId   string
   119  	Dir         string
   120  	CommandName string
   121  	Args        []string
   122  
   123  	// StdinSet indicates whether or not the client supplied stdin. This is
   124  	// necessary as Stdin will be nil if the client supplied stdin but it
   125  	// is empty.
   126  	StdinSet bool
   127  	Stdin    []byte
   128  }
   129  
   130  // CmdGetter looks up a Command implementation connected to a particular Context.
   131  type CmdGetter func(contextId, cmdName string) (cmd.Command, error)
   132  
   133  // Jujuc implements the jujuc command in the form required by net/rpc.
   134  type Jujuc struct {
   135  	mu     sync.Mutex
   136  	getCmd CmdGetter
   137  }
   138  
   139  // badReqErrorf returns an error indicating a bad Request.
   140  func badReqErrorf(format string, v ...interface{}) error {
   141  	return fmt.Errorf("bad request: "+format, v...)
   142  }
   143  
   144  // Main runs the Command specified by req, and fills in resp. A single command
   145  // is run at a time.
   146  func (j *Jujuc) Main(req Request, resp *exec.ExecResponse) error {
   147  	if req.CommandName == "" {
   148  		return badReqErrorf("command not specified")
   149  	}
   150  	if !filepath.IsAbs(req.Dir) {
   151  		return badReqErrorf("Dir is not absolute")
   152  	}
   153  	c, err := j.getCmd(req.ContextId, req.CommandName)
   154  	if err != nil {
   155  		return badReqErrorf("%s", err)
   156  	}
   157  	var stdin io.Reader
   158  	if req.StdinSet {
   159  		stdin = bytes.NewReader(req.Stdin)
   160  	} else {
   161  		// noStdinReader will error with ErrNoStdin
   162  		// if its Read method is called.
   163  		stdin = noStdinReader{}
   164  	}
   165  	var stdout, stderr bytes.Buffer
   166  	ctx := &cmd.Context{
   167  		Dir:    req.Dir,
   168  		Stdin:  stdin,
   169  		Stdout: &stdout,
   170  		Stderr: &stderr,
   171  	}
   172  	j.mu.Lock()
   173  	defer j.mu.Unlock()
   174  	// Beware, reducing the log level of the following line will lead
   175  	// to passwords leaking if passed as args.
   176  	logger.Tracef("running hook tool %q %q", req.CommandName, req.Args)
   177  	logger.Tracef("running hook tool %q", req.CommandName)
   178  	logger.Debugf("hook context id %q; dir %q", req.ContextId, req.Dir)
   179  	wrapper := &cmdWrapper{c, nil}
   180  	resp.Code = cmd.Main(wrapper, ctx, req.Args)
   181  	if errors.Cause(wrapper.err) == ErrNoStdin {
   182  		return ErrNoStdin
   183  	}
   184  	resp.Stdout = stdout.Bytes()
   185  	resp.Stderr = stderr.Bytes()
   186  	return nil
   187  }
   188  
   189  // Server implements a server that serves command invocations via
   190  // a unix domain socket.
   191  type Server struct {
   192  	socketPath string
   193  	listener   net.Listener
   194  	server     *rpc.Server
   195  	closed     chan bool
   196  	closing    chan bool
   197  	wg         sync.WaitGroup
   198  }
   199  
   200  // NewServer creates an RPC server bound to socketPath, which can execute
   201  // remote command invocations against an appropriate Context. It will not
   202  // actually do so until Run is called.
   203  func NewServer(getCmd CmdGetter, socketPath string) (*Server, error) {
   204  	server := rpc.NewServer()
   205  	if err := server.Register(&Jujuc{getCmd: getCmd}); err != nil {
   206  		return nil, err
   207  	}
   208  	listener, err := sockets.Listen(socketPath)
   209  	if err != nil {
   210  		return nil, err
   211  	}
   212  	s := &Server{
   213  		socketPath: socketPath,
   214  		listener:   listener,
   215  		server:     server,
   216  		closed:     make(chan bool),
   217  		closing:    make(chan bool),
   218  	}
   219  	return s, nil
   220  }
   221  
   222  // Run accepts new connections until it encounters an error, or until Close is
   223  // called, and then blocks until all existing connections have been closed.
   224  func (s *Server) Run() (err error) {
   225  	var conn net.Conn
   226  	for {
   227  		conn, err = s.listener.Accept()
   228  		if err != nil {
   229  			break
   230  		}
   231  		s.wg.Add(1)
   232  		go func(conn net.Conn) {
   233  			s.server.ServeConn(conn)
   234  			s.wg.Done()
   235  		}(conn)
   236  	}
   237  	select {
   238  	case <-s.closing:
   239  		// Someone has called Close(), so it is overwhelmingly likely that
   240  		// the error from Accept is a direct result of the Listener being
   241  		// closed, and can therefore be safely ignored.
   242  		err = nil
   243  	default:
   244  	}
   245  	s.wg.Wait()
   246  	close(s.closed)
   247  	return
   248  }
   249  
   250  // Close immediately stops accepting connections, and blocks until all existing
   251  // connections have been closed.
   252  func (s *Server) Close() {
   253  	close(s.closing)
   254  	s.listener.Close()
   255  	<-s.closed
   256  }
   257  
   258  type noStdinReader struct{}
   259  
   260  // Read implements io.Reader, simply returning ErrNoStdin any time it's called.
   261  func (noStdinReader) Read([]byte) (int, error) {
   262  	return 0, ErrNoStdin
   263  }
   264  
   265  // cmdWrapper wraps a cmd.Command's Run method so the error returned can be
   266  // intercepted when the command is run via cmd.Main.
   267  type cmdWrapper struct {
   268  	cmd.Command
   269  	err error
   270  }
   271  
   272  func (c *cmdWrapper) Run(ctx *cmd.Context) error {
   273  	c.err = c.Command.Run(ctx)
   274  	return c.err
   275  }