code.gitea.io/gitea@v1.19.3/modules/graceful/net_unix.go (about)

     1  // Copyright 2019 The Gitea Authors. All rights reserved.
     2  // SPDX-License-Identifier: MIT
     3  
     4  // This code is heavily inspired by the archived gofacebook/gracenet/net.go handler
     5  
     6  //go:build !windows
     7  
     8  package graceful
     9  
    10  import (
    11  	"fmt"
    12  	"net"
    13  	"os"
    14  	"strconv"
    15  	"strings"
    16  	"sync"
    17  
    18  	"code.gitea.io/gitea/modules/log"
    19  	"code.gitea.io/gitea/modules/setting"
    20  	"code.gitea.io/gitea/modules/util"
    21  )
    22  
    23  const (
    24  	listenFDs = "LISTEN_FDS"
    25  	startFD   = 3
    26  	unlinkFDs = "GITEA_UNLINK_FDS"
    27  )
    28  
    29  // In order to keep the working directory the same as when we started we record
    30  // it at startup.
    31  var originalWD, _ = os.Getwd()
    32  
    33  var (
    34  	once  = sync.Once{}
    35  	mutex = sync.Mutex{}
    36  
    37  	providedListenersToUnlink = []bool{}
    38  	activeListenersToUnlink   = []bool{}
    39  	providedListeners         = []net.Listener{}
    40  	activeListeners           = []net.Listener{}
    41  )
    42  
    43  func getProvidedFDs() (savedErr error) {
    44  	// Only inherit the provided FDS once but we will save the error so that repeated calls to this function will return the same error
    45  	once.Do(func() {
    46  		mutex.Lock()
    47  		defer mutex.Unlock()
    48  
    49  		numFDs := os.Getenv(listenFDs)
    50  		if numFDs == "" {
    51  			return
    52  		}
    53  		n, err := strconv.Atoi(numFDs)
    54  		if err != nil {
    55  			savedErr = fmt.Errorf("%s is not a number: %s. Err: %w", listenFDs, numFDs, err)
    56  			return
    57  		}
    58  
    59  		fdsToUnlinkStr := strings.Split(os.Getenv(unlinkFDs), ",")
    60  		providedListenersToUnlink = make([]bool, n)
    61  		for _, fdStr := range fdsToUnlinkStr {
    62  			i, err := strconv.Atoi(fdStr)
    63  			if err != nil || i < 0 || i >= n {
    64  				continue
    65  			}
    66  			providedListenersToUnlink[i] = true
    67  		}
    68  
    69  		for i := startFD; i < n+startFD; i++ {
    70  			file := os.NewFile(uintptr(i), fmt.Sprintf("listener_FD%d", i))
    71  
    72  			l, err := net.FileListener(file)
    73  			if err == nil {
    74  				// Close the inherited file if it's a listener
    75  				if err = file.Close(); err != nil {
    76  					savedErr = fmt.Errorf("error closing provided socket fd %d: %s", i, err)
    77  					return
    78  				}
    79  				providedListeners = append(providedListeners, l)
    80  				continue
    81  			}
    82  
    83  			// If needed we can handle packetconns here.
    84  			savedErr = fmt.Errorf("Error getting provided socket fd %d: %w", i, err)
    85  			return
    86  		}
    87  	})
    88  	return savedErr
    89  }
    90  
    91  // CloseProvidedListeners closes all unused provided listeners.
    92  func CloseProvidedListeners() error {
    93  	mutex.Lock()
    94  	defer mutex.Unlock()
    95  	var returnableError error
    96  	for _, l := range providedListeners {
    97  		err := l.Close()
    98  		if err != nil {
    99  			log.Error("Error in closing unused provided listener: %v", err)
   100  			if returnableError != nil {
   101  				returnableError = fmt.Errorf("%v & %w", returnableError, err)
   102  			} else {
   103  				returnableError = err
   104  			}
   105  		}
   106  	}
   107  	providedListeners = []net.Listener{}
   108  
   109  	return returnableError
   110  }
   111  
   112  // GetListener obtains a listener for the local network address. The network must be
   113  // a stream-oriented network: "tcp", "tcp4", "tcp6", "unix" or "unixpacket". It
   114  // returns an provided net.Listener for the matching network and address, or
   115  // creates a new one using net.Listen.
   116  func GetListener(network, address string) (net.Listener, error) {
   117  	// Add a deferral to say that we've tried to grab a listener
   118  	defer GetManager().InformCleanup()
   119  	switch network {
   120  	case "tcp", "tcp4", "tcp6":
   121  		tcpAddr, err := net.ResolveTCPAddr(network, address)
   122  		if err != nil {
   123  			return nil, err
   124  		}
   125  		return GetListenerTCP(network, tcpAddr)
   126  	case "unix", "unixpacket":
   127  		unixAddr, err := net.ResolveUnixAddr(network, address)
   128  		if err != nil {
   129  			return nil, err
   130  		}
   131  		return GetListenerUnix(network, unixAddr)
   132  	default:
   133  		return nil, net.UnknownNetworkError(network)
   134  	}
   135  }
   136  
   137  // GetListenerTCP announces on the local network address. The network must be:
   138  // "tcp", "tcp4" or "tcp6". It returns a provided net.Listener for the
   139  // matching network and address, or creates a new one using net.ListenTCP.
   140  func GetListenerTCP(network string, address *net.TCPAddr) (*net.TCPListener, error) {
   141  	if err := getProvidedFDs(); err != nil {
   142  		return nil, err
   143  	}
   144  
   145  	mutex.Lock()
   146  	defer mutex.Unlock()
   147  
   148  	// look for a provided listener
   149  	for i, l := range providedListeners {
   150  		if isSameAddr(l.Addr(), address) {
   151  			providedListeners = append(providedListeners[:i], providedListeners[i+1:]...)
   152  			needsUnlink := providedListenersToUnlink[i]
   153  			providedListenersToUnlink = append(providedListenersToUnlink[:i], providedListenersToUnlink[i+1:]...)
   154  
   155  			activeListeners = append(activeListeners, l)
   156  			activeListenersToUnlink = append(activeListenersToUnlink, needsUnlink)
   157  			return l.(*net.TCPListener), nil
   158  		}
   159  	}
   160  
   161  	// no provided listener for this address -> make a fresh listener
   162  	l, err := net.ListenTCP(network, address)
   163  	if err != nil {
   164  		return nil, err
   165  	}
   166  	activeListeners = append(activeListeners, l)
   167  	activeListenersToUnlink = append(activeListenersToUnlink, false)
   168  	return l, nil
   169  }
   170  
   171  // GetListenerUnix announces on the local network address. The network must be:
   172  // "unix" or "unixpacket". It returns a provided net.Listener for the
   173  // matching network and address, or creates a new one using net.ListenUnix.
   174  func GetListenerUnix(network string, address *net.UnixAddr) (*net.UnixListener, error) {
   175  	if err := getProvidedFDs(); err != nil {
   176  		return nil, err
   177  	}
   178  
   179  	mutex.Lock()
   180  	defer mutex.Unlock()
   181  
   182  	// look for a provided listener
   183  	for i, l := range providedListeners {
   184  		if isSameAddr(l.Addr(), address) {
   185  			providedListeners = append(providedListeners[:i], providedListeners[i+1:]...)
   186  			needsUnlink := providedListenersToUnlink[i]
   187  			providedListenersToUnlink = append(providedListenersToUnlink[:i], providedListenersToUnlink[i+1:]...)
   188  
   189  			activeListenersToUnlink = append(activeListenersToUnlink, needsUnlink)
   190  			activeListeners = append(activeListeners, l)
   191  			unixListener := l.(*net.UnixListener)
   192  			if needsUnlink {
   193  				unixListener.SetUnlinkOnClose(true)
   194  			}
   195  			return unixListener, nil
   196  		}
   197  	}
   198  
   199  	// make a fresh listener
   200  	if err := util.Remove(address.Name); err != nil && !os.IsNotExist(err) {
   201  		return nil, fmt.Errorf("Failed to remove unix socket %s: %w", address.Name, err)
   202  	}
   203  
   204  	l, err := net.ListenUnix(network, address)
   205  	if err != nil {
   206  		return nil, err
   207  	}
   208  
   209  	fileMode := os.FileMode(setting.UnixSocketPermission)
   210  	if err = os.Chmod(address.Name, fileMode); err != nil {
   211  		return nil, fmt.Errorf("Failed to set permission of unix socket to %s: %w", fileMode.String(), err)
   212  	}
   213  
   214  	activeListeners = append(activeListeners, l)
   215  	activeListenersToUnlink = append(activeListenersToUnlink, true)
   216  	return l, nil
   217  }
   218  
   219  func isSameAddr(a1, a2 net.Addr) bool {
   220  	// If the addresses are not on the same network fail.
   221  	if a1.Network() != a2.Network() {
   222  		return false
   223  	}
   224  
   225  	// If the two addresses have the same string representation they're equal
   226  	a1s := a1.String()
   227  	a2s := a2.String()
   228  	if a1s == a2s {
   229  		return true
   230  	}
   231  
   232  	// This allows for ipv6 vs ipv4 local addresses to compare as equal. This
   233  	// scenario is common when listening on localhost.
   234  	const ipv6prefix = "[::]"
   235  	a1s = strings.TrimPrefix(a1s, ipv6prefix)
   236  	a2s = strings.TrimPrefix(a2s, ipv6prefix)
   237  	const ipv4prefix = "0.0.0.0"
   238  	a1s = strings.TrimPrefix(a1s, ipv4prefix)
   239  	a2s = strings.TrimPrefix(a2s, ipv4prefix)
   240  	return a1s == a2s
   241  }
   242  
   243  func getActiveListeners() []net.Listener {
   244  	mutex.Lock()
   245  	defer mutex.Unlock()
   246  	listeners := make([]net.Listener, len(activeListeners))
   247  	copy(listeners, activeListeners)
   248  	return listeners
   249  }
   250  
   251  func getActiveListenersToUnlink() []bool {
   252  	mutex.Lock()
   253  	defer mutex.Unlock()
   254  	listenersToUnlink := make([]bool, len(activeListenersToUnlink))
   255  	copy(listenersToUnlink, activeListenersToUnlink)
   256  	return listenersToUnlink
   257  }