github.com/cloud-foundations/dominator@v0.0.0-20221004181915-6e4fee580046/lib/net/reverseconnection/dialer.go (about) 1 package reverseconnection 2 3 import ( 4 "encoding/json" 5 "io" 6 "net" 7 "net/http" 8 "time" 9 10 "github.com/Cloud-Foundations/Dominator/lib/log" 11 "github.com/Cloud-Foundations/Dominator/lib/log/nulllogger" 12 ) 13 14 func newDialer(rawDialer *net.Dialer, serveMux *http.ServeMux, 15 minimumInterval, maximumInterval time.Duration, 16 logger log.DebugLogger) *Dialer { 17 if rawDialer == nil { 18 rawDialer = &net.Dialer{} 19 } 20 if serveMux == nil { 21 serveMux = http.DefaultServeMux 22 } 23 if minimumInterval < time.Second { 24 minimumInterval = time.Second 25 } 26 if maximumInterval <= minimumInterval { 27 maximumInterval = 0 28 } 29 if logger == nil { 30 logger = nulllogger.New() 31 } 32 dialer := &Dialer{ 33 dialer: rawDialer, 34 minimumInterval: minimumInterval, 35 maximumInterval: maximumInterval, 36 logger: logger, 37 connectionMap: make(map[string]net.Conn), 38 } 39 serveMux.HandleFunc(urlPath, dialer.connectHandler) 40 return dialer 41 } 42 43 // Add a connection to the map. Returns true if added, false if duplicate. 44 func (d *Dialer) add(address string, conn net.Conn) bool { 45 d.connectionMapLock.Lock() 46 defer d.connectionMapLock.Unlock() 47 if _, ok := d.connectionMap[address]; ok { 48 return false 49 } else { 50 d.connectionMap[address] = conn 51 return true 52 } 53 } 54 55 func (d *Dialer) dial(network, address string) (net.Conn, error) { 56 if network != "tcp" || len(d.connectionMap) < 1 { 57 return d.dialer.Dial(network, address) 58 } 59 if conn, err := d.lookupDial(address); err != nil { 60 return nil, err 61 } else if conn != nil { 62 return conn, nil 63 } 64 return d.dialer.Dial(network, address) 65 } 66 67 func (d *Dialer) lookupDial(address string) (net.Conn, error) { 68 host, port, err := net.SplitHostPort(address) 69 if err != nil { 70 return nil, err 71 } 72 addrs, err := net.LookupHost(host) 73 if err != nil { 74 return nil, err 75 } 76 if len(addrs) < 1 { 77 return nil, nil 78 } 79 for _, addr := range addrs { 80 oneAddress := net.JoinHostPort(addr, port) 81 if conn := d.lookup(oneAddress); conn != nil { 82 d.logger.Debugf(0, "Consuming reverse dialer connection from: %s\n", 83 oneAddress) 84 // Tell other side we are ready for them to accept. 85 buffer := make([]byte, 1) 86 if _, err := conn.Write(buffer); err != nil { 87 d.logger.Printf("error sending please-accept message: %s\n", 88 err) 89 return nil, nil 90 } 91 return conn, nil 92 } 93 } 94 return nil, nil 95 } 96 97 // Lookup a connection and remove it from the map. Caller must consume. 98 func (d *Dialer) lookup(address string) net.Conn { 99 d.connectionMapLock.Lock() 100 defer d.connectionMapLock.Unlock() 101 if conn, ok := d.connectionMap[address]; ok { 102 delete(d.connectionMap, address) 103 return conn 104 } 105 return nil 106 } 107 108 func (d *Dialer) connectHandler(w http.ResponseWriter, req *http.Request) { 109 d.logger.Debugf(1, "%s request from remote: %s\n", 110 req.Method, req.RemoteAddr) 111 if req.Method != "CONNECT" { 112 w.Header().Set("Content-Type", "text/plain; charset=utf-8") 113 w.WriteHeader(http.StatusMethodNotAllowed) 114 d.logger.Debugf(0, "rejecting method=%s from remote: %s\n", 115 req.Method, req.RemoteAddr) 116 return 117 } 118 hijacker, ok := w.(http.Hijacker) 119 if !ok { 120 w.Header().Set("Content-Type", "text/plain; charset=utf-8") 121 w.WriteHeader(http.StatusInternalServerError) 122 d.logger.Println("not a hijacker ", req.RemoteAddr) 123 return 124 } 125 d.connectionMapLock.Lock() 126 if conn, ok := d.connectionMap[req.RemoteAddr]; ok { 127 // We have nothing to detect if the remote closed, so assume the remote 128 // is retrying and close the old (unused) connection. 129 delete(d.connectionMap, req.RemoteAddr) 130 d.connectionMapLock.Unlock() 131 conn.Close() 132 d.logger.Debugf(0, "closed unused duplicate remote: %s\n", 133 req.RemoteAddr) 134 } else { 135 d.connectionMapLock.Unlock() 136 } 137 conn, _, err := hijacker.Hijack() 138 if err != nil { 139 w.Header().Set("Content-Type", "text/plain; charset=utf-8") 140 w.WriteHeader(http.StatusInternalServerError) 141 d.logger.Printf("rpc hijacking %s: %s\n", req.RemoteAddr, err) 142 return 143 } 144 defer func() { 145 if conn != nil { 146 conn.Close() 147 } 148 }() 149 _, err = io.WriteString(conn, "HTTP/1.0 "+connectString+"\n\n") 150 if err != nil { 151 d.logger.Println("error writing connect message: ", err.Error()) 152 return 153 } 154 message := reverseDialerMessage{ 155 MinimumInterval: d.minimumInterval, 156 MaximumInterval: d.maximumInterval, 157 } 158 encoder := json.NewEncoder(conn) 159 encoder.SetIndent("", " ") 160 if err := encoder.Encode(message); err != nil { 161 d.logger.Printf("error writing ReverseDialerMessage: %s\n", err) 162 return 163 } 164 // Ensure we don't write anything else until the other end has drained its 165 // buffer. 166 buffer := make([]byte, 1) 167 d.logger.Debugf(1, "waiting for sync byte from remote: %s\n", 168 req.RemoteAddr) 169 if _, err := conn.Read(buffer); err != nil { 170 d.logger.Printf("error reading sync byte from: %s: %s\n", 171 req.RemoteAddr, err) 172 return 173 } 174 if d.add(req.RemoteAddr, conn) { 175 d.logger.Debugf(0, "Registered reverse dialer connection from: %s\n", 176 req.RemoteAddr) 177 } else { 178 d.logger.Printf( 179 "Closing duplicate reverse dialer connection from: %s\n", 180 req.RemoteAddr) 181 } 182 conn = nil 183 }