github.com/rohankumardubey/nomad@v0.11.8/plugins/drivers/server.go (about)

     1  package drivers
     2  
     3  import (
     4  	"fmt"
     5  	"io"
     6  
     7  	"github.com/golang/protobuf/ptypes"
     8  	plugin "github.com/hashicorp/go-plugin"
     9  	"github.com/hashicorp/nomad/nomad/structs"
    10  	"github.com/hashicorp/nomad/plugins/drivers/proto"
    11  	dstructs "github.com/hashicorp/nomad/plugins/shared/structs"
    12  	sproto "github.com/hashicorp/nomad/plugins/shared/structs/proto"
    13  	context "golang.org/x/net/context"
    14  	"google.golang.org/grpc/codes"
    15  	"google.golang.org/grpc/status"
    16  )
    17  
    18  type driverPluginServer struct {
    19  	broker *plugin.GRPCBroker
    20  	impl   DriverPlugin
    21  }
    22  
    23  func (b *driverPluginServer) TaskConfigSchema(ctx context.Context, req *proto.TaskConfigSchemaRequest) (*proto.TaskConfigSchemaResponse, error) {
    24  	spec, err := b.impl.TaskConfigSchema()
    25  	if err != nil {
    26  		return nil, err
    27  	}
    28  
    29  	resp := &proto.TaskConfigSchemaResponse{
    30  		Spec: spec,
    31  	}
    32  	return resp, nil
    33  }
    34  
    35  func (b *driverPluginServer) Capabilities(ctx context.Context, req *proto.CapabilitiesRequest) (*proto.CapabilitiesResponse, error) {
    36  	caps, err := b.impl.Capabilities()
    37  	if err != nil {
    38  		return nil, err
    39  	}
    40  	resp := &proto.CapabilitiesResponse{
    41  		Capabilities: &proto.DriverCapabilities{
    42  			SendSignals:           caps.SendSignals,
    43  			Exec:                  caps.Exec,
    44  			MustCreateNetwork:     caps.MustInitiateNetwork,
    45  			NetworkIsolationModes: []proto.NetworkIsolationSpec_NetworkIsolationMode{},
    46  		},
    47  	}
    48  
    49  	switch caps.FSIsolation {
    50  	case FSIsolationNone:
    51  		resp.Capabilities.FsIsolation = proto.DriverCapabilities_NONE
    52  	case FSIsolationChroot:
    53  		resp.Capabilities.FsIsolation = proto.DriverCapabilities_CHROOT
    54  	case FSIsolationImage:
    55  		resp.Capabilities.FsIsolation = proto.DriverCapabilities_IMAGE
    56  	default:
    57  		resp.Capabilities.FsIsolation = proto.DriverCapabilities_NONE
    58  	}
    59  
    60  	for _, mode := range caps.NetIsolationModes {
    61  		resp.Capabilities.NetworkIsolationModes = append(resp.Capabilities.NetworkIsolationModes, netIsolationModeToProto(mode))
    62  	}
    63  	return resp, nil
    64  }
    65  
    66  func (b *driverPluginServer) Fingerprint(req *proto.FingerprintRequest, srv proto.Driver_FingerprintServer) error {
    67  	ctx := srv.Context()
    68  	ch, err := b.impl.Fingerprint(ctx)
    69  	if err != nil {
    70  		return err
    71  	}
    72  
    73  	for {
    74  		select {
    75  		case <-ctx.Done():
    76  			return nil
    77  		case f, ok := <-ch:
    78  
    79  			if !ok {
    80  				return nil
    81  			}
    82  			resp := &proto.FingerprintResponse{
    83  				Attributes:        dstructs.ConvertStructAttributeMap(f.Attributes),
    84  				Health:            healthStateToProto(f.Health),
    85  				HealthDescription: f.HealthDescription,
    86  			}
    87  
    88  			if err := srv.Send(resp); err != nil {
    89  				return err
    90  			}
    91  		}
    92  	}
    93  }
    94  
    95  func (b *driverPluginServer) RecoverTask(ctx context.Context, req *proto.RecoverTaskRequest) (*proto.RecoverTaskResponse, error) {
    96  	err := b.impl.RecoverTask(taskHandleFromProto(req.Handle))
    97  	if err != nil {
    98  		return nil, err
    99  	}
   100  
   101  	return &proto.RecoverTaskResponse{}, nil
   102  }
   103  
   104  func (b *driverPluginServer) StartTask(ctx context.Context, req *proto.StartTaskRequest) (*proto.StartTaskResponse, error) {
   105  	handle, net, err := b.impl.StartTask(taskConfigFromProto(req.Task))
   106  	if err != nil {
   107  		if rec, ok := err.(structs.Recoverable); ok {
   108  			st := status.New(codes.FailedPrecondition, rec.Error())
   109  			st, err := st.WithDetails(&sproto.RecoverableError{Recoverable: rec.IsRecoverable()})
   110  			if err != nil {
   111  				// If this error, it will always error
   112  				panic(err)
   113  			}
   114  			return nil, st.Err()
   115  		}
   116  		return nil, err
   117  	}
   118  
   119  	var pbNet *proto.NetworkOverride
   120  	if net != nil {
   121  		pbNet = &proto.NetworkOverride{
   122  			PortMap:       map[string]int32{},
   123  			Addr:          net.IP,
   124  			AutoAdvertise: net.AutoAdvertise,
   125  		}
   126  		for k, v := range net.PortMap {
   127  			pbNet.PortMap[k] = int32(v)
   128  		}
   129  	}
   130  
   131  	resp := &proto.StartTaskResponse{
   132  		Handle:          taskHandleToProto(handle),
   133  		NetworkOverride: pbNet,
   134  	}
   135  
   136  	return resp, nil
   137  }
   138  
   139  func (b *driverPluginServer) WaitTask(ctx context.Context, req *proto.WaitTaskRequest) (*proto.WaitTaskResponse, error) {
   140  	ch, err := b.impl.WaitTask(ctx, req.TaskId)
   141  	if err != nil {
   142  		return nil, err
   143  	}
   144  
   145  	var ok bool
   146  	var result *ExitResult
   147  	select {
   148  	case <-ctx.Done():
   149  		return nil, ctx.Err()
   150  	case result, ok = <-ch:
   151  		if !ok {
   152  			return &proto.WaitTaskResponse{
   153  				Err: "channel closed",
   154  			}, nil
   155  		}
   156  	}
   157  
   158  	var errStr string
   159  	if result.Err != nil {
   160  		errStr = result.Err.Error()
   161  	}
   162  
   163  	resp := &proto.WaitTaskResponse{
   164  		Err: errStr,
   165  		Result: &proto.ExitResult{
   166  			ExitCode:  int32(result.ExitCode),
   167  			Signal:    int32(result.Signal),
   168  			OomKilled: result.OOMKilled,
   169  		},
   170  	}
   171  
   172  	return resp, nil
   173  }
   174  
   175  func (b *driverPluginServer) StopTask(ctx context.Context, req *proto.StopTaskRequest) (*proto.StopTaskResponse, error) {
   176  	timeout, err := ptypes.Duration(req.Timeout)
   177  	if err != nil {
   178  		return nil, err
   179  	}
   180  
   181  	err = b.impl.StopTask(req.TaskId, timeout, req.Signal)
   182  	if err != nil {
   183  		return nil, err
   184  	}
   185  	return &proto.StopTaskResponse{}, nil
   186  }
   187  
   188  func (b *driverPluginServer) DestroyTask(ctx context.Context, req *proto.DestroyTaskRequest) (*proto.DestroyTaskResponse, error) {
   189  	err := b.impl.DestroyTask(req.TaskId, req.Force)
   190  	if err != nil {
   191  		return nil, err
   192  	}
   193  	return &proto.DestroyTaskResponse{}, nil
   194  }
   195  
   196  func (b *driverPluginServer) InspectTask(ctx context.Context, req *proto.InspectTaskRequest) (*proto.InspectTaskResponse, error) {
   197  	status, err := b.impl.InspectTask(req.TaskId)
   198  	if err != nil {
   199  		return nil, err
   200  	}
   201  
   202  	protoStatus, err := taskStatusToProto(status)
   203  	if err != nil {
   204  		return nil, err
   205  	}
   206  
   207  	var pbNet *proto.NetworkOverride
   208  	if status.NetworkOverride != nil {
   209  		pbNet = &proto.NetworkOverride{
   210  			PortMap:       map[string]int32{},
   211  			Addr:          status.NetworkOverride.IP,
   212  			AutoAdvertise: status.NetworkOverride.AutoAdvertise,
   213  		}
   214  		for k, v := range status.NetworkOverride.PortMap {
   215  			pbNet.PortMap[k] = int32(v)
   216  		}
   217  	}
   218  
   219  	resp := &proto.InspectTaskResponse{
   220  		Task: protoStatus,
   221  		Driver: &proto.TaskDriverStatus{
   222  			Attributes: status.DriverAttributes,
   223  		},
   224  		NetworkOverride: pbNet,
   225  	}
   226  
   227  	return resp, nil
   228  }
   229  
   230  func (b *driverPluginServer) TaskStats(req *proto.TaskStatsRequest, srv proto.Driver_TaskStatsServer) error {
   231  	interval, err := ptypes.Duration(req.CollectionInterval)
   232  	if err != nil {
   233  		return fmt.Errorf("failed to parse collection interval: %v", err)
   234  	}
   235  
   236  	ch, err := b.impl.TaskStats(srv.Context(), req.TaskId, interval)
   237  	if err != nil {
   238  		if rec, ok := err.(structs.Recoverable); ok {
   239  			st := status.New(codes.FailedPrecondition, rec.Error())
   240  			st, err := st.WithDetails(&sproto.RecoverableError{Recoverable: rec.IsRecoverable()})
   241  			if err != nil {
   242  				// If this error, it will always error
   243  				panic(err)
   244  			}
   245  			return st.Err()
   246  		}
   247  		return err
   248  	}
   249  
   250  	for stats := range ch {
   251  		pb, err := TaskStatsToProto(stats)
   252  		if err != nil {
   253  			return fmt.Errorf("failed to encode task stats: %v", err)
   254  		}
   255  
   256  		if err = srv.Send(&proto.TaskStatsResponse{Stats: pb}); err == io.EOF {
   257  			break
   258  		} else if err != nil {
   259  			return err
   260  		}
   261  
   262  	}
   263  
   264  	return nil
   265  }
   266  
   267  func (b *driverPluginServer) ExecTask(ctx context.Context, req *proto.ExecTaskRequest) (*proto.ExecTaskResponse, error) {
   268  	timeout, err := ptypes.Duration(req.Timeout)
   269  	if err != nil {
   270  		return nil, err
   271  	}
   272  
   273  	result, err := b.impl.ExecTask(req.TaskId, req.Command, timeout)
   274  	if err != nil {
   275  		return nil, err
   276  	}
   277  	resp := &proto.ExecTaskResponse{
   278  		Stdout: result.Stdout,
   279  		Stderr: result.Stderr,
   280  		Result: exitResultToProto(result.ExitResult),
   281  	}
   282  
   283  	return resp, nil
   284  }
   285  
   286  func (b *driverPluginServer) ExecTaskStreaming(server proto.Driver_ExecTaskStreamingServer) error {
   287  	msg, err := server.Recv()
   288  	if err != nil {
   289  		return fmt.Errorf("failed to receive initial message: %v", err)
   290  	}
   291  
   292  	if msg.Setup == nil {
   293  		return fmt.Errorf("first message should always be setup")
   294  	}
   295  
   296  	if impl, ok := b.impl.(ExecTaskStreamingRawDriver); ok {
   297  		return impl.ExecTaskStreamingRaw(server.Context(),
   298  			msg.Setup.TaskId, msg.Setup.Command, msg.Setup.Tty,
   299  			server)
   300  	}
   301  
   302  	d, ok := b.impl.(ExecTaskStreamingDriver)
   303  	if !ok {
   304  		return fmt.Errorf("driver does not support exec")
   305  	}
   306  
   307  	execOpts, errCh := StreamToExecOptions(server.Context(),
   308  		msg.Setup.Command, msg.Setup.Tty,
   309  		server)
   310  
   311  	result, err := d.ExecTaskStreaming(server.Context(),
   312  		msg.Setup.TaskId, execOpts)
   313  
   314  	execOpts.Stdout.Close()
   315  	execOpts.Stderr.Close()
   316  
   317  	if err != nil {
   318  		return err
   319  	}
   320  
   321  	// wait for copy to be done
   322  	select {
   323  	case err = <-errCh:
   324  	case <-server.Context().Done():
   325  		err = fmt.Errorf("exec timed out: %v", server.Context().Err())
   326  	}
   327  
   328  	if err != nil {
   329  		return err
   330  	}
   331  
   332  	server.Send(&ExecTaskStreamingResponseMsg{
   333  		Exited: true,
   334  		Result: exitResultToProto(result),
   335  	})
   336  
   337  	return err
   338  }
   339  
   340  func (b *driverPluginServer) SignalTask(ctx context.Context, req *proto.SignalTaskRequest) (*proto.SignalTaskResponse, error) {
   341  	err := b.impl.SignalTask(req.TaskId, req.Signal)
   342  	if err != nil {
   343  		return nil, err
   344  	}
   345  
   346  	resp := &proto.SignalTaskResponse{}
   347  	return resp, nil
   348  }
   349  
   350  func (b *driverPluginServer) TaskEvents(req *proto.TaskEventsRequest, srv proto.Driver_TaskEventsServer) error {
   351  	ch, err := b.impl.TaskEvents(srv.Context())
   352  	if err != nil {
   353  		return err
   354  	}
   355  
   356  	for {
   357  		event := <-ch
   358  		if event == nil {
   359  			break
   360  		}
   361  		pbTimestamp, err := ptypes.TimestampProto(event.Timestamp)
   362  		if err != nil {
   363  			return err
   364  		}
   365  
   366  		pbEvent := &proto.DriverTaskEvent{
   367  			TaskId:      event.TaskID,
   368  			AllocId:     event.AllocID,
   369  			TaskName:    event.TaskName,
   370  			Timestamp:   pbTimestamp,
   371  			Message:     event.Message,
   372  			Annotations: event.Annotations,
   373  		}
   374  
   375  		if err = srv.Send(pbEvent); err == io.EOF {
   376  			break
   377  		} else if err != nil {
   378  			return err
   379  		}
   380  	}
   381  	return nil
   382  }
   383  
   384  func (b *driverPluginServer) CreateNetwork(ctx context.Context, req *proto.CreateNetworkRequest) (*proto.CreateNetworkResponse, error) {
   385  	nm, ok := b.impl.(DriverNetworkManager)
   386  	if !ok {
   387  		return nil, fmt.Errorf("CreateNetwork RPC not supported by driver")
   388  	}
   389  
   390  	spec, created, err := nm.CreateNetwork(req.AllocId)
   391  	if err != nil {
   392  		return nil, err
   393  	}
   394  
   395  	return &proto.CreateNetworkResponse{
   396  		IsolationSpec: NetworkIsolationSpecToProto(spec),
   397  		Created:       created,
   398  	}, nil
   399  
   400  }
   401  
   402  func (b *driverPluginServer) DestroyNetwork(ctx context.Context, req *proto.DestroyNetworkRequest) (*proto.DestroyNetworkResponse, error) {
   403  	nm, ok := b.impl.(DriverNetworkManager)
   404  	if !ok {
   405  		return nil, fmt.Errorf("DestroyNetwork RPC not supported by driver")
   406  	}
   407  
   408  	err := nm.DestroyNetwork(req.AllocId, NetworkIsolationSpecFromProto(req.IsolationSpec))
   409  	if err != nil {
   410  		return nil, err
   411  	}
   412  
   413  	return &proto.DestroyNetworkResponse{}, nil
   414  }