github.com/cloud-foundations/dominator@v0.0.0-20221004181915-6e4fee580046/lib/net/reverseconnection/dialer.go (about)

     1  package reverseconnection
     2  
     3  import (
     4  	"encoding/json"
     5  	"io"
     6  	"net"
     7  	"net/http"
     8  	"time"
     9  
    10  	"github.com/Cloud-Foundations/Dominator/lib/log"
    11  	"github.com/Cloud-Foundations/Dominator/lib/log/nulllogger"
    12  )
    13  
    14  func newDialer(rawDialer *net.Dialer, serveMux *http.ServeMux,
    15  	minimumInterval, maximumInterval time.Duration,
    16  	logger log.DebugLogger) *Dialer {
    17  	if rawDialer == nil {
    18  		rawDialer = &net.Dialer{}
    19  	}
    20  	if serveMux == nil {
    21  		serveMux = http.DefaultServeMux
    22  	}
    23  	if minimumInterval < time.Second {
    24  		minimumInterval = time.Second
    25  	}
    26  	if maximumInterval <= minimumInterval {
    27  		maximumInterval = 0
    28  	}
    29  	if logger == nil {
    30  		logger = nulllogger.New()
    31  	}
    32  	dialer := &Dialer{
    33  		dialer:          rawDialer,
    34  		minimumInterval: minimumInterval,
    35  		maximumInterval: maximumInterval,
    36  		logger:          logger,
    37  		connectionMap:   make(map[string]net.Conn),
    38  	}
    39  	serveMux.HandleFunc(urlPath, dialer.connectHandler)
    40  	return dialer
    41  }
    42  
    43  // Add a connection to the map. Returns true if added, false if duplicate.
    44  func (d *Dialer) add(address string, conn net.Conn) bool {
    45  	d.connectionMapLock.Lock()
    46  	defer d.connectionMapLock.Unlock()
    47  	if _, ok := d.connectionMap[address]; ok {
    48  		return false
    49  	} else {
    50  		d.connectionMap[address] = conn
    51  		return true
    52  	}
    53  }
    54  
    55  func (d *Dialer) dial(network, address string) (net.Conn, error) {
    56  	if network != "tcp" || len(d.connectionMap) < 1 {
    57  		return d.dialer.Dial(network, address)
    58  	}
    59  	if conn, err := d.lookupDial(address); err != nil {
    60  		return nil, err
    61  	} else if conn != nil {
    62  		return conn, nil
    63  	}
    64  	return d.dialer.Dial(network, address)
    65  }
    66  
    67  func (d *Dialer) lookupDial(address string) (net.Conn, error) {
    68  	host, port, err := net.SplitHostPort(address)
    69  	if err != nil {
    70  		return nil, err
    71  	}
    72  	addrs, err := net.LookupHost(host)
    73  	if err != nil {
    74  		return nil, err
    75  	}
    76  	if len(addrs) < 1 {
    77  		return nil, nil
    78  	}
    79  	for _, addr := range addrs {
    80  		oneAddress := net.JoinHostPort(addr, port)
    81  		if conn := d.lookup(oneAddress); conn != nil {
    82  			d.logger.Debugf(0, "Consuming reverse dialer connection from: %s\n",
    83  				oneAddress)
    84  			// Tell other side we are ready for them to accept.
    85  			buffer := make([]byte, 1)
    86  			if _, err := conn.Write(buffer); err != nil {
    87  				d.logger.Printf("error sending please-accept message: %s\n",
    88  					err)
    89  				return nil, nil
    90  			}
    91  			return conn, nil
    92  		}
    93  	}
    94  	return nil, nil
    95  }
    96  
    97  // Lookup a connection and remove it from the map. Caller must consume.
    98  func (d *Dialer) lookup(address string) net.Conn {
    99  	d.connectionMapLock.Lock()
   100  	defer d.connectionMapLock.Unlock()
   101  	if conn, ok := d.connectionMap[address]; ok {
   102  		delete(d.connectionMap, address)
   103  		return conn
   104  	}
   105  	return nil
   106  }
   107  
   108  func (d *Dialer) connectHandler(w http.ResponseWriter, req *http.Request) {
   109  	d.logger.Debugf(1, "%s request from remote: %s\n",
   110  		req.Method, req.RemoteAddr)
   111  	if req.Method != "CONNECT" {
   112  		w.Header().Set("Content-Type", "text/plain; charset=utf-8")
   113  		w.WriteHeader(http.StatusMethodNotAllowed)
   114  		d.logger.Debugf(0, "rejecting method=%s from remote: %s\n",
   115  			req.Method, req.RemoteAddr)
   116  		return
   117  	}
   118  	hijacker, ok := w.(http.Hijacker)
   119  	if !ok {
   120  		w.Header().Set("Content-Type", "text/plain; charset=utf-8")
   121  		w.WriteHeader(http.StatusInternalServerError)
   122  		d.logger.Println("not a hijacker ", req.RemoteAddr)
   123  		return
   124  	}
   125  	d.connectionMapLock.Lock()
   126  	if conn, ok := d.connectionMap[req.RemoteAddr]; ok {
   127  		// We have nothing to detect if the remote closed, so assume the remote
   128  		// is retrying and close the old (unused) connection.
   129  		delete(d.connectionMap, req.RemoteAddr)
   130  		d.connectionMapLock.Unlock()
   131  		conn.Close()
   132  		d.logger.Debugf(0, "closed unused duplicate remote: %s\n",
   133  			req.RemoteAddr)
   134  	} else {
   135  		d.connectionMapLock.Unlock()
   136  	}
   137  	conn, _, err := hijacker.Hijack()
   138  	if err != nil {
   139  		w.Header().Set("Content-Type", "text/plain; charset=utf-8")
   140  		w.WriteHeader(http.StatusInternalServerError)
   141  		d.logger.Printf("rpc hijacking %s: %s\n", req.RemoteAddr, err)
   142  		return
   143  	}
   144  	defer func() {
   145  		if conn != nil {
   146  			conn.Close()
   147  		}
   148  	}()
   149  	_, err = io.WriteString(conn, "HTTP/1.0 "+connectString+"\n\n")
   150  	if err != nil {
   151  		d.logger.Println("error writing connect message: ", err.Error())
   152  		return
   153  	}
   154  	message := reverseDialerMessage{
   155  		MinimumInterval: d.minimumInterval,
   156  		MaximumInterval: d.maximumInterval,
   157  	}
   158  	encoder := json.NewEncoder(conn)
   159  	encoder.SetIndent("", "    ")
   160  	if err := encoder.Encode(message); err != nil {
   161  		d.logger.Printf("error writing ReverseDialerMessage: %s\n", err)
   162  		return
   163  	}
   164  	// Ensure we don't write anything else until the other end has drained its
   165  	// buffer.
   166  	buffer := make([]byte, 1)
   167  	d.logger.Debugf(1, "waiting for sync byte from remote: %s\n",
   168  		req.RemoteAddr)
   169  	if _, err := conn.Read(buffer); err != nil {
   170  		d.logger.Printf("error reading sync byte from: %s: %s\n",
   171  			req.RemoteAddr, err)
   172  		return
   173  	}
   174  	if d.add(req.RemoteAddr, conn) {
   175  		d.logger.Debugf(0, "Registered reverse dialer connection from: %s\n",
   176  			req.RemoteAddr)
   177  	} else {
   178  		d.logger.Printf(
   179  			"Closing duplicate reverse dialer connection from: %s\n",
   180  			req.RemoteAddr)
   181  	}
   182  	conn = nil
   183  }