github.com/anth0d/nomad@v0.0.0-20221214183521-ae3a0a2cad06/plugins/drivers/server.go (about) 1 package drivers 2 3 import ( 4 "context" 5 "fmt" 6 "io" 7 "math" 8 9 "github.com/golang/protobuf/ptypes" 10 "github.com/hashicorp/go-plugin" 11 "google.golang.org/grpc/codes" 12 "google.golang.org/grpc/status" 13 14 "github.com/hashicorp/nomad/nomad/structs" 15 "github.com/hashicorp/nomad/plugins/drivers/proto" 16 dstructs "github.com/hashicorp/nomad/plugins/shared/structs" 17 sproto "github.com/hashicorp/nomad/plugins/shared/structs/proto" 18 ) 19 20 type driverPluginServer struct { 21 broker *plugin.GRPCBroker 22 impl DriverPlugin 23 } 24 25 func (b *driverPluginServer) TaskConfigSchema(ctx context.Context, req *proto.TaskConfigSchemaRequest) (*proto.TaskConfigSchemaResponse, error) { 26 spec, err := b.impl.TaskConfigSchema() 27 if err != nil { 28 return nil, err 29 } 30 31 resp := &proto.TaskConfigSchemaResponse{ 32 Spec: spec, 33 } 34 return resp, nil 35 } 36 37 func (b *driverPluginServer) Capabilities(ctx context.Context, req *proto.CapabilitiesRequest) (*proto.CapabilitiesResponse, error) { 38 caps, err := b.impl.Capabilities() 39 if err != nil { 40 return nil, err 41 } 42 resp := &proto.CapabilitiesResponse{ 43 Capabilities: &proto.DriverCapabilities{ 44 SendSignals: caps.SendSignals, 45 Exec: caps.Exec, 46 MustCreateNetwork: caps.MustInitiateNetwork, 47 NetworkIsolationModes: []proto.NetworkIsolationSpec_NetworkIsolationMode{}, 48 RemoteTasks: caps.RemoteTasks, 49 }, 50 } 51 52 switch caps.FSIsolation { 53 case FSIsolationNone: 54 resp.Capabilities.FsIsolation = proto.DriverCapabilities_NONE 55 case FSIsolationChroot: 56 resp.Capabilities.FsIsolation = proto.DriverCapabilities_CHROOT 57 case FSIsolationImage: 58 resp.Capabilities.FsIsolation = proto.DriverCapabilities_IMAGE 59 default: 60 resp.Capabilities.FsIsolation = proto.DriverCapabilities_NONE 61 } 62 63 for _, mode := range caps.NetIsolationModes { 64 resp.Capabilities.NetworkIsolationModes = append(resp.Capabilities.NetworkIsolationModes, netIsolationModeToProto(mode)) 65 } 66 return resp, nil 67 } 68 69 func (b *driverPluginServer) Fingerprint(req *proto.FingerprintRequest, srv proto.Driver_FingerprintServer) error { 70 ctx := srv.Context() 71 ch, err := b.impl.Fingerprint(ctx) 72 if err != nil { 73 return err 74 } 75 76 for { 77 select { 78 case <-ctx.Done(): 79 return nil 80 case f, ok := <-ch: 81 82 if !ok { 83 return nil 84 } 85 resp := &proto.FingerprintResponse{ 86 Attributes: dstructs.ConvertStructAttributeMap(f.Attributes), 87 Health: healthStateToProto(f.Health), 88 HealthDescription: f.HealthDescription, 89 } 90 91 if err := srv.Send(resp); err != nil { 92 return err 93 } 94 } 95 } 96 } 97 98 func (b *driverPluginServer) RecoverTask(ctx context.Context, req *proto.RecoverTaskRequest) (*proto.RecoverTaskResponse, error) { 99 err := b.impl.RecoverTask(taskHandleFromProto(req.Handle)) 100 if err != nil { 101 return nil, err 102 } 103 104 return &proto.RecoverTaskResponse{}, nil 105 } 106 107 func (b *driverPluginServer) StartTask(ctx context.Context, req *proto.StartTaskRequest) (*proto.StartTaskResponse, error) { 108 handle, net, err := b.impl.StartTask(taskConfigFromProto(req.Task)) 109 if err != nil { 110 if rec, ok := err.(structs.Recoverable); ok { 111 st := status.New(codes.FailedPrecondition, rec.Error()) 112 st, err := st.WithDetails(&sproto.RecoverableError{Recoverable: rec.IsRecoverable()}) 113 if err != nil { 114 // If this error, it will always error 115 panic(err) 116 } 117 return nil, st.Err() 118 } 119 return nil, err 120 } 121 122 var pbNet *proto.NetworkOverride 123 if net != nil { 124 pbNet = &proto.NetworkOverride{ 125 PortMap: map[string]int32{}, 126 Addr: net.IP, 127 AutoAdvertise: net.AutoAdvertise, 128 } 129 for k, v := range net.PortMap { 130 if v > math.MaxInt32 { 131 return nil, fmt.Errorf("port map out of bounds") 132 } 133 pbNet.PortMap[k] = int32(v) 134 } 135 } 136 137 resp := &proto.StartTaskResponse{ 138 Handle: taskHandleToProto(handle), 139 NetworkOverride: pbNet, 140 } 141 142 return resp, nil 143 } 144 145 func (b *driverPluginServer) WaitTask(ctx context.Context, req *proto.WaitTaskRequest) (*proto.WaitTaskResponse, error) { 146 ch, err := b.impl.WaitTask(ctx, req.TaskId) 147 if err != nil { 148 return nil, err 149 } 150 151 var ok bool 152 var result *ExitResult 153 select { 154 case <-ctx.Done(): 155 return nil, ctx.Err() 156 case result, ok = <-ch: 157 if !ok { 158 return &proto.WaitTaskResponse{ 159 Err: "channel closed", 160 }, nil 161 } 162 } 163 164 var errStr string 165 if result.Err != nil { 166 errStr = result.Err.Error() 167 } 168 169 resp := &proto.WaitTaskResponse{ 170 Err: errStr, 171 Result: &proto.ExitResult{ 172 ExitCode: int32(result.ExitCode), 173 Signal: int32(result.Signal), 174 OomKilled: result.OOMKilled, 175 }, 176 } 177 178 return resp, nil 179 } 180 181 func (b *driverPluginServer) StopTask(ctx context.Context, req *proto.StopTaskRequest) (*proto.StopTaskResponse, error) { 182 timeout, err := ptypes.Duration(req.Timeout) 183 if err != nil { 184 return nil, err 185 } 186 187 err = b.impl.StopTask(req.TaskId, timeout, req.Signal) 188 if err != nil { 189 return nil, err 190 } 191 return &proto.StopTaskResponse{}, nil 192 } 193 194 func (b *driverPluginServer) DestroyTask(ctx context.Context, req *proto.DestroyTaskRequest) (*proto.DestroyTaskResponse, error) { 195 err := b.impl.DestroyTask(req.TaskId, req.Force) 196 if err != nil { 197 return nil, err 198 } 199 return &proto.DestroyTaskResponse{}, nil 200 } 201 202 func (b *driverPluginServer) InspectTask(ctx context.Context, req *proto.InspectTaskRequest) (*proto.InspectTaskResponse, error) { 203 status, err := b.impl.InspectTask(req.TaskId) 204 if err != nil { 205 return nil, err 206 } 207 208 protoStatus, err := taskStatusToProto(status) 209 if err != nil { 210 return nil, err 211 } 212 213 var pbNet *proto.NetworkOverride 214 if status.NetworkOverride != nil { 215 pbNet = &proto.NetworkOverride{ 216 PortMap: map[string]int32{}, 217 Addr: status.NetworkOverride.IP, 218 AutoAdvertise: status.NetworkOverride.AutoAdvertise, 219 } 220 for k, v := range status.NetworkOverride.PortMap { 221 pbNet.PortMap[k] = int32(v) 222 } 223 } 224 225 resp := &proto.InspectTaskResponse{ 226 Task: protoStatus, 227 Driver: &proto.TaskDriverStatus{ 228 Attributes: status.DriverAttributes, 229 }, 230 NetworkOverride: pbNet, 231 } 232 233 return resp, nil 234 } 235 236 func (b *driverPluginServer) TaskStats(req *proto.TaskStatsRequest, srv proto.Driver_TaskStatsServer) error { 237 interval, err := ptypes.Duration(req.CollectionInterval) 238 if err != nil { 239 return fmt.Errorf("failed to parse collection interval: %v", err) 240 } 241 242 ch, err := b.impl.TaskStats(srv.Context(), req.TaskId, interval) 243 if err != nil { 244 if rec, ok := err.(structs.Recoverable); ok { 245 st := status.New(codes.FailedPrecondition, rec.Error()) 246 st, err := st.WithDetails(&sproto.RecoverableError{Recoverable: rec.IsRecoverable()}) 247 if err != nil { 248 // If this error, it will always error 249 panic(err) 250 } 251 return st.Err() 252 } 253 return err 254 } 255 256 for stats := range ch { 257 pb, err := TaskStatsToProto(stats) 258 if err != nil { 259 return fmt.Errorf("failed to encode task stats: %v", err) 260 } 261 262 if err = srv.Send(&proto.TaskStatsResponse{Stats: pb}); err == io.EOF { 263 break 264 } else if err != nil { 265 return err 266 } 267 268 } 269 270 return nil 271 } 272 273 func (b *driverPluginServer) ExecTask(ctx context.Context, req *proto.ExecTaskRequest) (*proto.ExecTaskResponse, error) { 274 timeout, err := ptypes.Duration(req.Timeout) 275 if err != nil { 276 return nil, err 277 } 278 279 result, err := b.impl.ExecTask(req.TaskId, req.Command, timeout) 280 if err != nil { 281 return nil, err 282 } 283 resp := &proto.ExecTaskResponse{ 284 Stdout: result.Stdout, 285 Stderr: result.Stderr, 286 Result: exitResultToProto(result.ExitResult), 287 } 288 289 return resp, nil 290 } 291 292 func (b *driverPluginServer) ExecTaskStreaming(server proto.Driver_ExecTaskStreamingServer) error { 293 msg, err := server.Recv() 294 if err != nil { 295 return fmt.Errorf("failed to receive initial message: %v", err) 296 } 297 298 if msg.Setup == nil { 299 return fmt.Errorf("first message should always be setup") 300 } 301 302 if impl, ok := b.impl.(ExecTaskStreamingRawDriver); ok { 303 return impl.ExecTaskStreamingRaw(server.Context(), 304 msg.Setup.TaskId, msg.Setup.Command, msg.Setup.Tty, 305 server) 306 } 307 308 d, ok := b.impl.(ExecTaskStreamingDriver) 309 if !ok { 310 return fmt.Errorf("driver does not support exec") 311 } 312 313 execOpts, errCh := StreamToExecOptions(server.Context(), 314 msg.Setup.Command, msg.Setup.Tty, 315 server) 316 317 result, err := d.ExecTaskStreaming(server.Context(), 318 msg.Setup.TaskId, execOpts) 319 320 execOpts.Stdout.Close() 321 execOpts.Stderr.Close() 322 323 if err != nil { 324 return err 325 } 326 327 // wait for copy to be done 328 select { 329 case err = <-errCh: 330 case <-server.Context().Done(): 331 err = fmt.Errorf("exec timed out: %v", server.Context().Err()) 332 } 333 334 if err != nil { 335 return err 336 } 337 338 server.Send(&ExecTaskStreamingResponseMsg{ 339 Exited: true, 340 Result: exitResultToProto(result), 341 }) 342 343 return err 344 } 345 346 func (b *driverPluginServer) SignalTask(ctx context.Context, req *proto.SignalTaskRequest) (*proto.SignalTaskResponse, error) { 347 err := b.impl.SignalTask(req.TaskId, req.Signal) 348 if err != nil { 349 return nil, err 350 } 351 352 resp := &proto.SignalTaskResponse{} 353 return resp, nil 354 } 355 356 func (b *driverPluginServer) TaskEvents(req *proto.TaskEventsRequest, srv proto.Driver_TaskEventsServer) error { 357 ch, err := b.impl.TaskEvents(srv.Context()) 358 if err != nil { 359 return err 360 } 361 362 for { 363 event := <-ch 364 if event == nil { 365 break 366 } 367 pbTimestamp, err := ptypes.TimestampProto(event.Timestamp) 368 if err != nil { 369 return err 370 } 371 372 pbEvent := &proto.DriverTaskEvent{ 373 TaskId: event.TaskID, 374 AllocId: event.AllocID, 375 TaskName: event.TaskName, 376 Timestamp: pbTimestamp, 377 Message: event.Message, 378 Annotations: event.Annotations, 379 } 380 381 if err = srv.Send(pbEvent); err == io.EOF { 382 break 383 } else if err != nil { 384 return err 385 } 386 } 387 return nil 388 } 389 390 func (b *driverPluginServer) CreateNetwork(ctx context.Context, req *proto.CreateNetworkRequest) (*proto.CreateNetworkResponse, error) { 391 nm, ok := b.impl.(DriverNetworkManager) 392 if !ok { 393 return nil, fmt.Errorf("CreateNetwork RPC not supported by driver") 394 } 395 396 spec, created, err := nm.CreateNetwork(req.GetAllocId(), networkCreateRequestFromProto(req)) 397 if err != nil { 398 return nil, err 399 } 400 401 return &proto.CreateNetworkResponse{ 402 IsolationSpec: NetworkIsolationSpecToProto(spec), 403 Created: created, 404 }, nil 405 } 406 407 func (b *driverPluginServer) DestroyNetwork(ctx context.Context, req *proto.DestroyNetworkRequest) (*proto.DestroyNetworkResponse, error) { 408 nm, ok := b.impl.(DriverNetworkManager) 409 if !ok { 410 return nil, fmt.Errorf("DestroyNetwork RPC not supported by driver") 411 } 412 413 err := nm.DestroyNetwork(req.AllocId, NetworkIsolationSpecFromProto(req.IsolationSpec)) 414 if err != nil { 415 return nil, err 416 } 417 418 return &proto.DestroyNetworkResponse{}, nil 419 }