github.com/cs3org/reva/v2@v2.27.7/cmd/revad/internal/grace/grace.go (about)

     1  // Copyright 2018-2021 CERN
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  //
    15  // In applying this license, CERN does not waive the privileges and immunities
    16  // granted to it by virtue of its status as an Intergovernmental Organization
    17  // or submit itself to any jurisdiction.
    18  
    19  package grace
    20  
    21  import (
    22  	"fmt"
    23  	"net"
    24  	"os"
    25  	"os/signal"
    26  	"path/filepath"
    27  	"strconv"
    28  	"strings"
    29  	"syscall"
    30  	"time"
    31  
    32  	"github.com/pkg/errors"
    33  	"github.com/rs/zerolog"
    34  )
    35  
    36  // Watcher watches a process for a graceful restart
    37  // preserving open network sockets to avoid packet loss.
    38  type Watcher struct {
    39  	log                     zerolog.Logger
    40  	graceful                bool
    41  	ppid                    int
    42  	lns                     map[string]net.Listener
    43  	ss                      map[string]Server
    44  	pidFile                 string
    45  	childPIDs               []int
    46  	gracefulShutdownTimeout int
    47  }
    48  
    49  // Option represent an option.
    50  type Option func(w *Watcher)
    51  
    52  // WithLogger adds a logger to the Watcher.
    53  func WithLogger(l zerolog.Logger) Option {
    54  	return func(w *Watcher) {
    55  		w.log = l
    56  	}
    57  }
    58  
    59  // WithPIDFile specifies the pid file to use.
    60  func WithPIDFile(fn string) Option {
    61  	return func(w *Watcher) {
    62  		w.pidFile = fn
    63  	}
    64  }
    65  
    66  func WithGracefuleShutdownTimeout(seconds int) Option {
    67  	return func(w *Watcher) {
    68  		w.gracefulShutdownTimeout = seconds
    69  	}
    70  }
    71  
    72  // NewWatcher creates a Watcher.
    73  func NewWatcher(opts ...Option) *Watcher {
    74  	w := &Watcher{
    75  		log:      zerolog.Nop(),
    76  		graceful: os.Getenv("GRACEFUL") == "true",
    77  		ppid:     os.Getppid(),
    78  		ss:       map[string]Server{},
    79  	}
    80  
    81  	for _, opt := range opts {
    82  		opt(w)
    83  	}
    84  
    85  	return w
    86  }
    87  
    88  // Exit exits the current process cleaning up
    89  // existing pid files.
    90  func (w *Watcher) Exit(errc int) {
    91  	err := w.clean()
    92  	if err != nil {
    93  		w.log.Warn().Err(err).Msg("error removing pid file")
    94  	} else {
    95  		w.log.Info().Msgf("pid file %q got removed", w.pidFile)
    96  	}
    97  	os.Exit(errc)
    98  }
    99  
   100  func (w *Watcher) clean() error {
   101  	// only remove PID file if the PID has been written by us
   102  	filePID, err := w.readPID()
   103  	if err != nil {
   104  		return err
   105  	}
   106  
   107  	if filePID != os.Getpid() {
   108  		// the pidfile may have been changed by a forked child
   109  		// TODO(labkode): is there a way to list children pids for current process?
   110  		return fmt.Errorf("pid:%d in pidfile is different from pid:%d, it can be a leftover from a hard shutdown or that a reload was triggered", filePID, os.Getpid())
   111  	}
   112  
   113  	return os.Remove(w.pidFile)
   114  }
   115  
   116  func (w *Watcher) readPID() (int, error) {
   117  	piddata, err := os.ReadFile(w.pidFile)
   118  	if err != nil {
   119  		return 0, err
   120  	}
   121  	// Convert the file contents to an integer.
   122  	pid, err := strconv.Atoi(string(piddata))
   123  	if err != nil {
   124  		return 0, err
   125  	}
   126  	return pid, nil
   127  }
   128  
   129  // GetProcessFromFile reads the pidfile and returns the running process or error if the process or file
   130  // are not available.
   131  func GetProcessFromFile(pfile string) (*os.Process, error) {
   132  	data, err := os.ReadFile(pfile)
   133  	if err != nil {
   134  		return nil, err
   135  	}
   136  
   137  	pid, err := strconv.Atoi(string(data))
   138  	if err != nil {
   139  		return nil, err
   140  	}
   141  
   142  	process, err := os.FindProcess(pid)
   143  	if err != nil {
   144  		return nil, err
   145  	}
   146  
   147  	return process, nil
   148  }
   149  
   150  // WritePID writes the pid to the configured pid file.
   151  func (w *Watcher) WritePID() error {
   152  	// Read in the pid file as a slice of bytes.
   153  	if piddata, err := os.ReadFile(w.pidFile); err == nil {
   154  		// Convert the file contents to an integer.
   155  		if pid, err := strconv.Atoi(string(piddata)); err == nil {
   156  			// Look for the pid in the process list.
   157  			if process, err := os.FindProcess(pid); err == nil {
   158  				// Send the process a signal zero kill.
   159  				if err := process.Signal(syscall.Signal(0)); err == nil {
   160  					if !w.graceful {
   161  						return fmt.Errorf("pid already running: %d", pid)
   162  					}
   163  
   164  					if pid != w.ppid { // overwrite only if parent pid is pidfile
   165  						// We only get an error if the pid isn't running, or it's not ours.
   166  						return fmt.Errorf("pid %d is not this process parent", pid)
   167  					}
   168  				} else {
   169  					w.log.Warn().Err(err).Msg("error sending zero kill signal to current process")
   170  				}
   171  			} else {
   172  				w.log.Warn().Msgf("pid:%d not found", pid)
   173  			}
   174  		} else {
   175  			w.log.Warn().Msg("error casting contents of pidfile to pid(int)")
   176  		}
   177  	} // else {
   178  	// w.log.Info().Msg("error reading pidfile")
   179  	//}
   180  
   181  	// If we get here, then the pidfile didn't exist or we are are in graceful reload and thus we overwrite
   182  	// or the pid in it doesn't belong to the user running this app.
   183  	err := os.WriteFile(w.pidFile, []byte(fmt.Sprintf("%d", os.Getpid())), 0664)
   184  	if err != nil {
   185  		return err
   186  	}
   187  	w.log.Info().Msgf("pidfile saved at: %s", w.pidFile)
   188  	return nil
   189  }
   190  
   191  func newListener(network, addr string) (net.Listener, error) {
   192  	return net.Listen(network, addr)
   193  }
   194  
   195  // GetListeners return grpc listener first and http listener second.
   196  func (w *Watcher) GetListeners(servers map[string]Server) (map[string]net.Listener, error) {
   197  	w.ss = servers
   198  	lns := map[string]net.Listener{}
   199  	if w.graceful {
   200  		w.log.Info().Msg("graceful restart, inheriting parent ln fds for grpc and http")
   201  		count := 3
   202  		for k, s := range servers {
   203  			network, addr := s.Network(), s.Address()
   204  			fd := os.NewFile(uintptr(count), "") // 3 because ExtraFile passed to new process
   205  			count++
   206  			ln, err := net.FileListener(fd)
   207  			if err != nil {
   208  				w.log.Error().Err(err).Msg("error creating net.Listener from fd")
   209  				// create new fd
   210  				ln, err := newListener(network, addr)
   211  				if err != nil {
   212  					return nil, err
   213  				}
   214  				lns[k] = ln
   215  			} else {
   216  				lns[k] = ln
   217  			}
   218  
   219  		}
   220  
   221  		// kill parent
   222  		// TODO(labkode): maybe race condition here?
   223  		// What do we do if we cannot kill the parent but we have valid fds?
   224  		// Do we abort running the forked child? Probably yes, as if the parent cannot be
   225  		// killed that means we run two version of the code indefinitely.
   226  		w.log.Info().Msgf("killing parent pid gracefully with SIGQUIT: %d", w.ppid)
   227  		p, err := os.FindProcess(w.ppid)
   228  		if err != nil {
   229  			w.log.Error().Err(err).Msgf("error finding parent process with ppid:%d", w.ppid)
   230  			err = errors.Wrap(err, "error finding parent process")
   231  			return nil, err
   232  		}
   233  		err = p.Kill()
   234  		if err != nil {
   235  			w.log.Error().Err(err).Msgf("error killing parent process with ppid:%d", w.ppid)
   236  			err = errors.Wrap(err, "error killing parent process")
   237  			return nil, err
   238  		}
   239  		w.lns = lns
   240  		return lns, nil
   241  	}
   242  
   243  	// create two listeners for grpc and http
   244  	for k, s := range servers {
   245  		network, addr := s.Network(), s.Address()
   246  		ln, err := newListener(network, addr)
   247  		if err != nil {
   248  			return nil, err
   249  		}
   250  		lns[k] = ln
   251  
   252  	}
   253  	w.lns = lns
   254  	return lns, nil
   255  }
   256  
   257  // Server is the interface that servers like HTTP or gRPC
   258  // servers need to implement.
   259  type Server interface {
   260  	Stop() error
   261  	GracefulStop() error
   262  	Network() string
   263  	Address() string
   264  }
   265  
   266  // TrapSignals captures the OS signal.
   267  func (w *Watcher) TrapSignals() {
   268  	signalCh := make(chan os.Signal, 1024)
   269  	signal.Notify(signalCh, syscall.SIGHUP, syscall.SIGINT, syscall.SIGQUIT)
   270  	for {
   271  		s := <-signalCh
   272  		w.log.Info().Msgf("%v signal received", s)
   273  
   274  		switch s {
   275  		case syscall.SIGHUP:
   276  			w.log.Info().Msg("preparing for a hot-reload, forking child process...")
   277  
   278  			// Fork a child process.
   279  			listeners := w.lns
   280  			p, err := forkChild(listeners)
   281  			if err != nil {
   282  				w.log.Error().Err(err).Msgf("unable to fork child process")
   283  			} else {
   284  				w.log.Info().Msgf("child forked with new pid %d", p.Pid)
   285  				w.childPIDs = append(w.childPIDs, p.Pid)
   286  			}
   287  
   288  		case syscall.SIGQUIT:
   289  			gracefulShutdown(w)
   290  		case syscall.SIGINT, syscall.SIGTERM:
   291  			if w.gracefulShutdownTimeout == 0 {
   292  				hardShutdown(w)
   293  			}
   294  			gracefulShutdown(w)
   295  		}
   296  	}
   297  }
   298  
   299  // TODO: Ideally this would call exit() but properly return an error. The
   300  // exit() is problematic (i.e. racey) especiaily when orchestrating multiple
   301  // reva services from some external runtime (like in the "ocis server" case
   302  func gracefulShutdown(w *Watcher) {
   303  	w.log.Info().Int("Timeout", w.gracefulShutdownTimeout).Msg("preparing for a graceful shutdown with deadline")
   304  	go func() {
   305  		count := w.gracefulShutdownTimeout
   306  		ticker := time.NewTicker(time.Second)
   307  		for ; true; <-ticker.C {
   308  			w.log.Info().Msgf("shutting down in %d seconds", count-1)
   309  			count--
   310  			if count <= 0 {
   311  				w.log.Info().Msg("deadline reached before draining active conns, hard stopping ...")
   312  				for _, s := range w.ss {
   313  					err := s.Stop()
   314  					if err != nil {
   315  						w.log.Error().Err(err).Msg("error stopping server")
   316  					}
   317  					w.log.Info().Msgf("fd to %s:%s abruptly closed", s.Network(), s.Address())
   318  				}
   319  				w.Exit(1)
   320  			}
   321  		}
   322  	}()
   323  	for _, s := range w.ss {
   324  		w.log.Info().Msgf("fd to %s:%s gracefully closed ", s.Network(), s.Address())
   325  		err := s.GracefulStop()
   326  		if err != nil {
   327  			w.log.Error().Err(err).Msg("error stopping server")
   328  			w.log.Info().Msg("exit with error code 1")
   329  
   330  			w.Exit(1)
   331  		}
   332  	}
   333  	w.log.Info().Msg("exit with error code 0")
   334  	w.Exit(0)
   335  }
   336  
   337  // TODO: Ideally this would call exit() but properly return an error. The
   338  // exit() is problematic (i.e. racey) especiaily when orchestrating multiple
   339  // reva services from some external runtime (like in the "ocis server" case
   340  func hardShutdown(w *Watcher) {
   341  	w.log.Info().Msg("preparing for hard shutdown, aborting all conns")
   342  	for _, s := range w.ss {
   343  		w.log.Info().Msgf("fd to %s:%s abruptly closed", s.Network(), s.Address())
   344  		err := s.Stop()
   345  		if err != nil {
   346  			w.log.Error().Err(err).Msg("error stopping server")
   347  		}
   348  	}
   349  	w.Exit(0)
   350  }
   351  
   352  func getListenerFile(ln net.Listener) (*os.File, error) {
   353  	switch t := ln.(type) {
   354  	case *net.TCPListener:
   355  		return t.File()
   356  	case *net.UnixListener:
   357  		return t.File()
   358  	}
   359  	return nil, fmt.Errorf("unsupported listener: %T", ln)
   360  }
   361  
   362  func forkChild(lns map[string]net.Listener) (*os.Process, error) {
   363  	// Get the file descriptor for the listener and marshal the metadata to pass
   364  	// to the child in the environment.
   365  	fds := map[string]*os.File{}
   366  	for k, ln := range lns {
   367  		fd, err := getListenerFile(ln)
   368  		if err != nil {
   369  			return nil, err
   370  		}
   371  		fds[k] = fd
   372  	}
   373  
   374  	// Pass stdin, stdout, and stderr along with the listener file to the child
   375  	files := []*os.File{
   376  		os.Stdin,
   377  		os.Stdout,
   378  		os.Stderr,
   379  	}
   380  
   381  	// Get current environment and add in the listener to it.
   382  	environment := append(os.Environ(), "GRACEFUL=true")
   383  	var counter = 3
   384  	for k, fd := range fds {
   385  		k = strings.ToUpper(k)
   386  		environment = append(environment, k+"FD="+fmt.Sprintf("%d", counter))
   387  		files = append(files, fd)
   388  		counter++
   389  	}
   390  
   391  	// Get current process name and directory.
   392  	execName, err := os.Executable()
   393  	if err != nil {
   394  		return nil, err
   395  	}
   396  	execDir := filepath.Dir(execName)
   397  
   398  	// Spawn child process.
   399  	p, err := os.StartProcess(execName, os.Args, &os.ProcAttr{
   400  		Dir:   execDir,
   401  		Env:   environment,
   402  		Files: files,
   403  		Sys:   &syscall.SysProcAttr{},
   404  	})
   405  
   406  	// TODO(labkode): if the process dies (because config changed and is wrong
   407  	// we need to return an error
   408  	if err != nil {
   409  		return nil, err
   410  	}
   411  
   412  	return p, nil
   413  }