github.com/aristanetworks/goarista@v0.0.0-20240514173732-cca2755bbd44/netns/nslistener.go (about)

     1  // Copyright (c) 2019 Arista Networks, Inc.
     2  // Use of this source code is governed by the Apache License 2.0
     3  // that can be found in the COPYING file.
     4  
     5  package netns
     6  
     7  import (
     8  	"bufio"
     9  	"errors"
    10  	"fmt"
    11  	"io"
    12  	"net"
    13  	"os"
    14  	"path/filepath"
    15  	"strings"
    16  	"time"
    17  
    18  	"github.com/aristanetworks/fsnotify"
    19  	"github.com/aristanetworks/goarista/logger"
    20  )
    21  
    22  // ListenerCreator is the signature of a function which creates a listener,
    23  // for use in functions where custom listeners can be generated
    24  type ListenerCreator func() (net.Listener, error)
    25  
    26  var makeListener = func(nsName string, listenerCreator ListenerCreator) (net.Listener, error) {
    27  	var listener net.Listener
    28  	err := Do(nsName, func() error {
    29  		var err error
    30  		listener, err = listenerCreator()
    31  		return err
    32  	})
    33  	return listener, err
    34  }
    35  
    36  func accept(listener net.Listener, conns chan<- net.Conn, logger logger.Logger) {
    37  	for {
    38  		c, err := listener.Accept()
    39  		if err != nil {
    40  			logger.Infof("Accept error: %v", err)
    41  			return
    42  		}
    43  		conns <- c
    44  	}
    45  }
    46  
    47  func (l *nsListener) waitForMount() bool {
    48  	for !hasMount(l.nsFile, l.logger) {
    49  		time.Sleep(time.Second)
    50  		if _, err := os.Stat(l.nsFile); err != nil {
    51  			l.logger.Infof("error stating %s: %v", l.nsFile, err)
    52  			return false
    53  		}
    54  	}
    55  	return true
    56  }
    57  
    58  // nsListener is a net.Listener that binds to a specific network namespace when it becomes available
    59  // and in case it gets deleted and recreated it will automatically bind to the newly created
    60  // namespace.
    61  type nsListener struct {
    62  	listener        net.Listener
    63  	watcher         *fsnotify.Watcher
    64  	nsName          string
    65  	nsFile          string
    66  	addr            *net.TCPAddr
    67  	done            chan struct{}
    68  	conns           chan net.Conn
    69  	logger          logger.Logger
    70  	listenerCreator ListenerCreator
    71  }
    72  
    73  func (l *nsListener) tearDown() {
    74  	if l.listener != nil {
    75  		l.logger.Info("Destroying listener")
    76  		l.listener.Close()
    77  		l.listener = nil
    78  	}
    79  }
    80  
    81  func (l *nsListener) setUp() bool {
    82  	l.logger.Infof("Creating listener in namespace %v", l.nsName)
    83  	if err := l.watcher.Add(l.nsFile); err != nil {
    84  		l.logger.Infof("Can't watch the file (will try again): %v", err)
    85  		return false
    86  	}
    87  	listener, err := makeListener(l.nsName, l.listenerCreator)
    88  	if err != nil {
    89  		l.logger.Infof("Can't create TCP listener (will try again): %v", err)
    90  		return false
    91  	}
    92  	l.listener = listener
    93  	go accept(l.listener, l.conns, l.logger)
    94  	return true
    95  }
    96  
    97  func (l *nsListener) watch() {
    98  	var mounted bool
    99  	if hasMount(l.nsFile, l.logger) {
   100  		mounted = l.setUp()
   101  	}
   102  
   103  	for {
   104  		select {
   105  		case <-l.done:
   106  			l.tearDown()
   107  			go func() {
   108  				// Drain the events, otherwise closing the watcher will get stuck
   109  				for range l.watcher.Events {
   110  				}
   111  			}()
   112  			l.watcher.Close()
   113  			close(l.conns)
   114  			return
   115  		case ev := <-l.watcher.Events:
   116  			if ev.Name != l.nsFile {
   117  				continue
   118  			}
   119  			if ev.Op&fsnotify.Create == fsnotify.Create {
   120  				if mounted || !l.waitForMount() {
   121  					continue
   122  				}
   123  				mounted = l.setUp()
   124  			}
   125  			if ev.Op&fsnotify.Remove == fsnotify.Remove {
   126  				l.tearDown()
   127  				mounted = false
   128  			}
   129  		}
   130  	}
   131  }
   132  
   133  func (l *nsListener) setupWatch() error {
   134  	w, err := fsnotify.NewWatcher()
   135  	if err != nil {
   136  		return err
   137  	}
   138  	if err = w.Add(filepath.Dir(l.nsFile)); err != nil {
   139  		return err
   140  	}
   141  
   142  	l.watcher = w
   143  	go l.watch()
   144  	return nil
   145  }
   146  
   147  func newNSListenerWithDir(nsDir, nsName string, addr *net.TCPAddr, logger logger.Logger,
   148  	listenerCreator ListenerCreator) (net.Listener, error) {
   149  	if listenerCreator == nil {
   150  		return nil, fmt.Errorf("newNSListenerWithDir received nil listenerCreator")
   151  	}
   152  	l := &nsListener{
   153  		nsName:          nsName,
   154  		nsFile:          filepath.Join(nsDir, nsName),
   155  		addr:            addr,
   156  		done:            make(chan struct{}),
   157  		conns:           make(chan net.Conn),
   158  		logger:          logger,
   159  		listenerCreator: listenerCreator,
   160  	}
   161  	if err := l.setupWatch(); err != nil {
   162  		return nil, err
   163  	}
   164  
   165  	return l, nil
   166  }
   167  
   168  // Accept accepts a connection on the listener socket.
   169  func (l *nsListener) Accept() (net.Conn, error) {
   170  	if c, ok := <-l.conns; ok {
   171  		return c, nil
   172  	}
   173  	return nil, errors.New("listener closed")
   174  }
   175  
   176  // Close closes the listener.
   177  func (l *nsListener) Close() error {
   178  	close(l.done)
   179  	return nil
   180  }
   181  
   182  // Addr returns the local address of the listener.
   183  func (l *nsListener) Addr() net.Addr {
   184  	return l.addr
   185  }
   186  
   187  func hasMountInProcMounts(r io.Reader, mountPoint string) bool {
   188  	// Kernels up to 3.18 export the namespace via procfs and later ones via nsfs
   189  	fsTypes := map[string]bool{"proc": true, "nsfs": true}
   190  
   191  	scanner := bufio.NewScanner(r)
   192  	for scanner.Scan() {
   193  		l := scanner.Text()
   194  		comps := strings.SplitN(l, " ", 3)
   195  		if len(comps) != 3 || !fsTypes[comps[0]] {
   196  			continue
   197  		}
   198  		if comps[1] == mountPoint {
   199  			return true
   200  		}
   201  	}
   202  
   203  	return false
   204  }
   205  
   206  func getNsDirFromProcMounts(r io.Reader) (string, error) {
   207  	// Newer EOS versions mount netns under /run
   208  	dirs := map[string]bool{"/var/run/netns": true, "/run/netns": true}
   209  
   210  	scanner := bufio.NewScanner(r)
   211  	for scanner.Scan() {
   212  		l := scanner.Text()
   213  		comps := strings.SplitN(l, " ", 3)
   214  		if len(comps) != 3 || !dirs[comps[1]] {
   215  			continue
   216  		}
   217  		return comps[1], nil
   218  	}
   219  
   220  	return "", errors.New("can't find the netns mount dir")
   221  }