code.gitea.io/gitea@v1.22.3/modules/graceful/net_unix.go (about)

     1  // Copyright 2019 The Gitea Authors. All rights reserved.
     2  // SPDX-License-Identifier: MIT
     3  
     4  // This code is heavily inspired by the archived gofacebook/gracenet/net.go handler
     5  
     6  //go:build !windows
     7  
     8  package graceful
     9  
    10  import (
    11  	"fmt"
    12  	"net"
    13  	"os"
    14  	"strconv"
    15  	"strings"
    16  	"sync"
    17  	"time"
    18  
    19  	"code.gitea.io/gitea/modules/log"
    20  	"code.gitea.io/gitea/modules/setting"
    21  	"code.gitea.io/gitea/modules/util"
    22  )
    23  
    24  const (
    25  	listenFDsEnv = "LISTEN_FDS"
    26  	startFD      = 3
    27  	unlinkFDsEnv = "GITEA_UNLINK_FDS"
    28  
    29  	notifySocketEnv    = "NOTIFY_SOCKET"
    30  	watchdogTimeoutEnv = "WATCHDOG_USEC"
    31  )
    32  
    33  // In order to keep the working directory the same as when we started we record
    34  // it at startup.
    35  var originalWD, _ = os.Getwd()
    36  
    37  var (
    38  	once  = sync.Once{}
    39  	mutex = sync.Mutex{}
    40  
    41  	providedListenersToUnlink = []bool{}
    42  	activeListenersToUnlink   = []bool{}
    43  	providedListeners         = []net.Listener{}
    44  	activeListeners           = []net.Listener{}
    45  
    46  	notifySocketAddr string
    47  	watchdogTimeout  time.Duration
    48  )
    49  
    50  func getProvidedFDs() (savedErr error) {
    51  	// Only inherit the provided FDS once but we will save the error so that repeated calls to this function will return the same error
    52  	once.Do(func() {
    53  		mutex.Lock()
    54  		defer mutex.Unlock()
    55  		// now handle some additional systemd provided things
    56  		notifySocketAddr = os.Getenv(notifySocketEnv)
    57  		if notifySocketAddr != "" {
    58  			log.Debug("Systemd Notify Socket provided: %s", notifySocketAddr)
    59  			savedErr = os.Unsetenv(notifySocketEnv)
    60  			if savedErr != nil {
    61  				log.Warn("Unable to Unset the NOTIFY_SOCKET environment variable: %v", savedErr)
    62  				return
    63  			}
    64  			// FIXME: We don't handle WATCHDOG_PID
    65  			timeoutStr := os.Getenv(watchdogTimeoutEnv)
    66  			if timeoutStr != "" {
    67  				savedErr = os.Unsetenv(watchdogTimeoutEnv)
    68  				if savedErr != nil {
    69  					log.Warn("Unable to Unset the WATCHDOG_USEC environment variable: %v", savedErr)
    70  					return
    71  				}
    72  
    73  				s, err := strconv.ParseInt(timeoutStr, 10, 64)
    74  				if err != nil {
    75  					log.Error("Unable to parse the provided WATCHDOG_USEC: %v", err)
    76  					savedErr = fmt.Errorf("unable to parse the provided WATCHDOG_USEC: %w", err)
    77  					return
    78  				}
    79  				if s <= 0 {
    80  					log.Error("Unable to parse the provided WATCHDOG_USEC: %s should be a positive number", timeoutStr)
    81  					savedErr = fmt.Errorf("unable to parse the provided WATCHDOG_USEC: %s should be a positive number", timeoutStr)
    82  					return
    83  				}
    84  				watchdogTimeout = time.Duration(s) * time.Microsecond
    85  			}
    86  		} else {
    87  			log.Trace("No Systemd Notify Socket provided")
    88  		}
    89  
    90  		numFDs := os.Getenv(listenFDsEnv)
    91  		if numFDs == "" {
    92  			return
    93  		}
    94  		n, err := strconv.Atoi(numFDs)
    95  		if err != nil {
    96  			savedErr = fmt.Errorf("%s is not a number: %s. Err: %w", listenFDsEnv, numFDs, err)
    97  			return
    98  		}
    99  
   100  		fdsToUnlinkStr := strings.Split(os.Getenv(unlinkFDsEnv), ",")
   101  		providedListenersToUnlink = make([]bool, n)
   102  		for _, fdStr := range fdsToUnlinkStr {
   103  			i, err := strconv.Atoi(fdStr)
   104  			if err != nil || i < 0 || i >= n {
   105  				continue
   106  			}
   107  			providedListenersToUnlink[i] = true
   108  		}
   109  
   110  		for i := startFD; i < n+startFD; i++ {
   111  			file := os.NewFile(uintptr(i), fmt.Sprintf("listener_FD%d", i))
   112  
   113  			l, err := net.FileListener(file)
   114  			if err == nil {
   115  				// Close the inherited file if it's a listener
   116  				if err = file.Close(); err != nil {
   117  					savedErr = fmt.Errorf("error closing provided socket fd %d: %w", i, err)
   118  					return
   119  				}
   120  				providedListeners = append(providedListeners, l)
   121  				continue
   122  			}
   123  
   124  			// If needed we can handle packetconns here.
   125  			savedErr = fmt.Errorf("Error getting provided socket fd %d: %w", i, err)
   126  			return
   127  		}
   128  	})
   129  	return savedErr
   130  }
   131  
   132  // closeProvidedListeners closes all unused provided listeners.
   133  func closeProvidedListeners() {
   134  	mutex.Lock()
   135  	defer mutex.Unlock()
   136  	for _, l := range providedListeners {
   137  		err := l.Close()
   138  		if err != nil {
   139  			log.Error("Error in closing unused provided listener: %v", err)
   140  		}
   141  	}
   142  	providedListeners = []net.Listener{}
   143  }
   144  
   145  // DefaultGetListener obtains a listener for the stream-oriented local network address:
   146  // "tcp", "tcp4", "tcp6", "unix" or "unixpacket".
   147  func DefaultGetListener(network, address string) (net.Listener, error) {
   148  	// Add a deferral to say that we've tried to grab a listener
   149  	defer GetManager().InformCleanup()
   150  	switch network {
   151  	case "tcp", "tcp4", "tcp6":
   152  		tcpAddr, err := net.ResolveTCPAddr(network, address)
   153  		if err != nil {
   154  			return nil, err
   155  		}
   156  		return GetListenerTCP(network, tcpAddr)
   157  	case "unix", "unixpacket":
   158  		unixAddr, err := net.ResolveUnixAddr(network, address)
   159  		if err != nil {
   160  			return nil, err
   161  		}
   162  		return GetListenerUnix(network, unixAddr)
   163  	default:
   164  		return nil, net.UnknownNetworkError(network)
   165  	}
   166  }
   167  
   168  // GetListenerTCP announces on the local network address. The network must be:
   169  // "tcp", "tcp4" or "tcp6". It returns a provided net.Listener for the
   170  // matching network and address, or creates a new one using net.ListenTCP.
   171  func GetListenerTCP(network string, address *net.TCPAddr) (*net.TCPListener, error) {
   172  	if err := getProvidedFDs(); err != nil {
   173  		return nil, err
   174  	}
   175  
   176  	mutex.Lock()
   177  	defer mutex.Unlock()
   178  
   179  	// look for a provided listener
   180  	for i, l := range providedListeners {
   181  		if isSameAddr(l.Addr(), address) {
   182  			providedListeners = append(providedListeners[:i], providedListeners[i+1:]...)
   183  			needsUnlink := providedListenersToUnlink[i]
   184  			providedListenersToUnlink = append(providedListenersToUnlink[:i], providedListenersToUnlink[i+1:]...)
   185  
   186  			activeListeners = append(activeListeners, l)
   187  			activeListenersToUnlink = append(activeListenersToUnlink, needsUnlink)
   188  			return l.(*net.TCPListener), nil
   189  		}
   190  	}
   191  
   192  	// no provided listener for this address -> make a fresh listener
   193  	l, err := net.ListenTCP(network, address)
   194  	if err != nil {
   195  		return nil, err
   196  	}
   197  	activeListeners = append(activeListeners, l)
   198  	activeListenersToUnlink = append(activeListenersToUnlink, false)
   199  	return l, nil
   200  }
   201  
   202  // GetListenerUnix announces on the local network address. The network must be:
   203  // "unix" or "unixpacket". It returns a provided net.Listener for the
   204  // matching network and address, or creates a new one using net.ListenUnix.
   205  func GetListenerUnix(network string, address *net.UnixAddr) (*net.UnixListener, error) {
   206  	if err := getProvidedFDs(); err != nil {
   207  		return nil, err
   208  	}
   209  
   210  	mutex.Lock()
   211  	defer mutex.Unlock()
   212  
   213  	// look for a provided listener
   214  	for i, l := range providedListeners {
   215  		if isSameAddr(l.Addr(), address) {
   216  			providedListeners = append(providedListeners[:i], providedListeners[i+1:]...)
   217  			needsUnlink := providedListenersToUnlink[i]
   218  			providedListenersToUnlink = append(providedListenersToUnlink[:i], providedListenersToUnlink[i+1:]...)
   219  
   220  			activeListenersToUnlink = append(activeListenersToUnlink, needsUnlink)
   221  			activeListeners = append(activeListeners, l)
   222  			unixListener := l.(*net.UnixListener)
   223  			if needsUnlink {
   224  				unixListener.SetUnlinkOnClose(true)
   225  			}
   226  			return unixListener, nil
   227  		}
   228  	}
   229  
   230  	// make a fresh listener
   231  	if err := util.Remove(address.Name); err != nil && !os.IsNotExist(err) {
   232  		return nil, fmt.Errorf("Failed to remove unix socket %s: %w", address.Name, err)
   233  	}
   234  
   235  	l, err := net.ListenUnix(network, address)
   236  	if err != nil {
   237  		return nil, err
   238  	}
   239  
   240  	fileMode := os.FileMode(setting.UnixSocketPermission)
   241  	if err = os.Chmod(address.Name, fileMode); err != nil {
   242  		return nil, fmt.Errorf("Failed to set permission of unix socket to %s: %w", fileMode.String(), err)
   243  	}
   244  
   245  	activeListeners = append(activeListeners, l)
   246  	activeListenersToUnlink = append(activeListenersToUnlink, true)
   247  	return l, nil
   248  }
   249  
   250  func isSameAddr(a1, a2 net.Addr) bool {
   251  	// If the addresses are not on the same network fail.
   252  	if a1.Network() != a2.Network() {
   253  		return false
   254  	}
   255  
   256  	// If the two addresses have the same string representation they're equal
   257  	a1s := a1.String()
   258  	a2s := a2.String()
   259  	if a1s == a2s {
   260  		return true
   261  	}
   262  
   263  	// This allows for ipv6 vs ipv4 local addresses to compare as equal. This
   264  	// scenario is common when listening on localhost.
   265  	const ipv6prefix = "[::]"
   266  	a1s = strings.TrimPrefix(a1s, ipv6prefix)
   267  	a2s = strings.TrimPrefix(a2s, ipv6prefix)
   268  	const ipv4prefix = "0.0.0.0"
   269  	a1s = strings.TrimPrefix(a1s, ipv4prefix)
   270  	a2s = strings.TrimPrefix(a2s, ipv4prefix)
   271  	return a1s == a2s
   272  }
   273  
   274  func getActiveListeners() []net.Listener {
   275  	mutex.Lock()
   276  	defer mutex.Unlock()
   277  	listeners := make([]net.Listener, len(activeListeners))
   278  	copy(listeners, activeListeners)
   279  	return listeners
   280  }
   281  
   282  func getActiveListenersToUnlink() []bool {
   283  	mutex.Lock()
   284  	defer mutex.Unlock()
   285  	listenersToUnlink := make([]bool, len(activeListenersToUnlink))
   286  	copy(listenersToUnlink, activeListenersToUnlink)
   287  	return listenersToUnlink
   288  }
   289  
   290  func getNotifySocket() (*net.UnixConn, error) {
   291  	if err := getProvidedFDs(); err != nil {
   292  		// This error will be logged elsewhere
   293  		return nil, nil
   294  	}
   295  
   296  	if notifySocketAddr == "" {
   297  		return nil, nil
   298  	}
   299  
   300  	socketAddr := &net.UnixAddr{
   301  		Name: notifySocketAddr,
   302  		Net:  "unixgram",
   303  	}
   304  
   305  	notifySocket, err := net.DialUnix(socketAddr.Net, nil, socketAddr)
   306  	if err != nil {
   307  		log.Warn("failed to dial NOTIFY_SOCKET %s: %v", socketAddr, err)
   308  		return nil, err
   309  	}
   310  
   311  	return notifySocket, nil
   312  }
   313  
   314  func getWatchdogTimeout() time.Duration {
   315  	if err := getProvidedFDs(); err != nil {
   316  		// This error will be logged elsewhere
   317  		return 0
   318  	}
   319  
   320  	return watchdogTimeout
   321  }