github.com/whamcloud/lemur@v0.0.0-20190827193804-4655df8a52af/cmd/lhsmd/transport/grpc/rpc.go (about)

     1  // Copyright (c) 2018 DDN. All rights reserved.
     2  // Use of this source code is governed by a MIT-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package rpc
     6  
     7  import (
     8  	"fmt"
     9  	"net"
    10  	"os"
    11  	"sync"
    12  	"time"
    13  
    14  	"github.com/pkg/errors"
    15  
    16  	"github.com/intel-hpdd/lemur/cmd/lhsmd/agent"
    17  	pb "github.com/intel-hpdd/lemur/pdm"
    18  	"github.com/intel-hpdd/logging/debug"
    19  	"golang.org/x/net/context"
    20  	"google.golang.org/grpc"
    21  )
    22  
    23  const (
    24  	// TransportType is the name of this transport
    25  	TransportType = "grpc"
    26  	// Connected indicates a connected endpoint
    27  	Connected = EndpointState(iota)
    28  	// Disconnected indicates a disconnected endpoint
    29  	Disconnected
    30  )
    31  
    32  type (
    33  	rpcTransport struct {
    34  		mu     sync.Mutex
    35  		server *grpc.Server
    36  	}
    37  
    38  	dmRPCServer struct {
    39  		stats *messageStats
    40  		agent *agent.HsmAgent
    41  	}
    42  
    43  	// EndpointState represents the connectedness state of an Endpoint
    44  	EndpointState int
    45  
    46  	// AgentEndpoint represents the agent side of a data mover connection
    47  	AgentEndpoint struct {
    48  		state    EndpointState
    49  		actionCh chan *agent.Action
    50  		mu       sync.Mutex
    51  		actions  map[agent.ActionID]*agent.Action
    52  	}
    53  )
    54  
    55  func init() {
    56  	agent.RegisterTransport(TransportType, &rpcTransport{})
    57  }
    58  
    59  func (t *rpcTransport) Init(conf *agent.Config, a *agent.HsmAgent) error {
    60  	if conf.Transport.Type != TransportType {
    61  		return nil
    62  	}
    63  
    64  	debug.Printf("Initializing grpc transport: %s", conf.Transport.ConnectionString())
    65  
    66  	// Ensure path is a directory and create if needed
    67  	if err := os.MkdirAll(conf.Transport.SocketDir, 0755); err != nil {
    68  		return errors.Wrap(err, "MkdirAll")
    69  	}
    70  
    71  	sock, err := net.Listen("unix", conf.Transport.ConnectionString())
    72  	if err != nil {
    73  		return errors.Errorf("Failed to listen: %v", err)
    74  	}
    75  
    76  	t.mu.Lock()
    77  	t.server = grpc.NewServer()
    78  	t.mu.Unlock()
    79  	pb.RegisterDataMoverServer(t.server, newServer(a))
    80  	go t.server.Serve(sock)
    81  
    82  	return nil
    83  }
    84  
    85  func (t *rpcTransport) Shutdown() {
    86  	t.mu.Lock()
    87  	t.server.Stop()
    88  	t.mu.Unlock()
    89  	debug.Print("shut down grpc transport")
    90  }
    91  
    92  // Send delivers an agent action to the backend
    93  func (ep *AgentEndpoint) Send(action *agent.Action) {
    94  	ep.actionCh <- action
    95  }
    96  
    97  // Register a data mover backend (aka Endpoint). When a backend starts, it first must
    98  // identify itself and its archive ID with the agent. The agent returns a unique
    99  // cookie that the backend uses for the rest of that session.
   100  //
   101  // If the Endpoint for this archive id already exists and is Connected, then this means
   102  // this is already a backend receiving messages for this archive, and we reject
   103  // this registration.  If it exists and is Disconnected, then currently the new backend
   104  // takes over this Endpoint. Existing in progress messages should be flushed, however.
   105  func (s *dmRPCServer) Register(context context.Context, e *pb.Endpoint) (*pb.Handle, error) {
   106  	ep, ok := s.agent.Endpoints.Get(e.Archive)
   107  	var handle *agent.Handle
   108  	var err error
   109  	if ok {
   110  		rpcEp, ok := ep.(*AgentEndpoint)
   111  		if !ok {
   112  			debug.Printf("not an rpc endpoint: %#v", ep)
   113  			return nil, errors.Errorf("not an rpc endpoint: %#v", ep)
   114  		}
   115  		if rpcEp.state == Connected {
   116  			debug.Printf("register rejected for  %v already connected", e)
   117  			return nil, errors.New("Archived already connected")
   118  		}
   119  		// TODO: should flush and perhaps even delete the existing Endpoint
   120  		// instead of just reusing it.
   121  		handle, err = s.agent.Endpoints.NewHandle(e.Archive)
   122  		if err != nil {
   123  			return nil, err
   124  		}
   125  	} else {
   126  		handle, err = s.agent.Endpoints.Add(e.Archive, &AgentEndpoint{
   127  			state:    Disconnected,
   128  			actions:  make(map[agent.ActionID]*agent.Action),
   129  			actionCh: make(chan *agent.Action),
   130  		})
   131  		if err != nil {
   132  			return nil, err
   133  		}
   134  	}
   135  	return &pb.Handle{Id: uint64(*handle)}, nil
   136  
   137  }
   138  
   139  // GetActions establish a connection the backend for a particular archive ID. The Endpoint
   140  // remains in Connected status as long as the backend is receiving messages from the agent.
   141  func (s *dmRPCServer) GetActions(h *pb.Handle, stream pb.DataMover_GetActionsServer) error {
   142  	temp, ok := s.agent.Endpoints.GetWithHandle((*agent.Handle)(&h.Id))
   143  	if !ok {
   144  		debug.Printf("bad cookie  %v", h.Id)
   145  		return errors.New("bad cookie")
   146  	}
   147  	ep, ok := temp.(*AgentEndpoint)
   148  	if !ok {
   149  		debug.Printf("not an rpc endpoint: %#v", ep)
   150  		return errors.Errorf("not an rpc endpoint: %#v", ep)
   151  	}
   152  
   153  	// Should use atomic CAS here
   154  	ep.state = Connected
   155  	defer func() {
   156  		debug.Printf("user disconnected %v", h)
   157  		ep.state = Disconnected
   158  		s.agent.Endpoints.RemoveHandle((*agent.Handle)(&h.Id))
   159  	}()
   160  
   161  	for {
   162  		select {
   163  		case <-stream.Context().Done():
   164  			return stream.Context().Err()
   165  		case action := <-ep.actionCh:
   166  			s.stats.Count.Inc(1)
   167  			s.stats.Rate.Mark(1)
   168  
   169  			ep.mu.Lock()
   170  			ep.actions[action.ID()] = action
   171  			ep.mu.Unlock()
   172  
   173  			if err := stream.Send(action.AsMessage()); err != nil {
   174  				debug.Printf("error while sending action: %s", err)
   175  				action.Fail(-1)
   176  
   177  				ep.mu.Lock()
   178  				delete(ep.actions, action.ID())
   179  				ep.mu.Unlock()
   180  
   181  				return errors.Wrap(err, "sending action failed")
   182  			}
   183  		}
   184  	}
   185  }
   186  
   187  // StatusStream provides the server with a stream of replies from the backend.
   188  // The backend includes its cookie in each reply. In theory it's possible for
   189  // replies to arrive for a Disconnected Endpoint, so we'll need proper protection
   190  // from various kinds of races here.
   191  func (s *dmRPCServer) StatusStream(stream pb.DataMover_StatusStreamServer) error {
   192  	for {
   193  		status, err := stream.Recv()
   194  		if err != nil {
   195  			return errors.Wrap(err, "status receive failed")
   196  		}
   197  		temp, ok := s.agent.Endpoints.GetWithHandle((*agent.Handle)(&status.Handle.Id))
   198  		if !ok {
   199  			debug.Printf("bad handle %v", status.Handle)
   200  			return errors.New("bad endpoint handle")
   201  		}
   202  		ep, ok := temp.(*AgentEndpoint)
   203  		if !ok {
   204  			debug.Printf("not an rpc endpoint: %#v", ep)
   205  			return errors.Errorf("not an rpc endpoint: %#v", ep)
   206  		}
   207  
   208  		ep.mu.Lock()
   209  		action, ok := ep.actions[agent.ActionID(status.Id)]
   210  		ep.mu.Unlock()
   211  		if ok {
   212  			completed, err := action.Update(status)
   213  			if completed {
   214  				ep.mu.Lock()
   215  				delete(ep.actions, agent.ActionID(status.Id))
   216  				ep.mu.Unlock()
   217  			} else if err != nil {
   218  				debug.Printf("Status update for 0x%x did not complete: %s", status.Id, err)
   219  				ep.mu.Lock()
   220  				delete(ep.actions, agent.ActionID(status.Id))
   221  				ep.mu.Unlock()
   222  
   223  				// send cancel to mover
   224  			}
   225  		} else {
   226  			debug.Printf("! unknown id: %x", status.Id)
   227  		}
   228  
   229  	}
   230  }
   231  
   232  func (s *dmRPCServer) startStats() {
   233  	go func() {
   234  		for {
   235  			fmt.Println(s.stats)
   236  			time.Sleep(10 * time.Second)
   237  		}
   238  	}()
   239  }
   240  
   241  func newServer(a *agent.HsmAgent) *dmRPCServer {
   242  	srv := &dmRPCServer{
   243  		stats: newMessageStats(),
   244  		agent: a,
   245  	}
   246  
   247  	//	srv.startStats()
   248  
   249  	return srv
   250  }