github.com/Cloud-Foundations/Dominator@v0.3.4/lib/srpc/client.go (about) 1 package srpc 2 3 import ( 4 "bufio" 5 "crypto/tls" 6 "errors" 7 "io" 8 "net" 9 "net/http" 10 "os" 11 "strings" 12 "sync" 13 "time" 14 15 libnet "github.com/Cloud-Foundations/Dominator/lib/net" 16 "github.com/Cloud-Foundations/tricorder/go/tricorder" 17 "github.com/Cloud-Foundations/tricorder/go/tricorder/units" 18 ) 19 20 type endpointType struct { 21 coderMaker coderMaker 22 path string 23 tls bool 24 } 25 26 var ( 27 attemptTransportUpgrade = true // Changed by tests. 28 clientMetricsDir *tricorder.DirectorySpec 29 clientMetricsMutex sync.Mutex 30 numInUseClientConnections uint64 31 numOpenClientConnections uint64 32 ) 33 34 func init() { 35 registerClientMetrics() 36 } 37 38 func registerClientMetrics() { 39 var err error 40 clientMetricsDir, err = tricorder.RegisterDirectory("srpc/client") 41 if err != nil { 42 panic(err) 43 } 44 err = clientMetricsDir.RegisterMetric("num-in-use-connections", 45 &numInUseClientConnections, units.None, 46 "number of connections in use") 47 if err != nil { 48 panic(err) 49 } 50 err = clientMetricsDir.RegisterMetric("num-open-connections", 51 &numOpenClientConnections, units.None, "number of open connections") 52 if err != nil { 53 panic(err) 54 } 55 } 56 57 func dial(network, address string, dialer Dialer) (net.Conn, error) { 58 hostPort := strings.SplitN(address, ":", 2) 59 address = strings.SplitN(hostPort[0], "*", 2)[0] + ":" + hostPort[1] 60 conn, err := dialer.Dial(network, address) 61 if err != nil { 62 if strings.Contains(err.Error(), ErrorConnectionRefused.Error()) { 63 return nil, ErrorConnectionRefused 64 } 65 if strings.Contains(err.Error(), ErrorNoRouteToHost.Error()) { 66 return nil, ErrorNoRouteToHost 67 } 68 return nil, err 69 } 70 if tcpConn, ok := conn.(libnet.TCPConn); ok { 71 if err := tcpConn.SetKeepAlive(true); err != nil { 72 conn.Close() 73 return nil, err 74 } 75 if err := tcpConn.SetKeepAlivePeriod(time.Minute * 5); err != nil { 76 conn.Close() 77 return nil, err 78 } 79 } 80 return conn, nil 81 } 82 83 func dialHTTP(network, address string, tlsConfig *tls.Config, 84 dialer Dialer) (*Client, error) { 85 if *srpcProxy == "" { 86 return dialHTTPDirect(network, address, tlsConfig, dialer) 87 } 88 var err error 89 if d, ok := dialer.(*net.Dialer); ok { 90 dialer, err = newProxyDialer(*srpcProxy, d) 91 } else { 92 dialer, err = newProxyDialer(*srpcProxy, &net.Dialer{}) 93 } 94 if err != nil { 95 return nil, err 96 } 97 return dialHTTPDirect(network, address, tlsConfig, dialer) 98 } 99 100 func dialHTTPDirect(network, address string, tlsConfig *tls.Config, 101 dialer Dialer) (*Client, error) { 102 insecureEndpoints := []endpointType{ 103 {&gobCoder{}, rpcPath, false}, 104 {&jsonCoder{}, jsonRpcPath, false}, 105 } 106 secureEndpoints := []endpointType{ 107 {&gobCoder{}, tlsRpcPath, true}, 108 {&jsonCoder{}, jsonTlsRpcPath, true}, 109 } 110 if tlsConfig == nil { 111 return dialHTTPEndpoints(network, address, nil, false, dialer, 112 insecureEndpoints) 113 } else { 114 var endpoints []endpointType 115 endpoints = append(endpoints, secureEndpoints...) 116 if tlsConfig.InsecureSkipVerify { // Don't have to trust server. 117 endpoints = append(endpoints, insecureEndpoints...) 118 } 119 client, err := dialHTTPEndpoints(network, address, tlsConfig, false, 120 dialer, endpoints) 121 if err != nil && 122 strings.Contains(err.Error(), "malformed HTTP response") { 123 // The server may do TLS on all connections: try that. 124 return dialHTTPEndpoints(network, address, tlsConfig, true, dialer, 125 secureEndpoints) 126 } 127 return client, err 128 } 129 } 130 131 func dialHTTPEndpoint(network, address string, tlsConfig *tls.Config, 132 fullTLS bool, dialer Dialer, endpoint endpointType) (*Client, error) { 133 unsecuredConn, err := dial(network, address, dialer) 134 if err != nil { 135 return nil, err 136 } 137 dataConn := unsecuredConn 138 doClose := true 139 defer func() { 140 if doClose { 141 dataConn.Close() 142 } 143 }() 144 if fullTLS { 145 tlsConn := tls.Client(unsecuredConn, tlsConfig) 146 if err := tlsConn.Handshake(); err != nil { 147 if strings.Contains(err.Error(), ErrorBadCertificate.Error()) { 148 return nil, ErrorBadCertificate 149 } 150 return nil, err 151 } 152 dataConn = tlsConn 153 } 154 if err := doHTTPConnect(dataConn, endpoint.path); err != nil { 155 return nil, err 156 } 157 if endpoint.tls && !fullTLS { 158 tlsConn := tls.Client(unsecuredConn, tlsConfig) 159 if err := tlsConn.Handshake(); err != nil { 160 if strings.Contains(err.Error(), ErrorBadCertificate.Error()) { 161 return nil, ErrorBadCertificate 162 } 163 return nil, err 164 } 165 dataConn = tlsConn 166 } 167 doClose = false 168 return newClient(unsecuredConn, dataConn, endpoint.tls, endpoint.coderMaker) 169 } 170 171 func dialHTTPEndpoints(network, address string, tlsConfig *tls.Config, 172 fullTLS bool, dialer Dialer, endpoints []endpointType) (*Client, error) { 173 for _, endpoint := range endpoints { 174 client, err := dialHTTPEndpoint(network, address, tlsConfig, fullTLS, 175 dialer, endpoint) 176 if err == nil { 177 return client, nil 178 } 179 if err != ErrorNoSrpcEndpoint { 180 return nil, err 181 } 182 } 183 return nil, ErrorNoSrpcEndpoint 184 } 185 186 func doHTTPConnect(conn net.Conn, path string) error { 187 var query string 188 if *srpcClientDoNotUseMethodPowers { 189 query = "?" + doNotUseMethodPowers + "=true" 190 } 191 io.WriteString(conn, "CONNECT "+path+query+" HTTP/1.0\n\n") 192 // Require successful HTTP response before switching to SRPC protocol. 193 resp, err := http.ReadResponse(bufio.NewReader(conn), 194 &http.Request{Method: "CONNECT"}) 195 if err != nil { 196 return err 197 } 198 if resp.StatusCode == http.StatusNotFound { 199 return ErrorNoSrpcEndpoint 200 } 201 if resp.StatusCode == http.StatusUnauthorized { 202 return ErrorBadCertificate 203 } 204 if resp.StatusCode == http.StatusMethodNotAllowed { 205 return ErrorMissingCertificate 206 } 207 if resp.StatusCode != http.StatusOK || resp.Status != connectString { 208 return errors.New("unexpected HTTP response: " + resp.Status) 209 } 210 return nil 211 } 212 213 func getEarliestClientCertExpiration() time.Time { 214 var earliest time.Time 215 if clientTlsConfig == nil { 216 return earliest 217 } 218 for _, cert := range clientTlsConfig.Certificates { 219 if cert.Leaf != nil && !cert.Leaf.NotAfter.IsZero() { 220 if earliest.IsZero() { 221 earliest = cert.Leaf.NotAfter 222 } else if cert.Leaf.NotAfter.Before(earliest) { 223 earliest = cert.Leaf.NotAfter 224 } 225 } 226 } 227 return earliest 228 } 229 230 func newClient(rawConn, dataConn net.Conn, isEncrypted bool, 231 makeCoder coderMaker) (*Client, error) { 232 clientMetricsMutex.Lock() 233 numOpenClientConnections++ 234 clientMetricsMutex.Unlock() 235 client := &Client{ 236 bufrw: bufio.NewReadWriter(bufio.NewReader(dataConn), 237 bufio.NewWriter(dataConn)), 238 conn: dataConn, 239 connType: "unknown", 240 localAddr: rawConn.LocalAddr().String(), 241 isEncrypted: isEncrypted, 242 makeCoder: makeCoder, 243 remoteAddr: rawConn.RemoteAddr().String(), 244 } 245 if tcpConn, ok := rawConn.(libnet.TCPConn); ok { 246 client.tcpConn = tcpConn 247 client.connType = "TCP" 248 } 249 if isEncrypted { 250 client.connType += "/TLS" 251 } 252 if attemptTransportUpgrade && *srpcProxy == "" { 253 oldBufrw := client.bufrw 254 if _, err := client.localAttemptUpgradeToUnix(); err != nil { 255 client.Close() 256 return nil, err 257 } 258 if client.conn != dataConn && client.bufrw == oldBufrw { 259 logger.Debugf(0, 260 "transport type: %s did not replace buffer, fixing\n", 261 client.connType) 262 client.bufrw = bufio.NewReadWriter(bufio.NewReader(client.conn), 263 bufio.NewWriter(client.conn)) 264 } 265 } 266 logger.Debugf(0, "made %s connection to: %s\n", 267 client.connType, client.remoteAddr) 268 return client, nil 269 } 270 271 func newFakeClient(options FakeClientOptions) *Client { 272 return &Client{fakeClientOptions: &options} 273 } 274 275 func (client *Client) call(serviceMethod string) (*Conn, error) { 276 if client.conn == nil { 277 panic("cannot call Client after Close()") 278 } 279 if client.resource != nil && !client.resource.inUse { 280 panic("cannot call Client after Close() or Put()") 281 } 282 client.callLock.Lock() 283 conn, err := client.callWithLock(serviceMethod) 284 if err != nil { 285 client.callLock.Unlock() 286 } 287 return conn, err 288 } 289 290 func (client *Client) callWithLock(serviceMethod string) (*Conn, error) { 291 _, err := client.bufrw.WriteString(serviceMethod + "\n") 292 if err != nil { 293 return nil, err 294 } 295 if err = client.bufrw.Flush(); err != nil { 296 return nil, err 297 } 298 resp, err := client.bufrw.ReadString('\n') 299 if err != nil { 300 return nil, err 301 } 302 if resp != "\n" { 303 resp := resp[:len(resp)-1] 304 if resp == ErrorAccessToMethodDenied.Error() { 305 return nil, ErrorAccessToMethodDenied 306 } 307 return nil, errors.New(resp) 308 } 309 conn := &Conn{ 310 Decoder: client.makeCoder.MakeDecoder(client.bufrw), 311 Encoder: client.makeCoder.MakeEncoder(client.bufrw), 312 parent: client, 313 isEncrypted: client.isEncrypted, 314 ReadWriter: client.bufrw, 315 } 316 return conn, nil 317 } 318 319 func (client *Client) close() error { 320 if client.fakeClientOptions != nil { 321 return nil 322 } 323 if client.conn == nil { 324 return os.ErrClosed 325 } 326 client.bufrw.Flush() 327 if client.resource == nil { 328 clientMetricsMutex.Lock() 329 numOpenClientConnections-- 330 clientMetricsMutex.Unlock() 331 conn := client.conn 332 client.conn = nil 333 return conn.Close() 334 } 335 client.resource.resource.Release() 336 client.conn = nil 337 clientMetricsMutex.Lock() 338 if client.resource.inUse { 339 numInUseClientConnections-- 340 client.resource.inUse = false 341 } 342 numOpenClientConnections-- 343 clientMetricsMutex.Unlock() 344 return client.resource.closeError 345 } 346 347 func (client *Client) ping() error { 348 conn, err := client.call("") 349 if err != nil { 350 return err 351 } 352 conn.Close() 353 return nil 354 } 355 356 func (client *Client) requestReply(serviceMethod string, request interface{}, 357 reply interface{}) error { 358 conn, err := client.Call(serviceMethod) 359 if err != nil { 360 return err 361 } 362 defer conn.Close() 363 return conn.requestReply(request, reply) 364 } 365 366 func (conn *Conn) requestReply(request interface{}, reply interface{}) error { 367 if err := conn.Encode(request); err != nil { 368 return err 369 } 370 if err := conn.Flush(); err != nil { 371 return err 372 } 373 str, err := conn.ReadString('\n') 374 if err != nil { 375 return err 376 } 377 if str != "\n" { 378 return errors.New(str[:len(str)-1]) 379 } 380 return conn.Decode(reply) 381 } 382 383 func (client *Client) setKeepAlive(keepalive bool) error { 384 if client.tcpConn == nil { 385 return nil 386 } 387 return client.tcpConn.SetKeepAlive(keepalive) 388 } 389 390 func (client *Client) setKeepAlivePeriod(d time.Duration) error { 391 if client.tcpConn == nil { 392 return nil 393 } 394 return client.tcpConn.SetKeepAlivePeriod(d) 395 }