github.com/cloud-foundations/dominator@v0.0.0-20221004181915-6e4fee580046/lib/net/reverseconnection/listener.go (about) 1 package reverseconnection 2 3 import ( 4 "bufio" 5 "encoding/json" 6 "errors" 7 "fmt" 8 "io" 9 "math/rand" 10 "net" 11 "net/http" 12 "os" 13 "path" 14 "time" 15 16 libjson "github.com/Cloud-Foundations/Dominator/lib/json" 17 "github.com/Cloud-Foundations/Dominator/lib/log" 18 "github.com/Cloud-Foundations/Dominator/lib/log/prefixlogger" 19 libnet "github.com/Cloud-Foundations/Dominator/lib/net" 20 ) 21 22 const ( 23 configDirectory = "/etc/reverse-listeners" 24 ) 25 26 var ( 27 errorNotFound = errors.New("HTTP method not found") 28 errorLoopback = errors.New("loopback address") 29 ) 30 31 func getIp4Address(conn net.Conn) (ip4Address, error) { 32 remoteAddr := conn.RemoteAddr() 33 var zero ip4Address 34 if remoteAddr.Network() != "tcp" { 35 return zero, errors.New("wrong network type: " + remoteAddr.Network()) 36 } 37 remoteHost, _, err := net.SplitHostPort(remoteAddr.String()) 38 if err != nil { 39 return zero, err 40 } 41 return getIp4AddressFromAddress(remoteHost) 42 } 43 44 func getIp4AddressFromAddress(address string) (ip4Address, error) { 45 ip := net.ParseIP(address) 46 if ip == nil { 47 return ip4Address{}, errors.New("failed to parse: " + address) 48 } 49 if ip.IsLoopback() { 50 return ip4Address{}, errorLoopback 51 } 52 ip = ip.To4() 53 if ip == nil { 54 return ip4Address{}, errors.New(address + " is not IPv4") 55 } 56 return ip4Address{ip[0], ip[1], ip[2], ip[3]}, nil 57 } 58 59 func listen(network string, portNumber uint, logger log.DebugLogger) ( 60 *Listener, error) { 61 rListener, err := libnet.ListenWithReuse(network, 62 fmt.Sprintf(":%d", portNumber)) 63 if err != nil { 64 return nil, fmt.Errorf("error creating %s listener: %s", network, err) 65 } 66 acceptChannel := make(chan acceptEvent, 1) 67 listener := &Listener{ 68 listener: rListener, 69 portNumber: portNumber, 70 logger: logger, 71 acceptChannel: acceptChannel, 72 connectionMap: make(map[ip4Address]uint), 73 } 74 go listener.listen(acceptChannel) 75 return listener, nil 76 } 77 78 func sleep(minInterval, maxInterval time.Duration) { 79 jit := (maxInterval - minInterval) * time.Duration((rand.Intn(1000))) / 1000 80 time.Sleep(minInterval + jit) 81 } 82 83 func (conn *listenerConn) Close() error { 84 if ip, err := getIp4Address(conn); err != nil { 85 if err != errorLoopback { 86 conn.listener.logger.Println(err) 87 } 88 } else { 89 conn.listener.forget(conn.RemoteAddr().String(), ip) 90 } 91 return conn.TCPConn.Close() 92 } 93 94 func (l *Listener) accept() (*listenerConn, error) { 95 if l.closed { 96 return nil, errors.New("listener is closed") 97 } 98 event := <-l.acceptChannel 99 return event.conn, event.error 100 } 101 102 func (l *Listener) close() error { 103 l.closed = true 104 return l.listener.Close() 105 } 106 107 func (l *Listener) forget(remoteHost string, ip ip4Address) { 108 l.logger.Debugf(1, "reverse listener: forget(%s)\n", remoteHost) 109 l.connectionMapLock.Lock() 110 defer l.connectionMapLock.Unlock() 111 if numConn := l.connectionMap[ip]; numConn < 1 { 112 panic("unknown connection from: " + remoteHost) 113 } else { 114 l.connectionMap[ip] = numConn - 1 115 } 116 } 117 118 func (l *Listener) listen(acceptChannel chan<- acceptEvent) { 119 for { 120 if l.closed { 121 break 122 } 123 conn, err := l.listener.Accept() 124 tcpConn, ok := conn.(libnet.TCPConn) 125 if !ok { 126 conn.Close() 127 l.logger.Println("rejecting non-TCP connection") 128 continue 129 } 130 l.remember(conn) 131 acceptChannel <- acceptEvent{ 132 &listenerConn{TCPConn: tcpConn, listener: l}, err} 133 } 134 } 135 136 func (l *Listener) remember(conn net.Conn) { 137 l.logger.Debugf(1, "reverse listener: remember(%s): %p\n", 138 conn.RemoteAddr(), conn) 139 if ip, err := getIp4Address(conn); err == nil { 140 l.connectionMapLock.Lock() 141 defer l.connectionMapLock.Unlock() 142 l.connectionMap[ip]++ 143 } 144 } 145 146 func (l *Listener) requestConnections(serviceName string) error { 147 var config ReverseListenerConfig 148 filename := path.Join(configDirectory, serviceName) 149 if err := libjson.ReadFromFile(filename, &config); err != nil { 150 if os.IsNotExist(err) { 151 return nil 152 } 153 return err 154 } 155 if config.Network == "" { 156 config.Network = "tcp" 157 } 158 if config.MinimumInterval < time.Minute { 159 config.MinimumInterval = time.Minute 160 } 161 if config.MaximumInterval <= config.MinimumInterval { 162 config.MaximumInterval = config.MinimumInterval * 11 / 10 163 } 164 serverHost, _, err := net.SplitHostPort(config.ServerAddress) 165 if err != nil { 166 return err 167 } 168 go l.connectLoop(config, serverHost) 169 return nil 170 } 171 172 func (l *Listener) connectLoop(config ReverseListenerConfig, 173 serverHost string) { 174 logger := prefixlogger.New("reverse listener: "+config.ServerAddress+": ", 175 l.logger) 176 logger.Debugf(0, "starting loop, min interval: %s, max interval: %s\n", 177 config.MinimumInterval, config.MaximumInterval) 178 for { 179 sleep(config.MinimumInterval, config.MaximumInterval) 180 addrs, err := net.LookupHost(serverHost) 181 if err != nil { 182 logger.Println(err) 183 continue 184 } 185 foundExisting := false 186 for _, addr := range addrs { 187 if ip, err := getIp4AddressFromAddress(addr); err != nil { 188 continue 189 } else { 190 l.connectionMapLock.Lock() 191 if l.connectionMap[ip] > 0 { 192 foundExisting = true 193 } 194 l.connectionMapLock.Unlock() 195 } 196 if foundExisting { 197 break 198 } 199 } 200 if foundExisting { 201 continue 202 } 203 message, err := l.connect(config.Network, config.ServerAddress, 204 config.MinimumInterval>>1, logger) 205 if err != nil { 206 if err != errorNotFound { 207 logger.Println(err) 208 } 209 continue 210 } 211 if message.MinimumInterval >= time.Second { 212 newMaximumInterval := message.MaximumInterval 213 if newMaximumInterval <= message.MinimumInterval { 214 newMaximumInterval = message.MinimumInterval * 11 / 10 215 } 216 if message.MinimumInterval != config.MinimumInterval || 217 newMaximumInterval != config.MaximumInterval { 218 logger.Debugf(0, 219 "min interval: %s -> %s, max interval: %s -> %s\n", 220 config.MinimumInterval, message.MinimumInterval, 221 config.MaximumInterval, newMaximumInterval) 222 } 223 config.MinimumInterval = message.MinimumInterval 224 config.MaximumInterval = newMaximumInterval 225 } 226 } 227 } 228 229 func (l *Listener) connect(network, serverAddress string, timeout time.Duration, 230 logger log.DebugLogger) (*reverseDialerMessage, error) { 231 logger.Debugln(0, "dialing") 232 localAddr := fmt.Sprintf(":%d", l.portNumber) 233 deadline := time.Now().Add(timeout) 234 rawConn, err := libnet.BindAndDial(network, localAddr, serverAddress, 235 timeout) 236 if err != nil { 237 return nil, err 238 } 239 defer func() { 240 if rawConn != nil { 241 rawConn.Close() 242 } 243 }() 244 tcpConn, ok := rawConn.(libnet.TCPConn) 245 if !ok { 246 return nil, errors.New("rejecting non-TCP connection") 247 } 248 if err := rawConn.SetDeadline(deadline); err != nil { 249 return nil, errors.New("error setting deadline: " + err.Error()) 250 } 251 logger.Debugln(0, "sending HTTP CONNECT") 252 _, err = io.WriteString(rawConn, "CONNECT "+urlPath+" HTTP/1.0\n\n") 253 if err != nil { 254 return nil, errors.New("error writing CONNECT: " + err.Error()) 255 } 256 reader := bufio.NewReader(rawConn) 257 resp, err := http.ReadResponse(reader, &http.Request{Method: "CONNECT"}) 258 if err != nil { 259 return nil, errors.New("error reading HTTP response: " + err.Error()) 260 } 261 if resp.StatusCode == http.StatusNotFound { 262 return nil, errorNotFound 263 } 264 if resp.StatusCode != http.StatusOK || resp.Status != connectString { 265 return nil, errors.New("unexpected HTTP response: " + resp.Status) 266 } 267 decoder := json.NewDecoder(reader) 268 var message reverseDialerMessage 269 if err := decoder.Decode(&message); err != nil { 270 return nil, errors.New("error decoding message: " + err.Error()) 271 } 272 // Send all-clear to other side to ensure nothing further is buffered. 273 buffer := make([]byte, 1) 274 if _, err := rawConn.Write(buffer); err != nil { 275 return nil, errors.New("error writing sync byte: " + err.Error()) 276 } 277 if err := rawConn.SetDeadline(time.Time{}); err != nil { 278 return nil, errors.New("error resetting deadline: " + err.Error()) 279 } 280 logger.Println("made connection, waiting for remote consumption") 281 // Wait for other side to consume. 282 if _, err := rawConn.Read(buffer); err != nil { 283 return nil, errors.New("error reading sync byte: " + err.Error()) 284 } 285 logger.Println("remote has consumed, injecting to local listener") 286 l.remember(rawConn) 287 l.acceptChannel <- acceptEvent{ 288 &listenerConn{TCPConn: tcpConn, listener: l}, nil} 289 rawConn = nil // Prevent Close on return. 290 return &message, nil 291 }