github.com/ferranbt/nomad@v0.9.3-0.20190607002617-85c449b7667c/plugins/drivers/server.go (about)

     1  package drivers
     2  
     3  import (
     4  	"fmt"
     5  	"io"
     6  
     7  	"github.com/golang/protobuf/ptypes"
     8  	plugin "github.com/hashicorp/go-plugin"
     9  	"github.com/hashicorp/nomad/nomad/structs"
    10  	"github.com/hashicorp/nomad/plugins/drivers/proto"
    11  	dstructs "github.com/hashicorp/nomad/plugins/shared/structs"
    12  	sproto "github.com/hashicorp/nomad/plugins/shared/structs/proto"
    13  	context "golang.org/x/net/context"
    14  	"google.golang.org/grpc/codes"
    15  	"google.golang.org/grpc/status"
    16  )
    17  
    18  type driverPluginServer struct {
    19  	broker *plugin.GRPCBroker
    20  	impl   DriverPlugin
    21  }
    22  
    23  func (b *driverPluginServer) TaskConfigSchema(ctx context.Context, req *proto.TaskConfigSchemaRequest) (*proto.TaskConfigSchemaResponse, error) {
    24  	spec, err := b.impl.TaskConfigSchema()
    25  	if err != nil {
    26  		return nil, err
    27  	}
    28  
    29  	resp := &proto.TaskConfigSchemaResponse{
    30  		Spec: spec,
    31  	}
    32  	return resp, nil
    33  }
    34  
    35  func (b *driverPluginServer) Capabilities(ctx context.Context, req *proto.CapabilitiesRequest) (*proto.CapabilitiesResponse, error) {
    36  	caps, err := b.impl.Capabilities()
    37  	if err != nil {
    38  		return nil, err
    39  	}
    40  	resp := &proto.CapabilitiesResponse{
    41  		Capabilities: &proto.DriverCapabilities{
    42  			SendSignals: caps.SendSignals,
    43  			Exec:        caps.Exec,
    44  		},
    45  	}
    46  
    47  	switch caps.FSIsolation {
    48  	case FSIsolationNone:
    49  		resp.Capabilities.FsIsolation = proto.DriverCapabilities_NONE
    50  	case FSIsolationChroot:
    51  		resp.Capabilities.FsIsolation = proto.DriverCapabilities_CHROOT
    52  	case FSIsolationImage:
    53  		resp.Capabilities.FsIsolation = proto.DriverCapabilities_IMAGE
    54  	default:
    55  		resp.Capabilities.FsIsolation = proto.DriverCapabilities_NONE
    56  	}
    57  	return resp, nil
    58  }
    59  
    60  func (b *driverPluginServer) Fingerprint(req *proto.FingerprintRequest, srv proto.Driver_FingerprintServer) error {
    61  	ctx := srv.Context()
    62  	ch, err := b.impl.Fingerprint(ctx)
    63  	if err != nil {
    64  		return err
    65  	}
    66  
    67  	for {
    68  		select {
    69  		case <-ctx.Done():
    70  			return nil
    71  		case f, ok := <-ch:
    72  
    73  			if !ok {
    74  				return nil
    75  			}
    76  			resp := &proto.FingerprintResponse{
    77  				Attributes:        dstructs.ConvertStructAttributeMap(f.Attributes),
    78  				Health:            healthStateToProto(f.Health),
    79  				HealthDescription: f.HealthDescription,
    80  			}
    81  
    82  			if err := srv.Send(resp); err != nil {
    83  				return err
    84  			}
    85  		}
    86  	}
    87  }
    88  
    89  func (b *driverPluginServer) RecoverTask(ctx context.Context, req *proto.RecoverTaskRequest) (*proto.RecoverTaskResponse, error) {
    90  	err := b.impl.RecoverTask(taskHandleFromProto(req.Handle))
    91  	if err != nil {
    92  		return nil, err
    93  	}
    94  
    95  	return &proto.RecoverTaskResponse{}, nil
    96  }
    97  
    98  func (b *driverPluginServer) StartTask(ctx context.Context, req *proto.StartTaskRequest) (*proto.StartTaskResponse, error) {
    99  	handle, net, err := b.impl.StartTask(taskConfigFromProto(req.Task))
   100  	if err != nil {
   101  		if rec, ok := err.(structs.Recoverable); ok {
   102  			st := status.New(codes.FailedPrecondition, rec.Error())
   103  			st, err := st.WithDetails(&sproto.RecoverableError{Recoverable: rec.IsRecoverable()})
   104  			if err != nil {
   105  				// If this error, it will always error
   106  				panic(err)
   107  			}
   108  			return nil, st.Err()
   109  		}
   110  		return nil, err
   111  	}
   112  
   113  	var pbNet *proto.NetworkOverride
   114  	if net != nil {
   115  		pbNet = &proto.NetworkOverride{
   116  			PortMap:       map[string]int32{},
   117  			Addr:          net.IP,
   118  			AutoAdvertise: net.AutoAdvertise,
   119  		}
   120  		for k, v := range net.PortMap {
   121  			pbNet.PortMap[k] = int32(v)
   122  		}
   123  	}
   124  
   125  	resp := &proto.StartTaskResponse{
   126  		Handle:          taskHandleToProto(handle),
   127  		NetworkOverride: pbNet,
   128  	}
   129  
   130  	return resp, nil
   131  }
   132  
   133  func (b *driverPluginServer) WaitTask(ctx context.Context, req *proto.WaitTaskRequest) (*proto.WaitTaskResponse, error) {
   134  	ch, err := b.impl.WaitTask(ctx, req.TaskId)
   135  	if err != nil {
   136  		return nil, err
   137  	}
   138  
   139  	var ok bool
   140  	var result *ExitResult
   141  	select {
   142  	case <-ctx.Done():
   143  		return nil, ctx.Err()
   144  	case result, ok = <-ch:
   145  		if !ok {
   146  			return &proto.WaitTaskResponse{
   147  				Err: "channel closed",
   148  			}, nil
   149  		}
   150  	}
   151  
   152  	var errStr string
   153  	if result.Err != nil {
   154  		errStr = result.Err.Error()
   155  	}
   156  
   157  	resp := &proto.WaitTaskResponse{
   158  		Err: errStr,
   159  		Result: &proto.ExitResult{
   160  			ExitCode:  int32(result.ExitCode),
   161  			Signal:    int32(result.Signal),
   162  			OomKilled: result.OOMKilled,
   163  		},
   164  	}
   165  
   166  	return resp, nil
   167  }
   168  
   169  func (b *driverPluginServer) StopTask(ctx context.Context, req *proto.StopTaskRequest) (*proto.StopTaskResponse, error) {
   170  	timeout, err := ptypes.Duration(req.Timeout)
   171  	if err != nil {
   172  		return nil, err
   173  	}
   174  
   175  	err = b.impl.StopTask(req.TaskId, timeout, req.Signal)
   176  	if err != nil {
   177  		return nil, err
   178  	}
   179  	return &proto.StopTaskResponse{}, nil
   180  }
   181  
   182  func (b *driverPluginServer) DestroyTask(ctx context.Context, req *proto.DestroyTaskRequest) (*proto.DestroyTaskResponse, error) {
   183  	err := b.impl.DestroyTask(req.TaskId, req.Force)
   184  	if err != nil {
   185  		return nil, err
   186  	}
   187  	return &proto.DestroyTaskResponse{}, nil
   188  }
   189  
   190  func (b *driverPluginServer) InspectTask(ctx context.Context, req *proto.InspectTaskRequest) (*proto.InspectTaskResponse, error) {
   191  	status, err := b.impl.InspectTask(req.TaskId)
   192  	if err != nil {
   193  		return nil, err
   194  	}
   195  
   196  	protoStatus, err := taskStatusToProto(status)
   197  	if err != nil {
   198  		return nil, err
   199  	}
   200  
   201  	var pbNet *proto.NetworkOverride
   202  	if status.NetworkOverride != nil {
   203  		pbNet = &proto.NetworkOverride{
   204  			PortMap:       map[string]int32{},
   205  			Addr:          status.NetworkOverride.IP,
   206  			AutoAdvertise: status.NetworkOverride.AutoAdvertise,
   207  		}
   208  		for k, v := range status.NetworkOverride.PortMap {
   209  			pbNet.PortMap[k] = int32(v)
   210  		}
   211  	}
   212  
   213  	resp := &proto.InspectTaskResponse{
   214  		Task: protoStatus,
   215  		Driver: &proto.TaskDriverStatus{
   216  			Attributes: status.DriverAttributes,
   217  		},
   218  		NetworkOverride: pbNet,
   219  	}
   220  
   221  	return resp, nil
   222  }
   223  
   224  func (b *driverPluginServer) TaskStats(req *proto.TaskStatsRequest, srv proto.Driver_TaskStatsServer) error {
   225  	interval, err := ptypes.Duration(req.CollectionInterval)
   226  	if err != nil {
   227  		return fmt.Errorf("failed to parse collection interval: %v", err)
   228  	}
   229  
   230  	ch, err := b.impl.TaskStats(srv.Context(), req.TaskId, interval)
   231  	if err != nil {
   232  		if rec, ok := err.(structs.Recoverable); ok {
   233  			st := status.New(codes.FailedPrecondition, rec.Error())
   234  			st, err := st.WithDetails(&sproto.RecoverableError{Recoverable: rec.IsRecoverable()})
   235  			if err != nil {
   236  				// If this error, it will always error
   237  				panic(err)
   238  			}
   239  			return st.Err()
   240  		}
   241  		return err
   242  	}
   243  
   244  	for stats := range ch {
   245  		pb, err := TaskStatsToProto(stats)
   246  		if err != nil {
   247  			return fmt.Errorf("failed to encode task stats: %v", err)
   248  		}
   249  
   250  		if err = srv.Send(&proto.TaskStatsResponse{Stats: pb}); err == io.EOF {
   251  			break
   252  		} else if err != nil {
   253  			return err
   254  		}
   255  
   256  	}
   257  
   258  	return nil
   259  }
   260  
   261  func (b *driverPluginServer) ExecTask(ctx context.Context, req *proto.ExecTaskRequest) (*proto.ExecTaskResponse, error) {
   262  	timeout, err := ptypes.Duration(req.Timeout)
   263  	if err != nil {
   264  		return nil, err
   265  	}
   266  
   267  	result, err := b.impl.ExecTask(req.TaskId, req.Command, timeout)
   268  	if err != nil {
   269  		return nil, err
   270  	}
   271  	resp := &proto.ExecTaskResponse{
   272  		Stdout: result.Stdout,
   273  		Stderr: result.Stderr,
   274  		Result: exitResultToProto(result.ExitResult),
   275  	}
   276  
   277  	return resp, nil
   278  }
   279  
   280  func (b *driverPluginServer) ExecTaskStreaming(server proto.Driver_ExecTaskStreamingServer) error {
   281  	msg, err := server.Recv()
   282  	if err != nil {
   283  		return fmt.Errorf("failed to receive initial message: %v", err)
   284  	}
   285  
   286  	if msg.Setup == nil {
   287  		return fmt.Errorf("first message should always be setup")
   288  	}
   289  
   290  	if impl, ok := b.impl.(ExecTaskStreamingRawDriver); ok {
   291  		return impl.ExecTaskStreamingRaw(server.Context(),
   292  			msg.Setup.TaskId, msg.Setup.Command, msg.Setup.Tty,
   293  			server)
   294  	}
   295  
   296  	d, ok := b.impl.(ExecTaskStreamingDriver)
   297  	if !ok {
   298  		return fmt.Errorf("driver does not support exec")
   299  	}
   300  
   301  	execOpts, errCh := StreamToExecOptions(server.Context(),
   302  		msg.Setup.Command, msg.Setup.Tty,
   303  		server)
   304  
   305  	result, err := d.ExecTaskStreaming(server.Context(),
   306  		msg.Setup.TaskId, execOpts)
   307  
   308  	execOpts.Stdout.Close()
   309  	execOpts.Stderr.Close()
   310  
   311  	if err != nil {
   312  		return err
   313  	}
   314  
   315  	// wait for copy to be done
   316  	select {
   317  	case err = <-errCh:
   318  	case <-server.Context().Done():
   319  		err = fmt.Errorf("exec timed out: %v", server.Context().Err())
   320  	}
   321  
   322  	if err != nil {
   323  		return err
   324  	}
   325  
   326  	server.Send(&ExecTaskStreamingResponseMsg{
   327  		Exited: true,
   328  		Result: exitResultToProto(result),
   329  	})
   330  
   331  	return err
   332  }
   333  
   334  func (b *driverPluginServer) SignalTask(ctx context.Context, req *proto.SignalTaskRequest) (*proto.SignalTaskResponse, error) {
   335  	err := b.impl.SignalTask(req.TaskId, req.Signal)
   336  	if err != nil {
   337  		return nil, err
   338  	}
   339  
   340  	resp := &proto.SignalTaskResponse{}
   341  	return resp, nil
   342  }
   343  
   344  func (b *driverPluginServer) TaskEvents(req *proto.TaskEventsRequest, srv proto.Driver_TaskEventsServer) error {
   345  	ch, err := b.impl.TaskEvents(srv.Context())
   346  	if err != nil {
   347  		return err
   348  	}
   349  
   350  	for {
   351  		event := <-ch
   352  		if event == nil {
   353  			break
   354  		}
   355  		pbTimestamp, err := ptypes.TimestampProto(event.Timestamp)
   356  		if err != nil {
   357  			return err
   358  		}
   359  
   360  		pbEvent := &proto.DriverTaskEvent{
   361  			TaskId:      event.TaskID,
   362  			AllocId:     event.AllocID,
   363  			TaskName:    event.TaskName,
   364  			Timestamp:   pbTimestamp,
   365  			Message:     event.Message,
   366  			Annotations: event.Annotations,
   367  		}
   368  
   369  		if err = srv.Send(pbEvent); err == io.EOF {
   370  			break
   371  		} else if err != nil {
   372  			return err
   373  		}
   374  	}
   375  	return nil
   376  }