github.com/Cloud-Foundations/Dominator@v0.3.4/lib/net/reverseconnection/listener.go (about)

     1  package reverseconnection
     2  
     3  import (
     4  	"bufio"
     5  	"encoding/json"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"math/rand"
    10  	"net"
    11  	"net/http"
    12  	"os"
    13  	"path"
    14  	"time"
    15  
    16  	libjson "github.com/Cloud-Foundations/Dominator/lib/json"
    17  	"github.com/Cloud-Foundations/Dominator/lib/log"
    18  	"github.com/Cloud-Foundations/Dominator/lib/log/prefixlogger"
    19  	libnet "github.com/Cloud-Foundations/Dominator/lib/net"
    20  )
    21  
    22  const (
    23  	configDirectory = "/etc/reverse-listeners"
    24  )
    25  
    26  var (
    27  	errorNotFound = errors.New("HTTP method not found")
    28  	errorLoopback = errors.New("loopback address")
    29  )
    30  
    31  func getIp4Address(conn net.Conn) (ip4Address, error) {
    32  	remoteAddr := conn.RemoteAddr()
    33  	var zero ip4Address
    34  	if remoteAddr.Network() != "tcp" {
    35  		return zero, errors.New("wrong network type: " + remoteAddr.Network())
    36  	}
    37  	remoteHost, _, err := net.SplitHostPort(remoteAddr.String())
    38  	if err != nil {
    39  		return zero, err
    40  	}
    41  	return getIp4AddressFromAddress(remoteHost)
    42  }
    43  
    44  func getIp4AddressFromAddress(address string) (ip4Address, error) {
    45  	ip := net.ParseIP(address)
    46  	if ip == nil {
    47  		return ip4Address{}, errors.New("failed to parse: " + address)
    48  	}
    49  	if ip.IsLoopback() {
    50  		return ip4Address{}, errorLoopback
    51  	}
    52  	ip = ip.To4()
    53  	if ip == nil {
    54  		return ip4Address{}, errors.New(address + " is not IPv4")
    55  	}
    56  	return ip4Address{ip[0], ip[1], ip[2], ip[3]}, nil
    57  }
    58  
    59  func listen(network string, portNumber uint, logger log.DebugLogger) (
    60  	*Listener, error) {
    61  	rListener, err := libnet.ListenWithReuse(network,
    62  		fmt.Sprintf(":%d", portNumber))
    63  	if err != nil {
    64  		return nil, fmt.Errorf("error creating %s listener: %s", network, err)
    65  	}
    66  	acceptChannel := make(chan acceptEvent, 1)
    67  	listener := &Listener{
    68  		listener:      rListener,
    69  		portNumber:    portNumber,
    70  		logger:        logger,
    71  		acceptChannel: acceptChannel,
    72  		connectionMap: make(map[ip4Address]uint),
    73  	}
    74  	go listener.listen(acceptChannel)
    75  	return listener, nil
    76  }
    77  
    78  func sleep(minInterval, maxInterval time.Duration) {
    79  	jit := (maxInterval - minInterval) * time.Duration((rand.Intn(1000))) / 1000
    80  	time.Sleep(minInterval + jit)
    81  }
    82  
    83  func (conn *listenerConn) Close() error {
    84  	if ip, err := getIp4Address(conn); err != nil {
    85  		if err != errorLoopback {
    86  			conn.listener.logger.Println(err)
    87  		}
    88  	} else {
    89  		conn.listener.forget(conn.RemoteAddr().String(), ip)
    90  	}
    91  	return conn.TCPConn.Close()
    92  }
    93  
    94  func (l *Listener) accept() (*listenerConn, error) {
    95  	if l.isClosed() {
    96  		return nil, errors.New("listener is closed")
    97  	}
    98  	event := <-l.acceptChannel
    99  	return event.conn, event.error
   100  }
   101  
   102  func (l *Listener) close() error {
   103  	l.closedLock.Lock()
   104  	l.closed = true
   105  	l.closedLock.Unlock()
   106  	return l.listener.Close()
   107  }
   108  
   109  func (l *Listener) forget(remoteHost string, ip ip4Address) {
   110  	l.logger.Debugf(1, "reverse listener: forget(%s)\n", remoteHost)
   111  	l.connectionMapLock.Lock()
   112  	defer l.connectionMapLock.Unlock()
   113  	if numConn := l.connectionMap[ip]; numConn < 1 {
   114  		panic("unknown connection from: " + remoteHost)
   115  	} else {
   116  		l.connectionMap[ip] = numConn - 1
   117  	}
   118  }
   119  
   120  func (l *Listener) isClosed() bool {
   121  	l.closedLock.Lock()
   122  	defer l.closedLock.Unlock()
   123  	return l.closed
   124  }
   125  
   126  func (l *Listener) listen(acceptChannel chan<- acceptEvent) {
   127  	for {
   128  		if l.isClosed() {
   129  			break
   130  		}
   131  		conn, err := l.listener.Accept()
   132  		if err != nil {
   133  			l.logger.Printf(
   134  				"error accepting connection on reverse listener: %s\n", err)
   135  			continue
   136  		}
   137  		tcpConn, ok := conn.(libnet.TCPConn)
   138  		if !ok {
   139  			conn.Close()
   140  			l.logger.Println("rejecting non-TCP connection")
   141  			continue
   142  		}
   143  		l.remember(conn)
   144  		acceptChannel <- acceptEvent{
   145  			&listenerConn{TCPConn: tcpConn, listener: l}, err}
   146  	}
   147  }
   148  
   149  func (l *Listener) remember(conn net.Conn) {
   150  	l.logger.Debugf(1, "reverse listener: remember(%s): %p\n",
   151  		conn.RemoteAddr(), conn)
   152  	if ip, err := getIp4Address(conn); err == nil {
   153  		l.connectionMapLock.Lock()
   154  		defer l.connectionMapLock.Unlock()
   155  		l.connectionMap[ip]++
   156  	}
   157  }
   158  
   159  func (l *Listener) requestConnections(serviceName string) error {
   160  	var config ReverseListenerConfig
   161  	filename := path.Join(configDirectory, serviceName)
   162  	if err := libjson.ReadFromFile(filename, &config); err != nil {
   163  		if os.IsNotExist(err) {
   164  			return nil
   165  		}
   166  		return err
   167  	}
   168  	if config.Network == "" {
   169  		config.Network = "tcp"
   170  	}
   171  	if config.MinimumInterval < time.Minute {
   172  		config.MinimumInterval = time.Minute
   173  	}
   174  	if config.MaximumInterval <= config.MinimumInterval {
   175  		config.MaximumInterval = config.MinimumInterval * 11 / 10
   176  	}
   177  	serverHost, _, err := net.SplitHostPort(config.ServerAddress)
   178  	if err != nil {
   179  		return err
   180  	}
   181  	go l.connectLoop(config, serverHost)
   182  	return nil
   183  }
   184  
   185  func (l *Listener) connectLoop(config ReverseListenerConfig,
   186  	serverHost string) {
   187  	logger := prefixlogger.New("reverse listener: "+config.ServerAddress+": ",
   188  		l.logger)
   189  	logger.Debugf(0, "starting loop, min interval: %s, max interval: %s\n",
   190  		config.MinimumInterval, config.MaximumInterval)
   191  	for {
   192  		sleep(config.MinimumInterval, config.MaximumInterval)
   193  		addrs, err := net.LookupHost(serverHost)
   194  		if err != nil {
   195  			logger.Println(err)
   196  			continue
   197  		}
   198  		foundExisting := false
   199  		for _, addr := range addrs {
   200  			if ip, err := getIp4AddressFromAddress(addr); err != nil {
   201  				continue
   202  			} else {
   203  				l.connectionMapLock.Lock()
   204  				if l.connectionMap[ip] > 0 {
   205  					foundExisting = true
   206  				}
   207  				l.connectionMapLock.Unlock()
   208  			}
   209  			if foundExisting {
   210  				break
   211  			}
   212  		}
   213  		if foundExisting {
   214  			continue
   215  		}
   216  		message, err := l.connect(config.Network, config.ServerAddress,
   217  			config.MinimumInterval>>1, logger)
   218  		if err != nil {
   219  			if err != errorNotFound {
   220  				logger.Println(err)
   221  			}
   222  			continue
   223  		}
   224  		if message.MinimumInterval >= time.Second {
   225  			newMaximumInterval := message.MaximumInterval
   226  			if newMaximumInterval <= message.MinimumInterval {
   227  				newMaximumInterval = message.MinimumInterval * 11 / 10
   228  			}
   229  			if message.MinimumInterval != config.MinimumInterval ||
   230  				newMaximumInterval != config.MaximumInterval {
   231  				logger.Debugf(0,
   232  					"min interval: %s -> %s, max interval: %s -> %s\n",
   233  					config.MinimumInterval, message.MinimumInterval,
   234  					config.MaximumInterval, newMaximumInterval)
   235  			}
   236  			config.MinimumInterval = message.MinimumInterval
   237  			config.MaximumInterval = newMaximumInterval
   238  		}
   239  	}
   240  }
   241  
   242  func (l *Listener) connect(network, serverAddress string, timeout time.Duration,
   243  	logger log.DebugLogger) (*reverseDialerMessage, error) {
   244  	logger.Debugln(0, "dialing")
   245  	localAddr := fmt.Sprintf(":%d", l.portNumber)
   246  	deadline := time.Now().Add(timeout)
   247  	rawConn, err := libnet.BindAndDial(network, localAddr, serverAddress,
   248  		timeout)
   249  	if err != nil {
   250  		return nil, err
   251  	}
   252  	defer func() {
   253  		if rawConn != nil {
   254  			rawConn.Close()
   255  		}
   256  	}()
   257  	tcpConn, ok := rawConn.(libnet.TCPConn)
   258  	if !ok {
   259  		return nil, errors.New("rejecting non-TCP connection")
   260  	}
   261  	if err := rawConn.SetDeadline(deadline); err != nil {
   262  		return nil, errors.New("error setting deadline: " + err.Error())
   263  	}
   264  	logger.Debugln(0, "sending HTTP CONNECT")
   265  	_, err = io.WriteString(rawConn, "CONNECT "+urlPath+" HTTP/1.0\n\n")
   266  	if err != nil {
   267  		return nil, errors.New("error writing CONNECT: " + err.Error())
   268  	}
   269  	reader := bufio.NewReader(rawConn)
   270  	resp, err := http.ReadResponse(reader, &http.Request{Method: "CONNECT"})
   271  	if err != nil {
   272  		return nil, errors.New("error reading HTTP response: " + err.Error())
   273  	}
   274  	if resp.StatusCode == http.StatusNotFound {
   275  		return nil, errorNotFound
   276  	}
   277  	if resp.StatusCode != http.StatusOK || resp.Status != connectString {
   278  		return nil, errors.New("unexpected HTTP response: " + resp.Status)
   279  	}
   280  	decoder := json.NewDecoder(reader)
   281  	var message reverseDialerMessage
   282  	if err := decoder.Decode(&message); err != nil {
   283  		return nil, errors.New("error decoding message: " + err.Error())
   284  	}
   285  	// Send all-clear to other side to ensure nothing further is buffered.
   286  	buffer := make([]byte, 1)
   287  	if _, err := rawConn.Write(buffer); err != nil {
   288  		return nil, errors.New("error writing sync byte: " + err.Error())
   289  	}
   290  	if err := rawConn.SetDeadline(time.Time{}); err != nil {
   291  		return nil, errors.New("error resetting deadline: " + err.Error())
   292  	}
   293  	logger.Println("made connection, waiting for remote consumption")
   294  	// Wait for other side to consume.
   295  	if _, err := rawConn.Read(buffer); err != nil {
   296  		return nil, errors.New("error reading sync byte: " + err.Error())
   297  	}
   298  	logger.Println("remote has consumed, injecting to local listener")
   299  	l.remember(rawConn)
   300  	l.acceptChannel <- acceptEvent{
   301  		&listenerConn{TCPConn: tcpConn, listener: l}, nil}
   302  	rawConn = nil // Prevent Close on return.
   303  	return &message, nil
   304  }