github.com/Cloud-Foundations/Dominator@v0.3.4/lib/net/reverseconnection/listener.go (about) 1 package reverseconnection 2 3 import ( 4 "bufio" 5 "encoding/json" 6 "errors" 7 "fmt" 8 "io" 9 "math/rand" 10 "net" 11 "net/http" 12 "os" 13 "path" 14 "time" 15 16 libjson "github.com/Cloud-Foundations/Dominator/lib/json" 17 "github.com/Cloud-Foundations/Dominator/lib/log" 18 "github.com/Cloud-Foundations/Dominator/lib/log/prefixlogger" 19 libnet "github.com/Cloud-Foundations/Dominator/lib/net" 20 ) 21 22 const ( 23 configDirectory = "/etc/reverse-listeners" 24 ) 25 26 var ( 27 errorNotFound = errors.New("HTTP method not found") 28 errorLoopback = errors.New("loopback address") 29 ) 30 31 func getIp4Address(conn net.Conn) (ip4Address, error) { 32 remoteAddr := conn.RemoteAddr() 33 var zero ip4Address 34 if remoteAddr.Network() != "tcp" { 35 return zero, errors.New("wrong network type: " + remoteAddr.Network()) 36 } 37 remoteHost, _, err := net.SplitHostPort(remoteAddr.String()) 38 if err != nil { 39 return zero, err 40 } 41 return getIp4AddressFromAddress(remoteHost) 42 } 43 44 func getIp4AddressFromAddress(address string) (ip4Address, error) { 45 ip := net.ParseIP(address) 46 if ip == nil { 47 return ip4Address{}, errors.New("failed to parse: " + address) 48 } 49 if ip.IsLoopback() { 50 return ip4Address{}, errorLoopback 51 } 52 ip = ip.To4() 53 if ip == nil { 54 return ip4Address{}, errors.New(address + " is not IPv4") 55 } 56 return ip4Address{ip[0], ip[1], ip[2], ip[3]}, nil 57 } 58 59 func listen(network string, portNumber uint, logger log.DebugLogger) ( 60 *Listener, error) { 61 rListener, err := libnet.ListenWithReuse(network, 62 fmt.Sprintf(":%d", portNumber)) 63 if err != nil { 64 return nil, fmt.Errorf("error creating %s listener: %s", network, err) 65 } 66 acceptChannel := make(chan acceptEvent, 1) 67 listener := &Listener{ 68 listener: rListener, 69 portNumber: portNumber, 70 logger: logger, 71 acceptChannel: acceptChannel, 72 connectionMap: make(map[ip4Address]uint), 73 } 74 go listener.listen(acceptChannel) 75 return listener, nil 76 } 77 78 func sleep(minInterval, maxInterval time.Duration) { 79 jit := (maxInterval - minInterval) * time.Duration((rand.Intn(1000))) / 1000 80 time.Sleep(minInterval + jit) 81 } 82 83 func (conn *listenerConn) Close() error { 84 if ip, err := getIp4Address(conn); err != nil { 85 if err != errorLoopback { 86 conn.listener.logger.Println(err) 87 } 88 } else { 89 conn.listener.forget(conn.RemoteAddr().String(), ip) 90 } 91 return conn.TCPConn.Close() 92 } 93 94 func (l *Listener) accept() (*listenerConn, error) { 95 if l.isClosed() { 96 return nil, errors.New("listener is closed") 97 } 98 event := <-l.acceptChannel 99 return event.conn, event.error 100 } 101 102 func (l *Listener) close() error { 103 l.closedLock.Lock() 104 l.closed = true 105 l.closedLock.Unlock() 106 return l.listener.Close() 107 } 108 109 func (l *Listener) forget(remoteHost string, ip ip4Address) { 110 l.logger.Debugf(1, "reverse listener: forget(%s)\n", remoteHost) 111 l.connectionMapLock.Lock() 112 defer l.connectionMapLock.Unlock() 113 if numConn := l.connectionMap[ip]; numConn < 1 { 114 panic("unknown connection from: " + remoteHost) 115 } else { 116 l.connectionMap[ip] = numConn - 1 117 } 118 } 119 120 func (l *Listener) isClosed() bool { 121 l.closedLock.Lock() 122 defer l.closedLock.Unlock() 123 return l.closed 124 } 125 126 func (l *Listener) listen(acceptChannel chan<- acceptEvent) { 127 for { 128 if l.isClosed() { 129 break 130 } 131 conn, err := l.listener.Accept() 132 if err != nil { 133 l.logger.Printf( 134 "error accepting connection on reverse listener: %s\n", err) 135 continue 136 } 137 tcpConn, ok := conn.(libnet.TCPConn) 138 if !ok { 139 conn.Close() 140 l.logger.Println("rejecting non-TCP connection") 141 continue 142 } 143 l.remember(conn) 144 acceptChannel <- acceptEvent{ 145 &listenerConn{TCPConn: tcpConn, listener: l}, err} 146 } 147 } 148 149 func (l *Listener) remember(conn net.Conn) { 150 l.logger.Debugf(1, "reverse listener: remember(%s): %p\n", 151 conn.RemoteAddr(), conn) 152 if ip, err := getIp4Address(conn); err == nil { 153 l.connectionMapLock.Lock() 154 defer l.connectionMapLock.Unlock() 155 l.connectionMap[ip]++ 156 } 157 } 158 159 func (l *Listener) requestConnections(serviceName string) error { 160 var config ReverseListenerConfig 161 filename := path.Join(configDirectory, serviceName) 162 if err := libjson.ReadFromFile(filename, &config); err != nil { 163 if os.IsNotExist(err) { 164 return nil 165 } 166 return err 167 } 168 if config.Network == "" { 169 config.Network = "tcp" 170 } 171 if config.MinimumInterval < time.Minute { 172 config.MinimumInterval = time.Minute 173 } 174 if config.MaximumInterval <= config.MinimumInterval { 175 config.MaximumInterval = config.MinimumInterval * 11 / 10 176 } 177 serverHost, _, err := net.SplitHostPort(config.ServerAddress) 178 if err != nil { 179 return err 180 } 181 go l.connectLoop(config, serverHost) 182 return nil 183 } 184 185 func (l *Listener) connectLoop(config ReverseListenerConfig, 186 serverHost string) { 187 logger := prefixlogger.New("reverse listener: "+config.ServerAddress+": ", 188 l.logger) 189 logger.Debugf(0, "starting loop, min interval: %s, max interval: %s\n", 190 config.MinimumInterval, config.MaximumInterval) 191 for { 192 sleep(config.MinimumInterval, config.MaximumInterval) 193 addrs, err := net.LookupHost(serverHost) 194 if err != nil { 195 logger.Println(err) 196 continue 197 } 198 foundExisting := false 199 for _, addr := range addrs { 200 if ip, err := getIp4AddressFromAddress(addr); err != nil { 201 continue 202 } else { 203 l.connectionMapLock.Lock() 204 if l.connectionMap[ip] > 0 { 205 foundExisting = true 206 } 207 l.connectionMapLock.Unlock() 208 } 209 if foundExisting { 210 break 211 } 212 } 213 if foundExisting { 214 continue 215 } 216 message, err := l.connect(config.Network, config.ServerAddress, 217 config.MinimumInterval>>1, logger) 218 if err != nil { 219 if err != errorNotFound { 220 logger.Println(err) 221 } 222 continue 223 } 224 if message.MinimumInterval >= time.Second { 225 newMaximumInterval := message.MaximumInterval 226 if newMaximumInterval <= message.MinimumInterval { 227 newMaximumInterval = message.MinimumInterval * 11 / 10 228 } 229 if message.MinimumInterval != config.MinimumInterval || 230 newMaximumInterval != config.MaximumInterval { 231 logger.Debugf(0, 232 "min interval: %s -> %s, max interval: %s -> %s\n", 233 config.MinimumInterval, message.MinimumInterval, 234 config.MaximumInterval, newMaximumInterval) 235 } 236 config.MinimumInterval = message.MinimumInterval 237 config.MaximumInterval = newMaximumInterval 238 } 239 } 240 } 241 242 func (l *Listener) connect(network, serverAddress string, timeout time.Duration, 243 logger log.DebugLogger) (*reverseDialerMessage, error) { 244 logger.Debugln(0, "dialing") 245 localAddr := fmt.Sprintf(":%d", l.portNumber) 246 deadline := time.Now().Add(timeout) 247 rawConn, err := libnet.BindAndDial(network, localAddr, serverAddress, 248 timeout) 249 if err != nil { 250 return nil, err 251 } 252 defer func() { 253 if rawConn != nil { 254 rawConn.Close() 255 } 256 }() 257 tcpConn, ok := rawConn.(libnet.TCPConn) 258 if !ok { 259 return nil, errors.New("rejecting non-TCP connection") 260 } 261 if err := rawConn.SetDeadline(deadline); err != nil { 262 return nil, errors.New("error setting deadline: " + err.Error()) 263 } 264 logger.Debugln(0, "sending HTTP CONNECT") 265 _, err = io.WriteString(rawConn, "CONNECT "+urlPath+" HTTP/1.0\n\n") 266 if err != nil { 267 return nil, errors.New("error writing CONNECT: " + err.Error()) 268 } 269 reader := bufio.NewReader(rawConn) 270 resp, err := http.ReadResponse(reader, &http.Request{Method: "CONNECT"}) 271 if err != nil { 272 return nil, errors.New("error reading HTTP response: " + err.Error()) 273 } 274 if resp.StatusCode == http.StatusNotFound { 275 return nil, errorNotFound 276 } 277 if resp.StatusCode != http.StatusOK || resp.Status != connectString { 278 return nil, errors.New("unexpected HTTP response: " + resp.Status) 279 } 280 decoder := json.NewDecoder(reader) 281 var message reverseDialerMessage 282 if err := decoder.Decode(&message); err != nil { 283 return nil, errors.New("error decoding message: " + err.Error()) 284 } 285 // Send all-clear to other side to ensure nothing further is buffered. 286 buffer := make([]byte, 1) 287 if _, err := rawConn.Write(buffer); err != nil { 288 return nil, errors.New("error writing sync byte: " + err.Error()) 289 } 290 if err := rawConn.SetDeadline(time.Time{}); err != nil { 291 return nil, errors.New("error resetting deadline: " + err.Error()) 292 } 293 logger.Println("made connection, waiting for remote consumption") 294 // Wait for other side to consume. 295 if _, err := rawConn.Read(buffer); err != nil { 296 return nil, errors.New("error reading sync byte: " + err.Error()) 297 } 298 logger.Println("remote has consumed, injecting to local listener") 299 l.remember(rawConn) 300 l.acceptChannel <- acceptEvent{ 301 &listenerConn{TCPConn: tcpConn, listener: l}, nil} 302 rawConn = nil // Prevent Close on return. 303 return &message, nil 304 }