github.com/anth0d/nomad@v0.0.0-20221214183521-ae3a0a2cad06/plugins/drivers/server.go (about)

     1  package drivers
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"io"
     7  	"math"
     8  
     9  	"github.com/golang/protobuf/ptypes"
    10  	"github.com/hashicorp/go-plugin"
    11  	"google.golang.org/grpc/codes"
    12  	"google.golang.org/grpc/status"
    13  
    14  	"github.com/hashicorp/nomad/nomad/structs"
    15  	"github.com/hashicorp/nomad/plugins/drivers/proto"
    16  	dstructs "github.com/hashicorp/nomad/plugins/shared/structs"
    17  	sproto "github.com/hashicorp/nomad/plugins/shared/structs/proto"
    18  )
    19  
    20  type driverPluginServer struct {
    21  	broker *plugin.GRPCBroker
    22  	impl   DriverPlugin
    23  }
    24  
    25  func (b *driverPluginServer) TaskConfigSchema(ctx context.Context, req *proto.TaskConfigSchemaRequest) (*proto.TaskConfigSchemaResponse, error) {
    26  	spec, err := b.impl.TaskConfigSchema()
    27  	if err != nil {
    28  		return nil, err
    29  	}
    30  
    31  	resp := &proto.TaskConfigSchemaResponse{
    32  		Spec: spec,
    33  	}
    34  	return resp, nil
    35  }
    36  
    37  func (b *driverPluginServer) Capabilities(ctx context.Context, req *proto.CapabilitiesRequest) (*proto.CapabilitiesResponse, error) {
    38  	caps, err := b.impl.Capabilities()
    39  	if err != nil {
    40  		return nil, err
    41  	}
    42  	resp := &proto.CapabilitiesResponse{
    43  		Capabilities: &proto.DriverCapabilities{
    44  			SendSignals:           caps.SendSignals,
    45  			Exec:                  caps.Exec,
    46  			MustCreateNetwork:     caps.MustInitiateNetwork,
    47  			NetworkIsolationModes: []proto.NetworkIsolationSpec_NetworkIsolationMode{},
    48  			RemoteTasks:           caps.RemoteTasks,
    49  		},
    50  	}
    51  
    52  	switch caps.FSIsolation {
    53  	case FSIsolationNone:
    54  		resp.Capabilities.FsIsolation = proto.DriverCapabilities_NONE
    55  	case FSIsolationChroot:
    56  		resp.Capabilities.FsIsolation = proto.DriverCapabilities_CHROOT
    57  	case FSIsolationImage:
    58  		resp.Capabilities.FsIsolation = proto.DriverCapabilities_IMAGE
    59  	default:
    60  		resp.Capabilities.FsIsolation = proto.DriverCapabilities_NONE
    61  	}
    62  
    63  	for _, mode := range caps.NetIsolationModes {
    64  		resp.Capabilities.NetworkIsolationModes = append(resp.Capabilities.NetworkIsolationModes, netIsolationModeToProto(mode))
    65  	}
    66  	return resp, nil
    67  }
    68  
    69  func (b *driverPluginServer) Fingerprint(req *proto.FingerprintRequest, srv proto.Driver_FingerprintServer) error {
    70  	ctx := srv.Context()
    71  	ch, err := b.impl.Fingerprint(ctx)
    72  	if err != nil {
    73  		return err
    74  	}
    75  
    76  	for {
    77  		select {
    78  		case <-ctx.Done():
    79  			return nil
    80  		case f, ok := <-ch:
    81  
    82  			if !ok {
    83  				return nil
    84  			}
    85  			resp := &proto.FingerprintResponse{
    86  				Attributes:        dstructs.ConvertStructAttributeMap(f.Attributes),
    87  				Health:            healthStateToProto(f.Health),
    88  				HealthDescription: f.HealthDescription,
    89  			}
    90  
    91  			if err := srv.Send(resp); err != nil {
    92  				return err
    93  			}
    94  		}
    95  	}
    96  }
    97  
    98  func (b *driverPluginServer) RecoverTask(ctx context.Context, req *proto.RecoverTaskRequest) (*proto.RecoverTaskResponse, error) {
    99  	err := b.impl.RecoverTask(taskHandleFromProto(req.Handle))
   100  	if err != nil {
   101  		return nil, err
   102  	}
   103  
   104  	return &proto.RecoverTaskResponse{}, nil
   105  }
   106  
   107  func (b *driverPluginServer) StartTask(ctx context.Context, req *proto.StartTaskRequest) (*proto.StartTaskResponse, error) {
   108  	handle, net, err := b.impl.StartTask(taskConfigFromProto(req.Task))
   109  	if err != nil {
   110  		if rec, ok := err.(structs.Recoverable); ok {
   111  			st := status.New(codes.FailedPrecondition, rec.Error())
   112  			st, err := st.WithDetails(&sproto.RecoverableError{Recoverable: rec.IsRecoverable()})
   113  			if err != nil {
   114  				// If this error, it will always error
   115  				panic(err)
   116  			}
   117  			return nil, st.Err()
   118  		}
   119  		return nil, err
   120  	}
   121  
   122  	var pbNet *proto.NetworkOverride
   123  	if net != nil {
   124  		pbNet = &proto.NetworkOverride{
   125  			PortMap:       map[string]int32{},
   126  			Addr:          net.IP,
   127  			AutoAdvertise: net.AutoAdvertise,
   128  		}
   129  		for k, v := range net.PortMap {
   130  			if v > math.MaxInt32 {
   131  				return nil, fmt.Errorf("port map out of bounds")
   132  			}
   133  			pbNet.PortMap[k] = int32(v)
   134  		}
   135  	}
   136  
   137  	resp := &proto.StartTaskResponse{
   138  		Handle:          taskHandleToProto(handle),
   139  		NetworkOverride: pbNet,
   140  	}
   141  
   142  	return resp, nil
   143  }
   144  
   145  func (b *driverPluginServer) WaitTask(ctx context.Context, req *proto.WaitTaskRequest) (*proto.WaitTaskResponse, error) {
   146  	ch, err := b.impl.WaitTask(ctx, req.TaskId)
   147  	if err != nil {
   148  		return nil, err
   149  	}
   150  
   151  	var ok bool
   152  	var result *ExitResult
   153  	select {
   154  	case <-ctx.Done():
   155  		return nil, ctx.Err()
   156  	case result, ok = <-ch:
   157  		if !ok {
   158  			return &proto.WaitTaskResponse{
   159  				Err: "channel closed",
   160  			}, nil
   161  		}
   162  	}
   163  
   164  	var errStr string
   165  	if result.Err != nil {
   166  		errStr = result.Err.Error()
   167  	}
   168  
   169  	resp := &proto.WaitTaskResponse{
   170  		Err: errStr,
   171  		Result: &proto.ExitResult{
   172  			ExitCode:  int32(result.ExitCode),
   173  			Signal:    int32(result.Signal),
   174  			OomKilled: result.OOMKilled,
   175  		},
   176  	}
   177  
   178  	return resp, nil
   179  }
   180  
   181  func (b *driverPluginServer) StopTask(ctx context.Context, req *proto.StopTaskRequest) (*proto.StopTaskResponse, error) {
   182  	timeout, err := ptypes.Duration(req.Timeout)
   183  	if err != nil {
   184  		return nil, err
   185  	}
   186  
   187  	err = b.impl.StopTask(req.TaskId, timeout, req.Signal)
   188  	if err != nil {
   189  		return nil, err
   190  	}
   191  	return &proto.StopTaskResponse{}, nil
   192  }
   193  
   194  func (b *driverPluginServer) DestroyTask(ctx context.Context, req *proto.DestroyTaskRequest) (*proto.DestroyTaskResponse, error) {
   195  	err := b.impl.DestroyTask(req.TaskId, req.Force)
   196  	if err != nil {
   197  		return nil, err
   198  	}
   199  	return &proto.DestroyTaskResponse{}, nil
   200  }
   201  
   202  func (b *driverPluginServer) InspectTask(ctx context.Context, req *proto.InspectTaskRequest) (*proto.InspectTaskResponse, error) {
   203  	status, err := b.impl.InspectTask(req.TaskId)
   204  	if err != nil {
   205  		return nil, err
   206  	}
   207  
   208  	protoStatus, err := taskStatusToProto(status)
   209  	if err != nil {
   210  		return nil, err
   211  	}
   212  
   213  	var pbNet *proto.NetworkOverride
   214  	if status.NetworkOverride != nil {
   215  		pbNet = &proto.NetworkOverride{
   216  			PortMap:       map[string]int32{},
   217  			Addr:          status.NetworkOverride.IP,
   218  			AutoAdvertise: status.NetworkOverride.AutoAdvertise,
   219  		}
   220  		for k, v := range status.NetworkOverride.PortMap {
   221  			pbNet.PortMap[k] = int32(v)
   222  		}
   223  	}
   224  
   225  	resp := &proto.InspectTaskResponse{
   226  		Task: protoStatus,
   227  		Driver: &proto.TaskDriverStatus{
   228  			Attributes: status.DriverAttributes,
   229  		},
   230  		NetworkOverride: pbNet,
   231  	}
   232  
   233  	return resp, nil
   234  }
   235  
   236  func (b *driverPluginServer) TaskStats(req *proto.TaskStatsRequest, srv proto.Driver_TaskStatsServer) error {
   237  	interval, err := ptypes.Duration(req.CollectionInterval)
   238  	if err != nil {
   239  		return fmt.Errorf("failed to parse collection interval: %v", err)
   240  	}
   241  
   242  	ch, err := b.impl.TaskStats(srv.Context(), req.TaskId, interval)
   243  	if err != nil {
   244  		if rec, ok := err.(structs.Recoverable); ok {
   245  			st := status.New(codes.FailedPrecondition, rec.Error())
   246  			st, err := st.WithDetails(&sproto.RecoverableError{Recoverable: rec.IsRecoverable()})
   247  			if err != nil {
   248  				// If this error, it will always error
   249  				panic(err)
   250  			}
   251  			return st.Err()
   252  		}
   253  		return err
   254  	}
   255  
   256  	for stats := range ch {
   257  		pb, err := TaskStatsToProto(stats)
   258  		if err != nil {
   259  			return fmt.Errorf("failed to encode task stats: %v", err)
   260  		}
   261  
   262  		if err = srv.Send(&proto.TaskStatsResponse{Stats: pb}); err == io.EOF {
   263  			break
   264  		} else if err != nil {
   265  			return err
   266  		}
   267  
   268  	}
   269  
   270  	return nil
   271  }
   272  
   273  func (b *driverPluginServer) ExecTask(ctx context.Context, req *proto.ExecTaskRequest) (*proto.ExecTaskResponse, error) {
   274  	timeout, err := ptypes.Duration(req.Timeout)
   275  	if err != nil {
   276  		return nil, err
   277  	}
   278  
   279  	result, err := b.impl.ExecTask(req.TaskId, req.Command, timeout)
   280  	if err != nil {
   281  		return nil, err
   282  	}
   283  	resp := &proto.ExecTaskResponse{
   284  		Stdout: result.Stdout,
   285  		Stderr: result.Stderr,
   286  		Result: exitResultToProto(result.ExitResult),
   287  	}
   288  
   289  	return resp, nil
   290  }
   291  
   292  func (b *driverPluginServer) ExecTaskStreaming(server proto.Driver_ExecTaskStreamingServer) error {
   293  	msg, err := server.Recv()
   294  	if err != nil {
   295  		return fmt.Errorf("failed to receive initial message: %v", err)
   296  	}
   297  
   298  	if msg.Setup == nil {
   299  		return fmt.Errorf("first message should always be setup")
   300  	}
   301  
   302  	if impl, ok := b.impl.(ExecTaskStreamingRawDriver); ok {
   303  		return impl.ExecTaskStreamingRaw(server.Context(),
   304  			msg.Setup.TaskId, msg.Setup.Command, msg.Setup.Tty,
   305  			server)
   306  	}
   307  
   308  	d, ok := b.impl.(ExecTaskStreamingDriver)
   309  	if !ok {
   310  		return fmt.Errorf("driver does not support exec")
   311  	}
   312  
   313  	execOpts, errCh := StreamToExecOptions(server.Context(),
   314  		msg.Setup.Command, msg.Setup.Tty,
   315  		server)
   316  
   317  	result, err := d.ExecTaskStreaming(server.Context(),
   318  		msg.Setup.TaskId, execOpts)
   319  
   320  	execOpts.Stdout.Close()
   321  	execOpts.Stderr.Close()
   322  
   323  	if err != nil {
   324  		return err
   325  	}
   326  
   327  	// wait for copy to be done
   328  	select {
   329  	case err = <-errCh:
   330  	case <-server.Context().Done():
   331  		err = fmt.Errorf("exec timed out: %v", server.Context().Err())
   332  	}
   333  
   334  	if err != nil {
   335  		return err
   336  	}
   337  
   338  	server.Send(&ExecTaskStreamingResponseMsg{
   339  		Exited: true,
   340  		Result: exitResultToProto(result),
   341  	})
   342  
   343  	return err
   344  }
   345  
   346  func (b *driverPluginServer) SignalTask(ctx context.Context, req *proto.SignalTaskRequest) (*proto.SignalTaskResponse, error) {
   347  	err := b.impl.SignalTask(req.TaskId, req.Signal)
   348  	if err != nil {
   349  		return nil, err
   350  	}
   351  
   352  	resp := &proto.SignalTaskResponse{}
   353  	return resp, nil
   354  }
   355  
   356  func (b *driverPluginServer) TaskEvents(req *proto.TaskEventsRequest, srv proto.Driver_TaskEventsServer) error {
   357  	ch, err := b.impl.TaskEvents(srv.Context())
   358  	if err != nil {
   359  		return err
   360  	}
   361  
   362  	for {
   363  		event := <-ch
   364  		if event == nil {
   365  			break
   366  		}
   367  		pbTimestamp, err := ptypes.TimestampProto(event.Timestamp)
   368  		if err != nil {
   369  			return err
   370  		}
   371  
   372  		pbEvent := &proto.DriverTaskEvent{
   373  			TaskId:      event.TaskID,
   374  			AllocId:     event.AllocID,
   375  			TaskName:    event.TaskName,
   376  			Timestamp:   pbTimestamp,
   377  			Message:     event.Message,
   378  			Annotations: event.Annotations,
   379  		}
   380  
   381  		if err = srv.Send(pbEvent); err == io.EOF {
   382  			break
   383  		} else if err != nil {
   384  			return err
   385  		}
   386  	}
   387  	return nil
   388  }
   389  
   390  func (b *driverPluginServer) CreateNetwork(ctx context.Context, req *proto.CreateNetworkRequest) (*proto.CreateNetworkResponse, error) {
   391  	nm, ok := b.impl.(DriverNetworkManager)
   392  	if !ok {
   393  		return nil, fmt.Errorf("CreateNetwork RPC not supported by driver")
   394  	}
   395  
   396  	spec, created, err := nm.CreateNetwork(req.GetAllocId(), networkCreateRequestFromProto(req))
   397  	if err != nil {
   398  		return nil, err
   399  	}
   400  
   401  	return &proto.CreateNetworkResponse{
   402  		IsolationSpec: NetworkIsolationSpecToProto(spec),
   403  		Created:       created,
   404  	}, nil
   405  }
   406  
   407  func (b *driverPluginServer) DestroyNetwork(ctx context.Context, req *proto.DestroyNetworkRequest) (*proto.DestroyNetworkResponse, error) {
   408  	nm, ok := b.impl.(DriverNetworkManager)
   409  	if !ok {
   410  		return nil, fmt.Errorf("DestroyNetwork RPC not supported by driver")
   411  	}
   412  
   413  	err := nm.DestroyNetwork(req.AllocId, NetworkIsolationSpecFromProto(req.IsolationSpec))
   414  	if err != nil {
   415  		return nil, err
   416  	}
   417  
   418  	return &proto.DestroyNetworkResponse{}, nil
   419  }