github.com/cloud-foundations/dominator@v0.0.0-20221004181915-6e4fee580046/lib/net/reverseconnection/listener.go (about)

     1  package reverseconnection
     2  
     3  import (
     4  	"bufio"
     5  	"encoding/json"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"math/rand"
    10  	"net"
    11  	"net/http"
    12  	"os"
    13  	"path"
    14  	"time"
    15  
    16  	libjson "github.com/Cloud-Foundations/Dominator/lib/json"
    17  	"github.com/Cloud-Foundations/Dominator/lib/log"
    18  	"github.com/Cloud-Foundations/Dominator/lib/log/prefixlogger"
    19  	libnet "github.com/Cloud-Foundations/Dominator/lib/net"
    20  )
    21  
    22  const (
    23  	configDirectory = "/etc/reverse-listeners"
    24  )
    25  
    26  var (
    27  	errorNotFound = errors.New("HTTP method not found")
    28  	errorLoopback = errors.New("loopback address")
    29  )
    30  
    31  func getIp4Address(conn net.Conn) (ip4Address, error) {
    32  	remoteAddr := conn.RemoteAddr()
    33  	var zero ip4Address
    34  	if remoteAddr.Network() != "tcp" {
    35  		return zero, errors.New("wrong network type: " + remoteAddr.Network())
    36  	}
    37  	remoteHost, _, err := net.SplitHostPort(remoteAddr.String())
    38  	if err != nil {
    39  		return zero, err
    40  	}
    41  	return getIp4AddressFromAddress(remoteHost)
    42  }
    43  
    44  func getIp4AddressFromAddress(address string) (ip4Address, error) {
    45  	ip := net.ParseIP(address)
    46  	if ip == nil {
    47  		return ip4Address{}, errors.New("failed to parse: " + address)
    48  	}
    49  	if ip.IsLoopback() {
    50  		return ip4Address{}, errorLoopback
    51  	}
    52  	ip = ip.To4()
    53  	if ip == nil {
    54  		return ip4Address{}, errors.New(address + " is not IPv4")
    55  	}
    56  	return ip4Address{ip[0], ip[1], ip[2], ip[3]}, nil
    57  }
    58  
    59  func listen(network string, portNumber uint, logger log.DebugLogger) (
    60  	*Listener, error) {
    61  	rListener, err := libnet.ListenWithReuse(network,
    62  		fmt.Sprintf(":%d", portNumber))
    63  	if err != nil {
    64  		return nil, fmt.Errorf("error creating %s listener: %s", network, err)
    65  	}
    66  	acceptChannel := make(chan acceptEvent, 1)
    67  	listener := &Listener{
    68  		listener:      rListener,
    69  		portNumber:    portNumber,
    70  		logger:        logger,
    71  		acceptChannel: acceptChannel,
    72  		connectionMap: make(map[ip4Address]uint),
    73  	}
    74  	go listener.listen(acceptChannel)
    75  	return listener, nil
    76  }
    77  
    78  func sleep(minInterval, maxInterval time.Duration) {
    79  	jit := (maxInterval - minInterval) * time.Duration((rand.Intn(1000))) / 1000
    80  	time.Sleep(minInterval + jit)
    81  }
    82  
    83  func (conn *listenerConn) Close() error {
    84  	if ip, err := getIp4Address(conn); err != nil {
    85  		if err != errorLoopback {
    86  			conn.listener.logger.Println(err)
    87  		}
    88  	} else {
    89  		conn.listener.forget(conn.RemoteAddr().String(), ip)
    90  	}
    91  	return conn.TCPConn.Close()
    92  }
    93  
    94  func (l *Listener) accept() (*listenerConn, error) {
    95  	if l.closed {
    96  		return nil, errors.New("listener is closed")
    97  	}
    98  	event := <-l.acceptChannel
    99  	return event.conn, event.error
   100  }
   101  
   102  func (l *Listener) close() error {
   103  	l.closed = true
   104  	return l.listener.Close()
   105  }
   106  
   107  func (l *Listener) forget(remoteHost string, ip ip4Address) {
   108  	l.logger.Debugf(1, "reverse listener: forget(%s)\n", remoteHost)
   109  	l.connectionMapLock.Lock()
   110  	defer l.connectionMapLock.Unlock()
   111  	if numConn := l.connectionMap[ip]; numConn < 1 {
   112  		panic("unknown connection from: " + remoteHost)
   113  	} else {
   114  		l.connectionMap[ip] = numConn - 1
   115  	}
   116  }
   117  
   118  func (l *Listener) listen(acceptChannel chan<- acceptEvent) {
   119  	for {
   120  		if l.closed {
   121  			break
   122  		}
   123  		conn, err := l.listener.Accept()
   124  		tcpConn, ok := conn.(libnet.TCPConn)
   125  		if !ok {
   126  			conn.Close()
   127  			l.logger.Println("rejecting non-TCP connection")
   128  			continue
   129  		}
   130  		l.remember(conn)
   131  		acceptChannel <- acceptEvent{
   132  			&listenerConn{TCPConn: tcpConn, listener: l}, err}
   133  	}
   134  }
   135  
   136  func (l *Listener) remember(conn net.Conn) {
   137  	l.logger.Debugf(1, "reverse listener: remember(%s): %p\n",
   138  		conn.RemoteAddr(), conn)
   139  	if ip, err := getIp4Address(conn); err == nil {
   140  		l.connectionMapLock.Lock()
   141  		defer l.connectionMapLock.Unlock()
   142  		l.connectionMap[ip]++
   143  	}
   144  }
   145  
   146  func (l *Listener) requestConnections(serviceName string) error {
   147  	var config ReverseListenerConfig
   148  	filename := path.Join(configDirectory, serviceName)
   149  	if err := libjson.ReadFromFile(filename, &config); err != nil {
   150  		if os.IsNotExist(err) {
   151  			return nil
   152  		}
   153  		return err
   154  	}
   155  	if config.Network == "" {
   156  		config.Network = "tcp"
   157  	}
   158  	if config.MinimumInterval < time.Minute {
   159  		config.MinimumInterval = time.Minute
   160  	}
   161  	if config.MaximumInterval <= config.MinimumInterval {
   162  		config.MaximumInterval = config.MinimumInterval * 11 / 10
   163  	}
   164  	serverHost, _, err := net.SplitHostPort(config.ServerAddress)
   165  	if err != nil {
   166  		return err
   167  	}
   168  	go l.connectLoop(config, serverHost)
   169  	return nil
   170  }
   171  
   172  func (l *Listener) connectLoop(config ReverseListenerConfig,
   173  	serverHost string) {
   174  	logger := prefixlogger.New("reverse listener: "+config.ServerAddress+": ",
   175  		l.logger)
   176  	logger.Debugf(0, "starting loop, min interval: %s, max interval: %s\n",
   177  		config.MinimumInterval, config.MaximumInterval)
   178  	for {
   179  		sleep(config.MinimumInterval, config.MaximumInterval)
   180  		addrs, err := net.LookupHost(serverHost)
   181  		if err != nil {
   182  			logger.Println(err)
   183  			continue
   184  		}
   185  		foundExisting := false
   186  		for _, addr := range addrs {
   187  			if ip, err := getIp4AddressFromAddress(addr); err != nil {
   188  				continue
   189  			} else {
   190  				l.connectionMapLock.Lock()
   191  				if l.connectionMap[ip] > 0 {
   192  					foundExisting = true
   193  				}
   194  				l.connectionMapLock.Unlock()
   195  			}
   196  			if foundExisting {
   197  				break
   198  			}
   199  		}
   200  		if foundExisting {
   201  			continue
   202  		}
   203  		message, err := l.connect(config.Network, config.ServerAddress,
   204  			config.MinimumInterval>>1, logger)
   205  		if err != nil {
   206  			if err != errorNotFound {
   207  				logger.Println(err)
   208  			}
   209  			continue
   210  		}
   211  		if message.MinimumInterval >= time.Second {
   212  			newMaximumInterval := message.MaximumInterval
   213  			if newMaximumInterval <= message.MinimumInterval {
   214  				newMaximumInterval = message.MinimumInterval * 11 / 10
   215  			}
   216  			if message.MinimumInterval != config.MinimumInterval ||
   217  				newMaximumInterval != config.MaximumInterval {
   218  				logger.Debugf(0,
   219  					"min interval: %s -> %s, max interval: %s -> %s\n",
   220  					config.MinimumInterval, message.MinimumInterval,
   221  					config.MaximumInterval, newMaximumInterval)
   222  			}
   223  			config.MinimumInterval = message.MinimumInterval
   224  			config.MaximumInterval = newMaximumInterval
   225  		}
   226  	}
   227  }
   228  
   229  func (l *Listener) connect(network, serverAddress string, timeout time.Duration,
   230  	logger log.DebugLogger) (*reverseDialerMessage, error) {
   231  	logger.Debugln(0, "dialing")
   232  	localAddr := fmt.Sprintf(":%d", l.portNumber)
   233  	deadline := time.Now().Add(timeout)
   234  	rawConn, err := libnet.BindAndDial(network, localAddr, serverAddress,
   235  		timeout)
   236  	if err != nil {
   237  		return nil, err
   238  	}
   239  	defer func() {
   240  		if rawConn != nil {
   241  			rawConn.Close()
   242  		}
   243  	}()
   244  	tcpConn, ok := rawConn.(libnet.TCPConn)
   245  	if !ok {
   246  		return nil, errors.New("rejecting non-TCP connection")
   247  	}
   248  	if err := rawConn.SetDeadline(deadline); err != nil {
   249  		return nil, errors.New("error setting deadline: " + err.Error())
   250  	}
   251  	logger.Debugln(0, "sending HTTP CONNECT")
   252  	_, err = io.WriteString(rawConn, "CONNECT "+urlPath+" HTTP/1.0\n\n")
   253  	if err != nil {
   254  		return nil, errors.New("error writing CONNECT: " + err.Error())
   255  	}
   256  	reader := bufio.NewReader(rawConn)
   257  	resp, err := http.ReadResponse(reader, &http.Request{Method: "CONNECT"})
   258  	if err != nil {
   259  		return nil, errors.New("error reading HTTP response: " + err.Error())
   260  	}
   261  	if resp.StatusCode == http.StatusNotFound {
   262  		return nil, errorNotFound
   263  	}
   264  	if resp.StatusCode != http.StatusOK || resp.Status != connectString {
   265  		return nil, errors.New("unexpected HTTP response: " + resp.Status)
   266  	}
   267  	decoder := json.NewDecoder(reader)
   268  	var message reverseDialerMessage
   269  	if err := decoder.Decode(&message); err != nil {
   270  		return nil, errors.New("error decoding message: " + err.Error())
   271  	}
   272  	// Send all-clear to other side to ensure nothing further is buffered.
   273  	buffer := make([]byte, 1)
   274  	if _, err := rawConn.Write(buffer); err != nil {
   275  		return nil, errors.New("error writing sync byte: " + err.Error())
   276  	}
   277  	if err := rawConn.SetDeadline(time.Time{}); err != nil {
   278  		return nil, errors.New("error resetting deadline: " + err.Error())
   279  	}
   280  	logger.Println("made connection, waiting for remote consumption")
   281  	// Wait for other side to consume.
   282  	if _, err := rawConn.Read(buffer); err != nil {
   283  		return nil, errors.New("error reading sync byte: " + err.Error())
   284  	}
   285  	logger.Println("remote has consumed, injecting to local listener")
   286  	l.remember(rawConn)
   287  	l.acceptChannel <- acceptEvent{
   288  		&listenerConn{TCPConn: tcpConn, listener: l}, nil}
   289  	rawConn = nil // Prevent Close on return.
   290  	return &message, nil
   291  }