src.elv.sh@v0.21.0-dev.0.20240515223629-06979efb9a2a/pkg/daemon/server.go (about)

     1  // Package daemon implements a service for mediating access to the data store,
     2  // and its client.
     3  //
     4  // Most RPCs exposed by the service correspond to the methods of Store in the
     5  // store package and are not documented here.
     6  package daemon
     7  
     8  import (
     9  	"net"
    10  	"os"
    11  	"os/signal"
    12  	"syscall"
    13  
    14  	"src.elv.sh/pkg/daemon/internal/api"
    15  	"src.elv.sh/pkg/logutil"
    16  	"src.elv.sh/pkg/prog"
    17  	"src.elv.sh/pkg/rpc"
    18  	"src.elv.sh/pkg/store"
    19  )
    20  
    21  var logger = logutil.GetLogger("[daemon] ")
    22  
    23  // Program is the daemon subprogram.
    24  type Program struct {
    25  	run   bool
    26  	paths *prog.DaemonPaths
    27  	// Used in tests.
    28  	serveOpts ServeOpts
    29  }
    30  
    31  func (p *Program) RegisterFlags(fs *prog.FlagSet) {
    32  	fs.BoolVar(&p.run, "daemon", false,
    33  		"[internal flag] Run the storage daemon instead of an Elvish shell")
    34  	p.paths = fs.DaemonPaths()
    35  }
    36  
    37  func (p *Program) Run(fds [3]*os.File, args []string) error {
    38  	if !p.run {
    39  		return prog.NextProgram()
    40  	}
    41  	if len(args) > 0 {
    42  		return prog.BadUsage("arguments are not allowed with -daemon")
    43  	}
    44  
    45  	// The stdout is redirected to a unique log file (see the spawn function),
    46  	// so just use it for logging.
    47  	logutil.SetOutput(fds[1])
    48  	setUmaskForDaemon()
    49  	exit := Serve(p.paths.Sock, p.paths.DB, p.serveOpts)
    50  	return prog.Exit(exit)
    51  }
    52  
    53  // ServeOpts keeps options that can be passed to Serve.
    54  type ServeOpts struct {
    55  	// If not nil, will be closed when the daemon is ready to serve requests.
    56  	Ready chan<- struct{}
    57  	// Causes the daemon to abort if closed or sent any date. If nil, Serve will
    58  	// set up its own signal channel by listening to SIGINT and SIGTERM.
    59  	Signals <-chan os.Signal
    60  	// If not nil, overrides the response of the Version RPC.
    61  	Version *int
    62  }
    63  
    64  // Serve runs the daemon service, listening on the socket specified by sockpath
    65  // and serving data from dbpath until all clients have exited. See doc for
    66  // ServeOpts for additional options.
    67  func Serve(sockpath, dbpath string, opts ServeOpts) int {
    68  	logger.Println("pid is", syscall.Getpid())
    69  	logger.Println("going to listen", sockpath)
    70  	listener, err := net.Listen("unix", sockpath)
    71  	if err != nil {
    72  		logger.Printf("failed to listen on %s: %v", sockpath, err)
    73  		logger.Println("aborting")
    74  		return 2
    75  	}
    76  
    77  	st, err := store.NewStore(dbpath)
    78  	if err != nil {
    79  		logger.Printf("failed to create storage: %v", err)
    80  		logger.Printf("serving anyway")
    81  	}
    82  
    83  	server := rpc.NewServer()
    84  	version := api.Version
    85  	if opts.Version != nil {
    86  		version = *opts.Version
    87  	}
    88  	server.RegisterName(api.ServiceName, &service{version, st, err})
    89  
    90  	connCh := make(chan net.Conn, 10)
    91  	listenErrCh := make(chan error, 1)
    92  	go func() {
    93  		for {
    94  			conn, err := listener.Accept()
    95  			if err != nil {
    96  				listenErrCh <- err
    97  				close(listenErrCh)
    98  				return
    99  			}
   100  			connCh <- conn
   101  		}
   102  	}()
   103  
   104  	sigCh := opts.Signals
   105  	if sigCh == nil {
   106  		ch := make(chan os.Signal, 1)
   107  		signal.Notify(ch, syscall.SIGTERM, syscall.SIGINT)
   108  		sigCh = ch
   109  	}
   110  
   111  	conns := make(map[net.Conn]struct{})
   112  	connDoneCh := make(chan net.Conn, 10)
   113  
   114  	interrupt := func() {
   115  		if len(conns) == 0 {
   116  			logger.Println("exiting since there are no clients")
   117  		}
   118  		logger.Printf("going to close %v active connections", len(conns))
   119  		for conn := range conns {
   120  			// Ignore the error - if we can't close the connection it's because
   121  			// the client has closed it. There is nothing we can do anyway.
   122  			conn.Close()
   123  		}
   124  	}
   125  
   126  	if opts.Ready != nil {
   127  		close(opts.Ready)
   128  	}
   129  
   130  loop:
   131  	for {
   132  		select {
   133  		case sig := <-sigCh:
   134  			logger.Printf("received signal %v", sig)
   135  			interrupt()
   136  			break loop
   137  		case err := <-listenErrCh:
   138  			logger.Println("could not listen:", err)
   139  			if len(conns) == 0 {
   140  				logger.Println("exiting since there are no clients")
   141  				break loop
   142  			}
   143  			logger.Println("continuing to serve until all existing clients exit")
   144  		case conn := <-connCh:
   145  			conns[conn] = struct{}{}
   146  			go func() {
   147  				server.ServeConn(conn)
   148  				connDoneCh <- conn
   149  			}()
   150  		case conn := <-connDoneCh:
   151  			delete(conns, conn)
   152  			if len(conns) == 0 {
   153  				logger.Println("all clients disconnected, exiting")
   154  				break loop
   155  			}
   156  		}
   157  	}
   158  
   159  	err = os.Remove(sockpath)
   160  	if err != nil {
   161  		logger.Printf("failed to remove socket %s: %v", sockpath, err)
   162  	}
   163  	if st != nil {
   164  		err = st.Close()
   165  		if err != nil {
   166  			logger.Printf("failed to close storage: %v", err)
   167  		}
   168  	}
   169  	err = listener.Close()
   170  	if err != nil {
   171  		logger.Printf("failed to close listener: %v", err)
   172  	}
   173  	// Ensure that the listener goroutine has exited before returning
   174  	<-listenErrCh
   175  	return 0
   176  }