github.com/juju/juju@v0.0.0-20240430160146-1752b71fcf00/worker/uniter/runner/jujuc/server.go (about) 1 // Copyright 2012, 2013, 2014 Canonical Ltd. 2 // Licensed under the AGPLv3, see LICENCE file for details. 3 4 package jujuc 5 6 import ( 7 "bytes" 8 "fmt" 9 "io" 10 "net" 11 "net/rpc" 12 "os" 13 "path/filepath" 14 "sort" 15 "sync" 16 17 "github.com/juju/cmd/v3" 18 "github.com/juju/errors" 19 "github.com/juju/loggo" 20 "github.com/juju/utils/v3/exec" 21 22 jujucmd "github.com/juju/juju/cmd" 23 "github.com/juju/juju/juju/sockets" 24 ) 25 26 // This logger is fine being package level as this jujuc executable 27 // is separate from the uniter in that it is run inside hooks. 28 var logger = loggo.GetLogger("jujuc") 29 30 // ErrNoStdin is returned by Jujuc.Main if the hook tool requests 31 // stdin, and none is supplied. 32 var ErrNoStdin = errors.New("hook tool requires stdin, none supplied") 33 34 type creator func(Context) (cmd.Command, error) 35 36 // baseCommands maps Command names to creators. 37 var baseCommands = map[string]creator{ 38 "close-port": NewClosePortCommand, 39 "config-get": NewConfigGetCommand, 40 "juju-log": NewJujuLogCommand, 41 "open-port": NewOpenPortCommand, 42 "opened-ports": NewOpenedPortsCommand, 43 "relation-get": NewRelationGetCommand, 44 "relation-ids": NewRelationIdsCommand, 45 "relation-list": NewRelationListCommand, 46 "relation-set": NewRelationSetCommand, 47 "unit-get": NewUnitGetCommand, 48 "add-metric": NewAddMetricCommand, 49 "juju-reboot": NewJujuRebootCommand, 50 "status-get": NewStatusGetCommand, 51 "status-set": NewStatusSetCommand, 52 "network-get": NewNetworkGetCommand, 53 "application-version-set": NewApplicationVersionSetCommand, 54 "k8s-spec-set": constructCommandCreator("k8s-spec-set", NewK8sSpecSetCommand), 55 "k8s-spec-get": constructCommandCreator("k8s-spec-get", NewK8sSpecGetCommand), 56 "k8s-raw-set": NewK8sRawSetCommand, 57 "k8s-raw-get": NewK8sRawGetCommand, 58 // "pod" variants are deprecated. 59 "pod-spec-set": constructCommandCreator("pod-spec-set", NewK8sSpecSetCommand), 60 "pod-spec-get": constructCommandCreator("pod-spec-get", NewK8sSpecGetCommand), 61 62 "goal-state": NewGoalStateCommand, 63 "credential-get": NewCredentialGetCommand, 64 65 "action-get": NewActionGetCommand, 66 "action-set": NewActionSetCommand, 67 "action-fail": NewActionFailCommand, 68 "action-log": NewActionLogCommand, 69 70 "state-get": NewStateGetCommand, 71 "state-delete": NewStateDeleteCommand, 72 "state-set": NewStateSetCommand, 73 } 74 75 type functionCmdCreator func(Context, string) (cmd.Command, error) 76 77 func constructCommandCreator(name string, newCmd functionCmdCreator) creator { 78 return func(ctx Context) (cmd.Command, error) { 79 return newCmd(ctx, name) 80 } 81 } 82 83 var secretCommands = map[string]creator{ 84 "secret-add": NewSecretAddCommand, 85 "secret-set": NewSecretSetCommand, 86 "secret-remove": NewSecretRemoveCommand, 87 "secret-get": NewSecretGetCommand, 88 "secret-info-get": NewSecretInfoGetCommand, 89 "secret-grant": NewSecretGrantCommand, 90 "secret-revoke": NewSecretRevokeCommand, 91 "secret-ids": NewSecretIdsCommand, 92 } 93 94 var storageCommands = map[string]creator{ 95 "storage-add": NewStorageAddCommand, 96 "storage-get": NewStorageGetCommand, 97 "storage-list": NewStorageListCommand, 98 } 99 100 var leaderCommands = map[string]creator{ 101 "is-leader": NewIsLeaderCommand, 102 "leader-get": NewLeaderGetCommand, 103 "leader-set": NewLeaderSetCommand, 104 } 105 106 var resourceCommands = map[string]creator{ 107 "resource-get": NewResourceGetCmd, 108 } 109 110 var payloadCommands = map[string]creator{ 111 "payload-register": NewPayloadRegisterCmd, 112 "payload-unregister": NewPayloadUnregisterCmd, 113 "payload-status-set": NewPayloadStatusSetCmd, 114 } 115 116 func allEnabledCommands() map[string]creator { 117 all := map[string]creator{} 118 add := func(m map[string]creator) { 119 for k, v := range m { 120 all[k] = v 121 } 122 } 123 add(baseCommands) 124 add(storageCommands) 125 add(leaderCommands) 126 add(resourceCommands) 127 add(payloadCommands) 128 add(secretCommands) 129 return all 130 } 131 132 // CommandNames returns the names of all jujuc commands. 133 func CommandNames() (names []string) { 134 for name := range allEnabledCommands() { 135 names = append(names, name) 136 } 137 sort.Strings(names) 138 return 139 } 140 141 // NewCommand returns an instance of the named Command, initialized to execute 142 // against the supplied Context. 143 func NewCommand(ctx Context, name string) (cmd.Command, error) { 144 f := allEnabledCommands()[name] 145 if f == nil { 146 return nil, errors.Errorf("unknown command: %s", name) 147 } 148 command, err := f(ctx) 149 if err != nil { 150 return nil, errors.Trace(err) 151 } 152 return command, nil 153 } 154 155 // Request contains the information necessary to run a Command remotely. 156 type Request struct { 157 ContextId string 158 Dir string 159 CommandName string 160 Args []string 161 162 // StdinSet indicates whether or not the client supplied stdin. This is 163 // necessary as Stdin will be nil if the client supplied stdin but it 164 // is empty. 165 StdinSet bool 166 Stdin []byte 167 168 Token string 169 } 170 171 // CmdGetter looks up a Command implementation connected to a particular Context. 172 type CmdGetter func(contextId, cmdName string) (cmd.Command, error) 173 174 // Jujuc implements the jujuc command in the form required by net/rpc. 175 type Jujuc struct { 176 mu sync.Mutex 177 getCmd CmdGetter 178 token string 179 } 180 181 // badReqErrorf returns an error indicating a bad Request. 182 func badReqErrorf(format string, v ...interface{}) error { 183 return fmt.Errorf("bad request: "+format, v...) 184 } 185 186 // Main runs the Command specified by req, and fills in resp. A single command 187 // is run at a time. 188 func (j *Jujuc) Main(req Request, resp *exec.ExecResponse) error { 189 if req.Token != j.token { 190 return badReqErrorf("token does not match") 191 } 192 if req.CommandName == "" { 193 return badReqErrorf("command not specified") 194 } 195 if !filepath.IsAbs(req.Dir) { 196 return badReqErrorf("Dir is not absolute") 197 } 198 c, err := j.getCmd(req.ContextId, req.CommandName) 199 if err != nil { 200 return badReqErrorf("%s", err) 201 } 202 var stdin io.Reader 203 if req.StdinSet { 204 stdin = bytes.NewReader(req.Stdin) 205 } else { 206 // noStdinReader will error with ErrNoStdin 207 // if its Read method is called. 208 stdin = noStdinReader{} 209 } 210 var stdout, stderr bytes.Buffer 211 ctx := &cmd.Context{ 212 Dir: req.Dir, 213 Stdin: stdin, 214 Stdout: &stdout, 215 Stderr: &stderr, 216 } 217 j.mu.Lock() 218 defer j.mu.Unlock() 219 // Beware, reducing the log level of the following line will lead 220 // to passwords leaking if passed as args. 221 logger.Tracef("running hook tool %q %q", req.CommandName, req.Args) 222 logger.Debugf("running hook tool %q for %s", req.CommandName, req.ContextId) 223 logger.Tracef("hook context id %q; dir %q", req.ContextId, req.Dir) 224 wrapper := &cmdWrapper{c, nil} 225 resp.Code = cmd.Main(wrapper, ctx, req.Args) 226 if errors.Cause(wrapper.err) == ErrNoStdin { 227 return ErrNoStdin 228 } 229 resp.Stdout = stdout.Bytes() 230 resp.Stderr = stderr.Bytes() 231 return nil 232 } 233 234 // Server implements a server that serves command invocations via 235 // a unix domain socket. 236 type Server struct { 237 socket sockets.Socket 238 listener net.Listener 239 server *rpc.Server 240 closed chan bool 241 closing chan bool 242 wg sync.WaitGroup 243 } 244 245 // NewServer creates an RPC server bound to socketPath, which can execute 246 // remote command invocations against an appropriate Context. It will not 247 // actually do so until Run is called. 248 func NewServer(getCmd CmdGetter, socket sockets.Socket, token string) (*Server, error) { 249 server := rpc.NewServer() 250 if err := server.Register(&Jujuc{getCmd: getCmd, token: token}); err != nil { 251 return nil, err 252 } 253 listener, err := sockets.Listen(socket) 254 if err != nil { 255 return nil, errors.Annotate(err, "listening to jujuc socket") 256 } 257 s := &Server{ 258 socket: socket, 259 listener: listener, 260 server: server, 261 closed: make(chan bool), 262 closing: make(chan bool), 263 } 264 return s, nil 265 } 266 267 // Run accepts new connections until it encounters an error, or until Close is 268 // called, and then blocks until all existing connections have been closed. 269 func (s *Server) Run() (err error) { 270 var conn net.Conn 271 for { 272 conn, err = s.listener.Accept() 273 if err != nil { 274 break 275 } 276 s.wg.Add(1) 277 go func(conn net.Conn) { 278 s.server.ServeConn(conn) 279 s.wg.Done() 280 }(conn) 281 } 282 select { 283 case <-s.closing: 284 // Someone has called Close(), so it is overwhelmingly likely that 285 // the error from Accept is a direct result of the Listener being 286 // closed, and can therefore be safely ignored. 287 err = nil 288 default: 289 } 290 s.wg.Wait() 291 close(s.closed) 292 return 293 } 294 295 // Close immediately stops accepting connections, and blocks until all existing 296 // connections have been closed. 297 func (s *Server) Close() { 298 close(s.closing) 299 s.listener.Close() 300 // We need to remove the socket path because 301 // we renamed the path after opening the 302 // socket and it won't be cleaned up automatically. 303 // Ignore error as we can't do much here 304 // anyway and remove the path if we start the 305 // server again. 306 _ = os.Remove(s.socket.Address) 307 <-s.closed 308 } 309 310 type noStdinReader struct{} 311 312 // Read implements io.Reader, simply returning ErrNoStdin any time it's called. 313 func (noStdinReader) Read([]byte) (int, error) { 314 return 0, ErrNoStdin 315 } 316 317 // cmdWrapper wraps a cmd.Command's Run method so the error returned can be 318 // intercepted when the command is run via cmd.Main. 319 type cmdWrapper struct { 320 cmd.Command 321 err error 322 } 323 324 func (c *cmdWrapper) Run(ctx *cmd.Context) error { 325 c.err = c.Command.Run(ctx) 326 return c.err 327 } 328 329 func (c *cmdWrapper) Info() *cmd.Info { 330 return jujucmd.Info(c.Command.Info()) 331 }