github.com/anth0d/nomad@v0.0.0-20221214183521-ae3a0a2cad06/plugins/drivers/client.go (about)

     1  package drivers
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"io"
     7  	"time"
     8  
     9  	"github.com/LK4D4/joincontext"
    10  	"github.com/golang/protobuf/ptypes"
    11  	hclog "github.com/hashicorp/go-hclog"
    12  	cstructs "github.com/hashicorp/nomad/client/structs"
    13  	"github.com/hashicorp/nomad/helper/pluginutils/grpcutils"
    14  	"github.com/hashicorp/nomad/nomad/structs"
    15  	"github.com/hashicorp/nomad/plugins/base"
    16  	"github.com/hashicorp/nomad/plugins/drivers/proto"
    17  	"github.com/hashicorp/nomad/plugins/shared/hclspec"
    18  	pstructs "github.com/hashicorp/nomad/plugins/shared/structs"
    19  	sproto "github.com/hashicorp/nomad/plugins/shared/structs/proto"
    20  	"google.golang.org/grpc/status"
    21  )
    22  
    23  var _ DriverPlugin = &driverPluginClient{}
    24  
    25  type driverPluginClient struct {
    26  	*base.BasePluginClient
    27  
    28  	client proto.DriverClient
    29  	logger hclog.Logger
    30  
    31  	// doneCtx is closed when the plugin exits
    32  	doneCtx context.Context
    33  }
    34  
    35  func (d *driverPluginClient) TaskConfigSchema() (*hclspec.Spec, error) {
    36  	req := &proto.TaskConfigSchemaRequest{}
    37  
    38  	resp, err := d.client.TaskConfigSchema(d.doneCtx, req)
    39  	if err != nil {
    40  		return nil, grpcutils.HandleGrpcErr(err, d.doneCtx)
    41  	}
    42  
    43  	return resp.Spec, nil
    44  }
    45  
    46  func (d *driverPluginClient) Capabilities() (*Capabilities, error) {
    47  	req := &proto.CapabilitiesRequest{}
    48  
    49  	resp, err := d.client.Capabilities(d.doneCtx, req)
    50  	if err != nil {
    51  		return nil, grpcutils.HandleGrpcErr(err, d.doneCtx)
    52  	}
    53  
    54  	caps := &Capabilities{}
    55  	if resp.Capabilities != nil {
    56  		caps.SendSignals = resp.Capabilities.SendSignals
    57  		caps.Exec = resp.Capabilities.Exec
    58  		caps.MustInitiateNetwork = resp.Capabilities.MustCreateNetwork
    59  
    60  		for _, mode := range resp.Capabilities.NetworkIsolationModes {
    61  			caps.NetIsolationModes = append(caps.NetIsolationModes, netIsolationModeFromProto(mode))
    62  		}
    63  
    64  		switch resp.Capabilities.FsIsolation {
    65  		case proto.DriverCapabilities_NONE:
    66  			caps.FSIsolation = FSIsolationNone
    67  		case proto.DriverCapabilities_CHROOT:
    68  			caps.FSIsolation = FSIsolationChroot
    69  		case proto.DriverCapabilities_IMAGE:
    70  			caps.FSIsolation = FSIsolationImage
    71  		default:
    72  			caps.FSIsolation = FSIsolationNone
    73  		}
    74  
    75  		caps.MountConfigs = MountConfigSupport(resp.Capabilities.MountConfigs)
    76  		caps.RemoteTasks = resp.Capabilities.RemoteTasks
    77  	}
    78  
    79  	return caps, nil
    80  }
    81  
    82  // Fingerprint the driver, return a chan that will be pushed to periodically and on changes to health
    83  func (d *driverPluginClient) Fingerprint(ctx context.Context) (<-chan *Fingerprint, error) {
    84  	req := &proto.FingerprintRequest{}
    85  
    86  	// Join the passed context and the shutdown context
    87  	joinedCtx, _ := joincontext.Join(ctx, d.doneCtx)
    88  
    89  	stream, err := d.client.Fingerprint(joinedCtx, req)
    90  	if err != nil {
    91  		return nil, grpcutils.HandleReqCtxGrpcErr(err, ctx, d.doneCtx)
    92  	}
    93  
    94  	ch := make(chan *Fingerprint, 1)
    95  	go d.handleFingerprint(ctx, ch, stream)
    96  
    97  	return ch, nil
    98  }
    99  
   100  func (d *driverPluginClient) handleFingerprint(reqCtx context.Context, ch chan *Fingerprint, stream proto.Driver_FingerprintClient) {
   101  	defer close(ch)
   102  	for {
   103  		pb, err := stream.Recv()
   104  		if err != nil {
   105  			if err != io.EOF {
   106  				ch <- &Fingerprint{
   107  					Err: grpcutils.HandleReqCtxGrpcErr(err, reqCtx, d.doneCtx),
   108  				}
   109  			}
   110  
   111  			// End the stream
   112  			return
   113  		}
   114  
   115  		f := &Fingerprint{
   116  			Attributes:        pstructs.ConvertProtoAttributeMap(pb.Attributes),
   117  			Health:            healthStateFromProto(pb.Health),
   118  			HealthDescription: pb.HealthDescription,
   119  		}
   120  
   121  		select {
   122  		case <-reqCtx.Done():
   123  			return
   124  		case ch <- f:
   125  		}
   126  	}
   127  }
   128  
   129  // RecoverTask does internal state recovery to be able to control the task of
   130  // the given TaskHandle
   131  func (d *driverPluginClient) RecoverTask(h *TaskHandle) error {
   132  	req := &proto.RecoverTaskRequest{Handle: taskHandleToProto(h)}
   133  
   134  	_, err := d.client.RecoverTask(d.doneCtx, req)
   135  	return grpcutils.HandleGrpcErr(err, d.doneCtx)
   136  }
   137  
   138  // StartTask starts execution of a task with the given TaskConfig. A TaskHandle
   139  // is returned to the caller that can be used to recover state of the task,
   140  // should the driver crash or exit prematurely.
   141  func (d *driverPluginClient) StartTask(c *TaskConfig) (*TaskHandle, *DriverNetwork, error) {
   142  	req := &proto.StartTaskRequest{
   143  		Task: taskConfigToProto(c),
   144  	}
   145  
   146  	resp, err := d.client.StartTask(d.doneCtx, req)
   147  	if err != nil {
   148  		st := status.Convert(err)
   149  		if len(st.Details()) > 0 {
   150  			if rec, ok := st.Details()[0].(*sproto.RecoverableError); ok {
   151  				return nil, nil, structs.NewRecoverableError(err, rec.Recoverable)
   152  			}
   153  		}
   154  		return nil, nil, grpcutils.HandleGrpcErr(err, d.doneCtx)
   155  	}
   156  
   157  	var net *DriverNetwork
   158  	if resp.NetworkOverride != nil {
   159  		net = &DriverNetwork{
   160  			PortMap:       map[string]int{},
   161  			IP:            resp.NetworkOverride.Addr,
   162  			AutoAdvertise: resp.NetworkOverride.AutoAdvertise,
   163  		}
   164  		for k, v := range resp.NetworkOverride.PortMap {
   165  			net.PortMap[k] = int(v)
   166  		}
   167  	}
   168  
   169  	return taskHandleFromProto(resp.Handle), net, nil
   170  }
   171  
   172  // WaitTask returns a channel that will have an ExitResult pushed to it once when the task
   173  // exits on its own or is killed. If WaitTask is called after the task has exited, the channel
   174  // will immedialy return the ExitResult. WaitTask can be called multiple times for
   175  // the same task without issue.
   176  func (d *driverPluginClient) WaitTask(ctx context.Context, id string) (<-chan *ExitResult, error) {
   177  	ch := make(chan *ExitResult)
   178  	go d.handleWaitTask(ctx, id, ch)
   179  	return ch, nil
   180  }
   181  
   182  func (d *driverPluginClient) handleWaitTask(ctx context.Context, id string, ch chan *ExitResult) {
   183  	defer close(ch)
   184  	var result ExitResult
   185  	req := &proto.WaitTaskRequest{
   186  		TaskId: id,
   187  	}
   188  
   189  	// Join the passed context and the shutdown context
   190  	joinedCtx, joinedCtxCancel := joincontext.Join(ctx, d.doneCtx)
   191  	defer joinedCtxCancel()
   192  
   193  	resp, err := d.client.WaitTask(joinedCtx, req)
   194  	if err != nil {
   195  		result.Err = grpcutils.HandleReqCtxGrpcErr(err, ctx, d.doneCtx)
   196  	} else {
   197  		result.ExitCode = int(resp.Result.ExitCode)
   198  		result.Signal = int(resp.Result.Signal)
   199  		result.OOMKilled = resp.Result.OomKilled
   200  		if len(resp.Err) > 0 {
   201  			result.Err = errors.New(resp.Err)
   202  		}
   203  	}
   204  	ch <- &result
   205  }
   206  
   207  // StopTask stops the task with the given taskID. A timeout and signal can be
   208  // given to control a graceful termination of the task. The driver will send the
   209  // given signal to the task and wait for the given timeout for it to exit. If the
   210  // task does not exit within the timeout it will be forcefully killed.
   211  func (d *driverPluginClient) StopTask(taskID string, timeout time.Duration, signal string) error {
   212  	req := &proto.StopTaskRequest{
   213  		TaskId:  taskID,
   214  		Timeout: ptypes.DurationProto(timeout),
   215  		Signal:  signal,
   216  	}
   217  
   218  	_, err := d.client.StopTask(d.doneCtx, req)
   219  	return grpcutils.HandleGrpcErr(err, d.doneCtx)
   220  }
   221  
   222  // DestroyTask removes the task from the driver's in memory state. The task
   223  // cannot be running unless force is set to true. If force is set to true the
   224  // driver will forcefully terminate the task before removing it.
   225  func (d *driverPluginClient) DestroyTask(taskID string, force bool) error {
   226  	req := &proto.DestroyTaskRequest{
   227  		TaskId: taskID,
   228  		Force:  force,
   229  	}
   230  
   231  	_, err := d.client.DestroyTask(d.doneCtx, req)
   232  	return grpcutils.HandleGrpcErr(err, d.doneCtx)
   233  }
   234  
   235  // InspectTask returns status information for a task
   236  func (d *driverPluginClient) InspectTask(taskID string) (*TaskStatus, error) {
   237  	req := &proto.InspectTaskRequest{TaskId: taskID}
   238  
   239  	resp, err := d.client.InspectTask(d.doneCtx, req)
   240  	if err != nil {
   241  		return nil, grpcutils.HandleGrpcErr(err, d.doneCtx)
   242  	}
   243  
   244  	status, err := taskStatusFromProto(resp.Task)
   245  	if err != nil {
   246  		return nil, err
   247  	}
   248  
   249  	if resp.Driver != nil {
   250  		status.DriverAttributes = resp.Driver.Attributes
   251  	}
   252  	if resp.NetworkOverride != nil {
   253  		status.NetworkOverride = &DriverNetwork{
   254  			PortMap:       map[string]int{},
   255  			IP:            resp.NetworkOverride.Addr,
   256  			AutoAdvertise: resp.NetworkOverride.AutoAdvertise,
   257  		}
   258  		for k, v := range resp.NetworkOverride.PortMap {
   259  			status.NetworkOverride.PortMap[k] = int(v)
   260  		}
   261  	}
   262  
   263  	return status, nil
   264  }
   265  
   266  // TaskStats returns resource usage statistics for the task
   267  func (d *driverPluginClient) TaskStats(ctx context.Context, taskID string, interval time.Duration) (<-chan *cstructs.TaskResourceUsage, error) {
   268  	req := &proto.TaskStatsRequest{
   269  		TaskId:             taskID,
   270  		CollectionInterval: ptypes.DurationProto(interval),
   271  	}
   272  	ctx, _ = joincontext.Join(ctx, d.doneCtx)
   273  	stream, err := d.client.TaskStats(ctx, req)
   274  	if err != nil {
   275  		st := status.Convert(err)
   276  		if len(st.Details()) > 0 {
   277  			if rec, ok := st.Details()[0].(*sproto.RecoverableError); ok {
   278  				return nil, structs.NewRecoverableError(err, rec.Recoverable)
   279  			}
   280  		}
   281  		return nil, grpcutils.HandleGrpcErr(err, d.doneCtx)
   282  	}
   283  
   284  	ch := make(chan *cstructs.TaskResourceUsage, 1)
   285  	go d.handleStats(ctx, ch, stream)
   286  
   287  	return ch, nil
   288  }
   289  
   290  func (d *driverPluginClient) handleStats(ctx context.Context, ch chan<- *cstructs.TaskResourceUsage, stream proto.Driver_TaskStatsClient) {
   291  	defer close(ch)
   292  	for {
   293  		resp, err := stream.Recv()
   294  		if ctx.Err() != nil {
   295  			// Context canceled; exit gracefully
   296  			return
   297  		}
   298  
   299  		if err != nil {
   300  			if err != io.EOF {
   301  				d.logger.Error("error receiving stream from TaskStats driver RPC, closing stream", "error", err)
   302  			}
   303  
   304  			// End of stream
   305  			return
   306  		}
   307  
   308  		stats, err := TaskStatsFromProto(resp.Stats)
   309  		if err != nil {
   310  			d.logger.Error("failed to decode stats from RPC", "error", err, "stats", resp.Stats)
   311  			continue
   312  		}
   313  
   314  		select {
   315  		case ch <- stats:
   316  		case <-ctx.Done():
   317  			return
   318  		}
   319  	}
   320  }
   321  
   322  // TaskEvents returns a channel that will receive events from the driver about all
   323  // tasks such as lifecycle events, terminal errors, etc.
   324  func (d *driverPluginClient) TaskEvents(ctx context.Context) (<-chan *TaskEvent, error) {
   325  	req := &proto.TaskEventsRequest{}
   326  
   327  	// Join the passed context and the shutdown context
   328  	joinedCtx, _ := joincontext.Join(ctx, d.doneCtx)
   329  
   330  	stream, err := d.client.TaskEvents(joinedCtx, req)
   331  	if err != nil {
   332  		return nil, grpcutils.HandleReqCtxGrpcErr(err, ctx, d.doneCtx)
   333  	}
   334  
   335  	ch := make(chan *TaskEvent, 1)
   336  	go d.handleTaskEvents(ctx, ch, stream)
   337  	return ch, nil
   338  }
   339  
   340  func (d *driverPluginClient) handleTaskEvents(reqCtx context.Context, ch chan *TaskEvent, stream proto.Driver_TaskEventsClient) {
   341  	defer close(ch)
   342  	for {
   343  		ev, err := stream.Recv()
   344  		if err != nil {
   345  			if err != io.EOF {
   346  				ch <- &TaskEvent{
   347  					Err: grpcutils.HandleReqCtxGrpcErr(err, reqCtx, d.doneCtx),
   348  				}
   349  			}
   350  
   351  			// End the stream
   352  			return
   353  		}
   354  
   355  		timestamp, _ := ptypes.Timestamp(ev.Timestamp)
   356  		event := &TaskEvent{
   357  			TaskID:      ev.TaskId,
   358  			AllocID:     ev.AllocId,
   359  			TaskName:    ev.TaskName,
   360  			Annotations: ev.Annotations,
   361  			Message:     ev.Message,
   362  			Timestamp:   timestamp,
   363  		}
   364  		select {
   365  		case <-reqCtx.Done():
   366  			return
   367  		case ch <- event:
   368  		}
   369  	}
   370  }
   371  
   372  // SignalTask will send the given signal to the specified task
   373  func (d *driverPluginClient) SignalTask(taskID string, signal string) error {
   374  	req := &proto.SignalTaskRequest{
   375  		TaskId: taskID,
   376  		Signal: signal,
   377  	}
   378  	_, err := d.client.SignalTask(d.doneCtx, req)
   379  	return grpcutils.HandleGrpcErr(err, d.doneCtx)
   380  }
   381  
   382  // ExecTask will run the given command within the execution context of the task.
   383  // The driver will wait for the given timeout for the command to complete before
   384  // terminating it. The stdout and stderr of the command will be return to the caller,
   385  // along with other exit information such as exit code.
   386  func (d *driverPluginClient) ExecTask(taskID string, cmd []string, timeout time.Duration) (*ExecTaskResult, error) {
   387  	req := &proto.ExecTaskRequest{
   388  		TaskId:  taskID,
   389  		Command: cmd,
   390  		Timeout: ptypes.DurationProto(timeout),
   391  	}
   392  
   393  	resp, err := d.client.ExecTask(d.doneCtx, req)
   394  	if err != nil {
   395  		return nil, grpcutils.HandleGrpcErr(err, d.doneCtx)
   396  	}
   397  
   398  	result := &ExecTaskResult{
   399  		Stdout:     resp.Stdout,
   400  		Stderr:     resp.Stderr,
   401  		ExitResult: exitResultFromProto(resp.Result),
   402  	}
   403  
   404  	return result, nil
   405  }
   406  
   407  var _ ExecTaskStreamingRawDriver = (*driverPluginClient)(nil)
   408  
   409  func (d *driverPluginClient) ExecTaskStreamingRaw(ctx context.Context,
   410  	taskID string,
   411  	command []string,
   412  	tty bool,
   413  	execStream ExecTaskStream) error {
   414  
   415  	stream, err := d.client.ExecTaskStreaming(ctx)
   416  	if err != nil {
   417  		return grpcutils.HandleGrpcErr(err, d.doneCtx)
   418  	}
   419  
   420  	err = stream.Send(&proto.ExecTaskStreamingRequest{
   421  		Setup: &proto.ExecTaskStreamingRequest_Setup{
   422  			TaskId:  taskID,
   423  			Command: command,
   424  			Tty:     tty,
   425  		},
   426  	})
   427  	if err != nil {
   428  		return grpcutils.HandleGrpcErr(err, d.doneCtx)
   429  	}
   430  
   431  	errCh := make(chan error, 1)
   432  
   433  	go func() {
   434  		for {
   435  			m, err := execStream.Recv()
   436  			if err == io.EOF {
   437  				return
   438  			} else if err != nil {
   439  				errCh <- err
   440  				return
   441  			}
   442  
   443  			if err := stream.Send(m); err != nil {
   444  				errCh <- err
   445  				return
   446  			}
   447  
   448  		}
   449  	}()
   450  
   451  	for {
   452  		select {
   453  		case err := <-errCh:
   454  			return err
   455  		default:
   456  		}
   457  
   458  		m, err := stream.Recv()
   459  		if err == io.EOF {
   460  			// Once we get to the end of stream successfully, we can ignore errCh:
   461  			// e.g. input write failures after process terminates shouldn't cause method to fail
   462  			return nil
   463  		} else if err != nil {
   464  			return err
   465  		}
   466  
   467  		if err := execStream.Send(m); err != nil {
   468  			return err
   469  		}
   470  	}
   471  }
   472  
   473  var _ DriverNetworkManager = (*driverPluginClient)(nil)
   474  
   475  func (d *driverPluginClient) CreateNetwork(allocID string, _ *NetworkCreateRequest) (*NetworkIsolationSpec, bool, error) {
   476  	req := &proto.CreateNetworkRequest{
   477  		AllocId: allocID,
   478  	}
   479  
   480  	resp, err := d.client.CreateNetwork(d.doneCtx, req)
   481  	if err != nil {
   482  		return nil, false, grpcutils.HandleGrpcErr(err, d.doneCtx)
   483  	}
   484  
   485  	return NetworkIsolationSpecFromProto(resp.IsolationSpec), resp.Created, nil
   486  }
   487  
   488  func (d *driverPluginClient) DestroyNetwork(allocID string, spec *NetworkIsolationSpec) error {
   489  	req := &proto.DestroyNetworkRequest{
   490  		AllocId:       allocID,
   491  		IsolationSpec: NetworkIsolationSpecToProto(spec),
   492  	}
   493  
   494  	_, err := d.client.DestroyNetwork(d.doneCtx, req)
   495  	if err != nil {
   496  		return grpcutils.HandleGrpcErr(err, d.doneCtx)
   497  	}
   498  
   499  	return nil
   500  }