github.com/aporeto-inc/trireme-lib@v10.358.0+incompatible/controller/internal/enforcer/applicationproxy/http/http.go (about) 1 package httpproxy 2 3 import ( 4 "bytes" 5 "context" 6 "crypto/tls" 7 "crypto/x509" 8 "encoding/json" 9 "encoding/pem" 10 "fmt" 11 "net" 12 "net/http" 13 "net/url" 14 "strings" 15 "sync" 16 "time" 17 18 "github.com/blang/semver" 19 jwt "github.com/dgrijalva/jwt-go" 20 "github.com/vulcand/oxy/forward" 21 "go.aporeto.io/enforcerd/trireme-lib/collector" 22 "go.aporeto.io/enforcerd/trireme-lib/common" 23 "go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/apiauth" 24 pcommon "go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/applicationproxy/common" 25 "go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/applicationproxy/markedconn" 26 "go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/applicationproxy/protomux" 27 "go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/applicationproxy/serviceregistry" 28 "go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/applicationproxy/tlshelper" 29 "go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/flowstats" 30 "go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/metadata" 31 "go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/utils/ephemeralkeys" 32 "go.aporeto.io/enforcerd/trireme-lib/controller/pkg/bufferpool" 33 "go.aporeto.io/enforcerd/trireme-lib/controller/pkg/secrets" 34 "go.aporeto.io/enforcerd/trireme-lib/policy" 35 "go.aporeto.io/gaia" 36 "go.aporeto.io/gaia/x509extensions" 37 "go.uber.org/zap" 38 ) 39 40 type statsContextKeyType string 41 42 const ( 43 statsContextKey = statsContextKeyType("statsContext") 44 45 // TriremeOIDCCallbackURI is the callback URI that must be presented by 46 // any OIDC provider. 47 TriremeOIDCCallbackURI = "/aporeto/oidc/callback" 48 typeCertificate = "CERTIFICATE" 49 ) 50 51 // JWTClaims is the structure of the claims we are sending on the wire. 52 type JWTClaims struct { 53 jwt.StandardClaims 54 SourceID string 55 Scopes []string 56 Profile []string 57 } 58 59 type hookFunc func(w http.ResponseWriter, r *http.Request) (bool, error) 60 61 // Config maintains state for proxies connections from listen to backend. 62 type Config struct { 63 cert *tls.Certificate 64 ca *x509.CertPool 65 keyPEM string 66 certPEM string 67 secrets secrets.Secrets 68 datapathKeyPair ephemeralkeys.KeyAccessor 69 collector collector.EventCollector 70 puContext string 71 localIPs map[string]struct{} 72 applicationProxy bool 73 mark int 74 server *http.Server 75 fwd *forward.Forwarder 76 fwdTLS *forward.Forwarder 77 tlsClientConfig *tls.Config 78 auth *apiauth.Processor 79 metadata *metadata.Client 80 tokenIssuer common.ServiceTokenIssuer 81 hooks map[string]hookFunc 82 agentVersion semver.Version 83 84 sync.RWMutex 85 } 86 87 // NewHTTPProxy creates a new instance of proxy reate a new instance of Proxy 88 func NewHTTPProxy( 89 c collector.EventCollector, 90 puContext string, 91 caPool *x509.CertPool, 92 applicationProxy bool, 93 mark int, 94 secrets secrets.Secrets, 95 tokenIssuer common.ServiceTokenIssuer, 96 datapathKeyPair ephemeralkeys.KeyAccessor, 97 agentVersion semver.Version, 98 ) *Config { 99 100 h := &Config{ 101 collector: c, 102 puContext: puContext, 103 ca: caPool, 104 applicationProxy: applicationProxy, 105 mark: mark, 106 secrets: secrets, 107 localIPs: markedconn.GetInterfaces(), 108 tlsClientConfig: &tls.Config{ 109 RootCAs: caPool, 110 }, 111 auth: apiauth.New(puContext, secrets), 112 metadata: metadata.NewClient(puContext, tokenIssuer), 113 tokenIssuer: tokenIssuer, 114 datapathKeyPair: datapathKeyPair, 115 agentVersion: agentVersion, 116 } 117 118 hooks := map[string]hookFunc{ 119 common.MetadataHookPolicy: h.policyHook, 120 common.MetadataHookHealth: h.healthHook, 121 common.MetadataHookCertificate: h.certificateHook, 122 common.MetadataHookKey: h.keyHook, 123 common.MetadataHookToken: h.tokenHook, 124 common.AWSHookInfo: h.awsInfoHook, 125 common.AWSHookRole: h.awsTokenHook, 126 } 127 128 h.hooks = hooks 129 130 return h 131 } 132 133 // clientTLSConfiguration calculates the right certificates and requests to the clients. 134 func (p *Config) clientTLSConfiguration(conn net.Conn, originalConfig *tls.Config) (*tls.Config, error) { 135 if mconn, ok := conn.(*markedconn.ProxiedConnection); ok { 136 ip, port := mconn.GetOriginalDestination() 137 portContext, err := serviceregistry.Instance().RetrieveExposedServiceContext(ip, port, "") 138 if err != nil { 139 return nil, fmt.Errorf("Unknown service: %s", err) 140 } 141 if portContext.Service.UserAuthorizationType == policy.UserAuthorizationMutualTLS || portContext.Service.UserAuthorizationType == policy.UserAuthorizationJWT { 142 clientCAs := p.ca 143 // now append the User given CA certPool 144 if portContext.ClientTrustedRoots != nil { 145 // append only when the certpool is given 146 if len(portContext.Service.MutualTLSTrustedRoots) > 0 { 147 if !clientCAs.AppendCertsFromPEM(portContext.Service.MutualTLSTrustedRoots) { 148 return nil, fmt.Errorf("Unable to process client CAs") 149 } 150 } 151 } 152 config := p.newBaseTLSConfig() 153 config.ClientAuth = tls.VerifyClientCertIfGiven 154 config.ClientCAs = clientCAs 155 return config, nil 156 } 157 return originalConfig, nil 158 } 159 return nil, fmt.Errorf("Invalid connection") 160 } 161 162 // newBaseTLSConfig creates the new basic TLS configuration for the server. 163 func (p *Config) newBaseTLSConfig() *tls.Config { 164 c := tlshelper.NewBaseTLSServerConfig() 165 c.NextProtos = []string{"h2"} 166 c.GetCertificate = p.GetCertificateFunc 167 c.ClientCAs = p.ca 168 return c 169 } 170 171 // newBaseTLSClientConfig creates the new basic TLS configuration for the client. 172 func (p *Config) newBaseTLSClientConfig() *tls.Config { 173 c := tlshelper.NewBaseTLSClientConfig() 174 c.NextProtos = []string{"h2"} 175 c.GetCertificate = p.GetCertificateFunc 176 c.GetClientCertificate = p.GetClientCertificateFunc 177 return c 178 } 179 180 // GetClientCertificateFunc returns the certificate that will be used by the Proxy as a client during the TLS 181 func (p *Config) GetClientCertificateFunc(*tls.CertificateRequestInfo) (*tls.Certificate, error) { 182 p.RLock() 183 defer p.RUnlock() 184 if p.cert != nil { 185 cert, err := x509.ParseCertificate(p.cert.Certificate[0]) 186 if err != nil { 187 zap.L().Error("http: Cannot build the cert chain") 188 } 189 if cert != nil { 190 by, _ := x509CertToPem(cert) 191 pemCert, err := buildCertChain(by, p.secrets.CertAuthority()) 192 if err != nil { 193 zap.L().Error("http: Cannot build the cert chain") 194 } 195 var certChain tls.Certificate 196 var certDERBlock *pem.Block 197 for { 198 certDERBlock, pemCert = pem.Decode(pemCert) 199 if certDERBlock == nil { 200 break 201 } 202 if certDERBlock.Type == typeCertificate { 203 certChain.Certificate = append(certChain.Certificate, certDERBlock.Bytes) 204 } 205 } 206 certChain.PrivateKey = p.cert.PrivateKey 207 return &certChain, nil 208 } 209 return p.cert, nil 210 } 211 return nil, nil 212 } 213 214 // RunNetworkServer runs an HTTP network server. If TLS is needed, the 215 // listener should be already a TLS listener. 216 func (p *Config) RunNetworkServer(ctx context.Context, l net.Listener, encrypted bool) error { 217 p.Lock() 218 defer p.Unlock() 219 220 if p.server != nil { 221 return fmt.Errorf("Server already running") 222 } 223 224 // for usage by callbacks below 225 protoListener, _ := l.(*protomux.ProtoListener) 226 227 // If its an encrypted, wrap the listener in a TLS context. This is activated 228 // for the listener from the network, but not for the listener from a PU. 229 if encrypted { 230 config := p.newBaseTLSConfig() 231 config.GetConfigForClient = func(helloMsg *tls.ClientHelloInfo) (*tls.Config, error) { 232 p.RLock() 233 defer p.RUnlock() 234 return p.clientTLSConfiguration(helloMsg.Conn, config) 235 } 236 config.GetClientCertificate = func(*tls.CertificateRequestInfo) (*tls.Certificate, error) { 237 p.RLock() 238 defer p.RUnlock() 239 return p.cert, nil 240 } 241 l = tls.NewListener(l, config) 242 } 243 // now create a client config, this is required if Aporeto is a client. 244 p.tlsClientConfig = p.newBaseTLSClientConfig() 245 246 reportStats := func(ctx context.Context) { 247 if state := ctx.Value(statsContextKey); state != nil { 248 if r, ok := state.(*flowstats.ConnectionState); ok { 249 r.Stats.Action = policy.Reject | policy.Log 250 r.Stats.DropReason = collector.UnableToDial 251 r.Stats.PolicyID = collector.DefaultEndPoint 252 p.collector.CollectFlowEvent(r.Stats) 253 } 254 } 255 } 256 257 networkDialerWithContext := func(ctx context.Context, network, _ string) (net.Conn, error) { 258 raddr, ok := ctx.Value(http.LocalAddrContextKey).(*net.TCPAddr) 259 if !ok { 260 reportStats(ctx) 261 return nil, fmt.Errorf("invalid destination address") 262 } 263 var platformData *markedconn.PlatformData 264 if protoListener != nil { 265 platformData = markedconn.TakePlatformData(protoListener.Listener, raddr.IP, raddr.Port) 266 } 267 conn, err := markedconn.DialMarkedWithContext(ctx, "tcp", raddr.String(), platformData, p.mark) 268 if err != nil { 269 reportStats(ctx) 270 return nil, fmt.Errorf("Failed to dial remote: %s", err) 271 } 272 return conn, nil 273 } 274 275 appDialerWithContext := func(ctx context.Context, network, _ string) (net.Conn, error) { 276 raddr, ok := ctx.Value(http.LocalAddrContextKey).(*net.TCPAddr) 277 if !ok { 278 reportStats(ctx) 279 return nil, fmt.Errorf("invalid destination address") 280 } 281 pctx, err := serviceregistry.Instance().RetrieveExposedServiceContext(raddr.IP, raddr.Port, "") 282 if err != nil { 283 return nil, err 284 } 285 raddr.Port = pctx.TargetPort 286 var platformData *markedconn.PlatformData 287 if protoListener != nil { 288 platformData = markedconn.TakePlatformData(protoListener.Listener, raddr.IP, raddr.Port) 289 } 290 conn, err := markedconn.DialMarkedWithContext(ctx, "tcp", raddr.String(), platformData, p.mark) 291 if err != nil { 292 reportStats(ctx) 293 return nil, fmt.Errorf("Failed to dial remote: %s", err) 294 } 295 return conn, nil 296 } 297 298 // Dial functions for the websockets. 299 netDial := func(network, addr string) (net.Conn, error) { 300 raddr, err := net.ResolveTCPAddr(network, addr) 301 if err != nil { 302 reportStats(ctx) 303 return nil, err 304 } 305 var platformData *markedconn.PlatformData 306 if protoListener != nil { 307 platformData = markedconn.TakePlatformData(protoListener.Listener, raddr.IP, raddr.Port) 308 } 309 conn, err := markedconn.DialMarkedWithContext(ctx, "tcp", raddr.String(), platformData, p.mark) 310 if err != nil { 311 reportStats(ctx) 312 return nil, fmt.Errorf("Failed to dial remote: %s", err) 313 } 314 return conn, nil 315 } 316 317 appDial := func(network, addr string) (net.Conn, error) { 318 raddr, err := net.ResolveTCPAddr(network, addr) 319 if err != nil { 320 reportStats(ctx) 321 return nil, err 322 } 323 pctx, err := serviceregistry.Instance().RetrieveExposedServiceContext(raddr.IP, raddr.Port, "") 324 if err != nil { 325 return nil, err 326 } 327 raddr.Port = pctx.TargetPort 328 var platformData *markedconn.PlatformData 329 if protoListener != nil { 330 platformData = markedconn.TakePlatformData(protoListener.Listener, raddr.IP, raddr.Port) 331 } 332 conn, err := markedconn.DialMarkedWithContext(ctx, "tcp", raddr.String(), platformData, p.mark) 333 if err != nil { 334 reportStats(ctx) 335 return nil, fmt.Errorf("Failed to dial remote: %s", err) 336 } 337 return conn, nil 338 } 339 340 // Create an encrypted downstream transport. We will mark the downstream connection 341 // to let the iptables rule capture it. 342 encryptedTransport := &http.Transport{ 343 TLSClientConfig: p.tlsClientConfig, 344 DialContext: networkDialerWithContext, 345 MaxIdleConnsPerHost: 2000, 346 MaxIdleConns: 2000, 347 ForceAttemptHTTP2: true, 348 } 349 350 // Create an unencrypted transport for talking to the application. If encryption 351 // is selected do not verify the certificates. This is supposed to be inside the 352 // same system. TODO: use pinned certificates. 353 transport := &http.Transport{ 354 TLSClientConfig: &tls.Config{ 355 InsecureSkipVerify: true, 356 GetClientCertificate: func(*tls.CertificateRequestInfo) (*tls.Certificate, error) { // nolint 357 p.RLock() 358 defer p.RUnlock() 359 return p.cert, nil 360 }, 361 }, 362 DialContext: appDialerWithContext, 363 MaxIdleConns: 2000, 364 MaxIdleConnsPerHost: 2000, 365 } 366 367 // Create the proxies downwards the network and the application. 368 var err error 369 p.fwdTLS, err = forward.New( 370 forward.RoundTripper(encryptedTransport), 371 forward.WebsocketTLSClientConfig(&tls.Config{RootCAs: p.ca}), 372 forward.WebSocketNetDial(netDial), 373 forward.BufferPool(bufferpool.NewPool(32*1204)), 374 forward.ErrorHandler(TriremeHTTPErrHandler{}), 375 ) 376 if err != nil { 377 return fmt.Errorf("Cannot initialize encrypted transport: %s", err) 378 } 379 380 p.fwd, err = forward.New( 381 forward.RoundTripper(NewTriremeRoundTripper(transport)), 382 forward.WebsocketTLSClientConfig(&tls.Config{InsecureSkipVerify: true}), 383 forward.WebSocketNetDial(appDial), 384 forward.BufferPool(bufferpool.NewPool(32*1204)), 385 forward.ErrorHandler(TriremeHTTPErrHandler{}), 386 ) 387 if err != nil { 388 return fmt.Errorf("Cannot initialize unencrypted transport: %s", err) 389 } 390 391 processor := p.processAppRequest 392 if !p.applicationProxy { 393 processor = p.processNetRequest 394 } 395 396 p.server = &http.Server{ 397 Handler: http.HandlerFunc(processor), 398 } 399 400 go func() { 401 <-ctx.Done() 402 p.server.Close() // nolint 403 }() 404 go p.server.Serve(l) // nolint 405 406 return nil 407 } 408 409 // ShutDown terminates the server. 410 func (p *Config) ShutDown() error { 411 return p.server.Close() 412 } 413 414 // UpdateSecrets updates the secrets 415 func (p *Config) UpdateSecrets(cert *tls.Certificate, caPool *x509.CertPool, s secrets.Secrets, certPEM, keyPEM string) { 416 p.Lock() 417 p.cert = cert 418 p.ca = caPool 419 p.secrets = s 420 p.certPEM = certPEM 421 p.keyPEM = keyPEM 422 p.tlsClientConfig.RootCAs = caPool 423 p.Unlock() 424 425 p.metadata.UpdateSecrets([]byte(certPEM), []byte(keyPEM)) 426 p.auth.UpdateSecrets(s) 427 } 428 429 // GetCertificateFunc implements the TLS interface for getting the certificate. This 430 // allows us to update the certificates of the connection on the fly. 431 func (p *Config) GetCertificateFunc(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { 432 p.RLock() 433 defer p.RUnlock() 434 // First we check if this is a direct access to the public port. In this case 435 // we will use the service public certificate. Otherwise, we will return the 436 // enforcer certificate since this is internal access. 437 if mconn, ok := clientHello.Conn.(*markedconn.ProxiedConnection); ok { 438 ip, port := mconn.GetOriginalDestination() 439 portContext, err := serviceregistry.Instance().RetrieveExposedServiceContext(ip, port, "") 440 if err != nil { 441 return nil, fmt.Errorf("service not available: %s %d", ip.String(), port) 442 } 443 service := portContext.Service 444 if service.PublicNetworkInfo != nil && service.PublicNetworkInfo.Ports.Min == uint16(port) && len(service.PublicServiceCertificate) > 0 { 445 tlsCert, err := tls.X509KeyPair(service.PublicServiceCertificate, service.PublicServiceCertificateKey) 446 if err != nil { 447 return nil, fmt.Errorf("failed to parse server certificate: %s", err) 448 } 449 return &tlsCert, nil 450 } 451 if p.cert != nil { 452 453 cert, err := x509.ParseCertificate(p.cert.Certificate[0]) 454 if err != nil { 455 return nil, fmt.Errorf("Leaf cert is missing") 456 } 457 if cert != nil { 458 by, _ := x509CertToPem(cert) 459 pemCert, err := buildCertChain(by, p.secrets.CertAuthority()) 460 if err != nil { 461 zap.L().Error("http: Cannot build the cert chain") 462 return nil, fmt.Errorf("Cannot build the cert chain") 463 } 464 var certChain tls.Certificate 465 //certPEMBlock := []byte(rootcaBundle) 466 var certDERBlock *pem.Block 467 for { 468 certDERBlock, pemCert = pem.Decode(pemCert) 469 if certDERBlock == nil { 470 break 471 } 472 if certDERBlock.Type == typeCertificate { 473 certChain.Certificate = append(certChain.Certificate, certDERBlock.Bytes) 474 } 475 } 476 certChain.PrivateKey = p.cert.PrivateKey 477 //certChain.Certificate 478 return &certChain, nil 479 } 480 return p.cert, nil 481 } 482 return nil, fmt.Errorf("no cert available - cert is nil") 483 } 484 if p.cert != nil { 485 return p.cert, nil 486 } 487 return nil, fmt.Errorf("no cert available - cert is nil") 488 } 489 490 func buildCertChain(certPEM, caPEM []byte) ([]byte, error) { 491 zap.L().Debug("http: BEFORE in buildCertChain certPEM", zap.String("certPEM", string(certPEM)), zap.String("caPEM", string(caPEM))) 492 certChain := []*x509.Certificate{} 493 clientPEMBlock := certPEM 494 495 derBlock, _ := pem.Decode(clientPEMBlock) 496 if derBlock != nil { 497 if derBlock.Type == typeCertificate { 498 cert, err := x509.ParseCertificate(derBlock.Bytes) 499 if err != nil { 500 return nil, err 501 } 502 certChain = append(certChain, cert) 503 } else { 504 return nil, fmt.Errorf("invalid pem block type: %s", derBlock.Type) 505 } 506 } 507 var certDERBlock *pem.Block 508 for { 509 certDERBlock, caPEM = pem.Decode(caPEM) 510 if certDERBlock == nil { 511 break 512 } 513 if certDERBlock.Type == typeCertificate { 514 cert, err := x509.ParseCertificate(certDERBlock.Bytes) 515 if err != nil { 516 return nil, err 517 } 518 certChain = append(certChain, cert) 519 } else { 520 return nil, fmt.Errorf("invalid pem block type: %s", certDERBlock.Type) 521 } 522 } 523 by, _ := x509CertChainToPem(certChain) 524 zap.L().Debug("http: After building the cert chain", zap.String("certChain", string(by))) 525 return x509CertChainToPem(certChain) 526 } 527 528 // x509CertChainToPem converts chain of x509 certs to byte. 529 func x509CertChainToPem(certChain []*x509.Certificate) ([]byte, error) { 530 var pemBytes bytes.Buffer 531 for _, cert := range certChain { 532 if err := pem.Encode(&pemBytes, &pem.Block{Type: typeCertificate, Bytes: cert.Raw}); err != nil { 533 return nil, err 534 } 535 } 536 return pemBytes.Bytes(), nil 537 } 538 539 // x509CertToPem converts x509 to byte. 540 func x509CertToPem(cert *x509.Certificate) ([]byte, error) { 541 var pemBytes bytes.Buffer 542 if err := pem.Encode(&pemBytes, &pem.Block{Type: typeCertificate, Bytes: cert.Raw}); err != nil { 543 return nil, err 544 } 545 return pemBytes.Bytes(), nil 546 } 547 func (p *Config) processAppRequest(w http.ResponseWriter, r *http.Request) { 548 549 zap.L().Debug("Processing Application Request", zap.String("URI", r.RequestURI), zap.String("Host", r.Host)) 550 originalDestination := r.Context().Value(http.LocalAddrContextKey).(*net.TCPAddr) 551 552 // Authorize the request by calling the authorizer library. 553 authRequest := &apiauth.Request{ 554 OriginalDestination: originalDestination, 555 Method: r.Method, 556 URL: r.URL, 557 RequestURI: r.RequestURI, 558 } 559 560 resp, err := p.auth.ApplicationRequest(authRequest) 561 if err != nil { 562 if resp.PUContext != nil { 563 state := flowstats.NewAppConnectionState(p.puContext, r, authRequest, resp) 564 state.Stats.Action = resp.Action 565 state.Stats.PolicyID = resp.NetworkPolicyID 566 p.collector.CollectFlowEvent(state.Stats) 567 } 568 http.Error(w, err.Error(), err.(*apiauth.AuthError).Status()) 569 return 570 } 571 572 state := flowstats.NewAppConnectionState(p.puContext, r, authRequest, resp) 573 if resp.External { 574 defer p.collector.CollectFlowEvent(state.Stats) 575 } 576 577 if resp.HookMethod != "" { 578 if hook, ok := p.hooks[resp.HookMethod]; ok { 579 if isHook, err := hook(w, r); err != nil || isHook { 580 if err != nil { 581 state.Stats.Action = policy.Reject 582 state.Stats.DropReason = collector.PolicyDrop 583 } 584 return 585 } 586 } else { 587 http.Error(w, "Invalid hook configuration", http.StatusInternalServerError) 588 return 589 } 590 } 591 592 httpScheme := "http://" 593 if resp.TLSListener { 594 httpScheme = "https://" 595 } 596 597 // Create the new target URL based on the Host parameter that we had. 598 r.URL, err = url.ParseRequestURI(httpScheme + r.Host) 599 if err != nil { 600 http.Error(w, "Invalid destination host name", http.StatusUnprocessableEntity) 601 return 602 } 603 604 // Add the headers with the authorization parameters and public key. The other side 605 // must validate our public key. 606 p.RLock() 607 r.Header.Add("X-APORETO-KEY", string(p.secrets.TransmittedKey())) 608 p.RUnlock() 609 r.Header.Add("X-APORETO-AUTH", resp.Token) 610 611 contextWithStats := context.WithValue(r.Context(), statsContextKey, state) 612 // Forward the request. 613 p.fwdTLS.ServeHTTP(w, r.WithContext(contextWithStats)) 614 } 615 616 func (p *Config) processNetRequest(w http.ResponseWriter, r *http.Request) { 617 618 zap.L().Debug("Processing Network Request", zap.String("URI", r.RequestURI), zap.String("Host", r.Host)) 619 originalDestination := r.Context().Value(http.LocalAddrContextKey).(*net.TCPAddr) 620 621 sourceAddress, err := net.ResolveTCPAddr("tcp", r.RemoteAddr) 622 if err != nil { 623 zap.L().Error("Internal server error - cannot determine source address information", zap.Error(err)) 624 http.Error(w, "Invalid network information", http.StatusForbidden) 625 return 626 } 627 628 requestCookie, _ := r.Cookie("X-APORETO-AUTH") // nolint errcheck 629 630 pr := &collector.PingReport{} 631 632 request := &apiauth.Request{ 633 OriginalDestination: originalDestination, 634 SourceAddress: sourceAddress, 635 Header: r.Header, 636 URL: r.URL, 637 Method: r.Method, 638 RequestURI: r.RequestURI, 639 Cookie: requestCookie, 640 TLS: r.TLS, 641 } 642 643 response, err := p.auth.NetworkRequest(r.Context(), request) 644 645 var userID string 646 if response != nil && len(response.UserAttributes) > 0 { 647 userData := &collector.UserRecord{ 648 Namespace: response.Namespace, 649 Claims: response.UserAttributes, 650 } 651 p.collector.CollectUserEvent(userData) 652 userID = userData.ID 653 } 654 655 state := flowstats.NewNetworkConnectionState(p.puContext, userID, request, response) 656 defer func() { 657 if response != nil && response.PingConfig != nil { 658 pr.PingID = response.PingConfig.PingID 659 pr.IterationID = response.PingConfig.IterationID 660 pr.Type = gaia.PingProbeTypeRequest 661 pr.RemotePUID = response.SourcePUID 662 pr.PUID = response.PUContext.ManagementID() 663 pr.Namespace = response.Namespace 664 pr.PayloadSize = response.PingConfig.PayloadSize 665 pr.PayloadSizeType = gaia.PingProbePayloadSizeTypeReceived 666 pr.Protocol = 6 667 pr.ServiceType = "L7" 668 pr.FourTuple = fmt.Sprintf("%s:%s:%d:%d", 669 sourceAddress.IP.String(), 670 originalDestination.IP.String(), 671 sourceAddress.Port, 672 originalDestination.Port) 673 pr.PolicyID = response.NetworkPolicyID 674 pr.PolicyAction = response.Action 675 pr.ServiceID = response.ServiceID 676 pr.AgentVersion = p.agentVersion.String() 677 pr.RemoteEndpointType = collector.EndPointTypePU 678 pr.IsServer = true 679 pr.Claims = response.PingConfig.Claims 680 pr.ClaimsType = gaia.PingProbeClaimsTypeReceived 681 pr.RemoteNamespaceType = gaia.PingProbeRemoteNamespaceTypePlain 682 pr.TargetTCPNetworks = true 683 pr.ExcludedNetworks = false 684 685 if len(r.TLS.PeerCertificates) > 0 { 686 if len(r.TLS.PeerCertificates[0].Subject.Organization) > 0 { 687 pr.RemoteNamespace = r.TLS.PeerCertificates[0].Subject.Organization[0] 688 } 689 pr.PeerCertIssuer = r.TLS.PeerCertificates[0].Issuer.String() 690 pr.PeerCertSubject = r.TLS.PeerCertificates[0].Subject.String() 691 pr.PeerCertExpiry = r.TLS.PeerCertificates[0].NotAfter 692 693 if found, controller := pcommon.ExtractExtension(x509extensions.Controller(), r.TLS.PeerCertificates[0].Extensions); found { 694 pr.RemoteController = string(controller) 695 } 696 } 697 698 p.collector.CollectPingEvent(pr) 699 } else { 700 p.collector.CollectFlowEvent(state.Stats) 701 } 702 }() 703 704 if err != nil { 705 706 zap.L().Debug("Authorization error", 707 zap.Error(err), 708 zap.String("URI", r.RequestURI), 709 zap.String("Host", r.Host), 710 ) 711 authError, ok := err.(*apiauth.AuthError) 712 if !ok { 713 http.Error(w, "Internal type error", http.StatusInternalServerError) 714 return 715 } 716 717 if response == nil { 718 // Basic errors are captured here. 719 http.Error(w, authError.Message(), authError.Status()) 720 return 721 } 722 723 if response.PingConfig != nil { 724 pr.Error = response.DropReason 725 } 726 727 if !response.Redirect { 728 // If there is no redirect, we also return an error. 729 http.Error(w, authError.Message(), authError.Status()) 730 return 731 } 732 733 // Redirect logic. Populate information here. This is forcing a 734 // redirect rather than an error. 735 if response.Cookie != nil { 736 http.SetCookie(w, response.Cookie) 737 } 738 w.Header().Add("Location", response.RedirectURI) 739 http.Error(w, response.Data, authError.Status()) 740 741 return 742 } 743 744 // Select as http or https for communication with listening service. 745 httpPrefix := "http://" 746 if response.TLSListener { 747 httpPrefix = "https://" 748 } 749 750 // Create the target URI. Websocket Gorilla proxy takes it from the URL. For normal 751 // connections we don't want that. 752 if forward.IsWebsocketRequest(r) { 753 r.URL, err = url.ParseRequestURI(httpPrefix + originalDestination.String()) 754 } else { 755 r.URL, err = url.ParseRequestURI(httpPrefix + r.Host) 756 } 757 if err != nil { 758 state.Stats.DropReason = collector.InvalidFormat 759 http.Error(w, fmt.Sprintf("Invalid HTTP Host parameter: %s", err), http.StatusBadRequest) 760 return 761 } 762 763 // Update the request headers with the user attributes as defined by the mappings 764 r.Header = response.Header 765 766 // Update the statistics and forward the request. We always encrypt downstream 767 state.Stats.Action = policy.Accept | policy.Encrypt | policy.Log 768 769 // // Treat the remote proxy scenario where the destination IPs are in a remote 770 // // host. Check of network rules that allow this transfer and report the corresponding 771 // // flows. 772 // if _, ok := p.localIPs[originalDestination.IP.String()]; !ok { 773 // _, action, err := pctx.PUContext.ApplicationACLPolicyFromAddr(originalDestination.IP, uint16(originalDestination.Port)) 774 // if err != nil || action.Action.Rejected() { 775 // defer p.collector.CollectFlowEvent(reportDownStream(state.stats, action)) 776 // http.Error(w, fmt.Sprintf("Access denied by network policy to downstream IP: %s", originalDestination.IP.String()), http.StatusNetworkAuthenticationRequired) 777 // return 778 // } 779 // if action.Action.Accepted() { 780 // defer p.collector.CollectFlowEvent(reportDownStream(state.stats, action)) 781 // } 782 // } 783 784 contextWithStats := context.WithValue(r.Context(), statsContextKey, state) 785 p.fwd.ServeHTTP(w, r.WithContext(contextWithStats)) 786 zap.L().Debug("Forwarding Request", zap.String("URI", r.RequestURI), zap.String("Host", r.Host)) 787 } 788 789 func (p *Config) policyHook(w http.ResponseWriter, r *http.Request) (bool, error) { 790 if r.Header.Get(common.MetadataKey) != common.MetadataValue { 791 http.Error(w, "unauthorized request for policy", http.StatusForbidden) 792 return true, fmt.Errorf("unauthorized") 793 } 794 795 data, _, err := p.metadata.GetCurrentPolicy() 796 if err != nil { 797 http.Error(w, "Unable to retrieve current policy", http.StatusInternalServerError) 798 return true, err 799 } 800 if _, err := w.Write(data); err != nil { 801 zap.L().Error("Unable to write policy response") 802 } 803 804 return true, nil 805 } 806 807 func (p *Config) certificateHook(w http.ResponseWriter, r *http.Request) (bool, error) { 808 if r.Header.Get(common.MetadataKey) != common.MetadataValue { 809 http.Error(w, "unauthorized request for certificate", http.StatusForbidden) 810 return true, fmt.Errorf("unauthorized") 811 } 812 813 if _, err := w.Write(p.metadata.GetCertificate()); err != nil { 814 zap.L().Error("Unable to write response") 815 } 816 817 return true, nil 818 } 819 820 func (p *Config) keyHook(w http.ResponseWriter, r *http.Request) (bool, error) { 821 if r.Header.Get(common.MetadataKey) != common.MetadataValue { 822 http.Error(w, "unauthorized request for private key", http.StatusForbidden) 823 return true, fmt.Errorf("unauthorized") 824 } 825 826 if _, err := w.Write(p.metadata.GetPrivateKey()); err != nil { 827 zap.L().Error("Unable to write response") 828 } 829 830 return true, nil 831 } 832 833 func (p *Config) healthHook(w http.ResponseWriter, r *http.Request) (bool, error) { 834 835 // Health hook will only return ok if the current policy is already populated. 836 plc, _, err := p.metadata.GetCurrentPolicy() 837 if err != nil || plc == nil { 838 http.Error(w, "Unable to retrieve current policy", http.StatusInternalServerError) 839 return true, err 840 } 841 842 if _, err := w.Write([]byte("OK\n")); err != nil { 843 zap.L().Error("Unable to write response to health API") 844 } 845 return true, nil 846 } 847 848 func (p *Config) tokenHook(w http.ResponseWriter, r *http.Request) (bool, error) { 849 850 if r.Header.Get(common.MetadataKey) != common.MetadataValue { 851 http.Error(w, "unauthorized request for token", http.StatusForbidden) 852 return true, fmt.Errorf("unauthorized") 853 } 854 855 audience := r.URL.Query().Get("audience") 856 validityString := r.URL.Query().Get("validity") 857 858 validity := time.Minute * 60 859 var err error 860 if validityString != "" { 861 validity, err = time.ParseDuration(validityString) 862 if err != nil { 863 http.Error(w, "Invalid validity time requested. Please use notation of number+unit. Example: `10m`", http.StatusUnprocessableEntity) 864 return true, nil 865 } 866 } 867 868 token, err := p.tokenIssuer.Issue(r.Context(), p.puContext, common.ServiceTokenTypeOAUTH, audience, validity) 869 if err != nil { 870 http.Error(w, fmt.Sprintf("Unable to issue token: %s", err), http.StatusBadRequest) 871 return true, nil 872 } 873 874 if _, err := w.Write([]byte(token)); err != nil { 875 zap.L().Error("Unable to write response on token API") 876 } 877 return true, nil 878 } 879 880 func (p *Config) awsInfoHook(w http.ResponseWriter, r *http.Request) (bool, error) { 881 882 if err := validateAWSHeaders(r); err != nil { 883 http.Error(w, fmt.Sprintf("invalid user agent: %s", err), http.StatusForbidden) 884 return true, err 885 } 886 887 awsRole, id, err := p.awsRole() 888 if err != nil { 889 return true, err 890 } 891 892 type info struct { 893 Code string `json:"Code,omitempty"` 894 LastUpdated time.Time `json:"LastUpdated,omitempty"` 895 InstanceProfileArn string `json:"InstanceProfileArn,omitempty"` 896 InstanceProfileID string `json:"InstanceProfileId,omitempty"` 897 } 898 899 out := &info{ 900 Code: "Success", 901 LastUpdated: time.Now(), 902 InstanceProfileArn: awsRole, 903 InstanceProfileID: id, 904 } 905 906 data, err := json.MarshalIndent(out, " ", " ") 907 if err != nil { 908 return true, fmt.Errorf("error in marshall of info: %s", err) 909 } 910 911 if _, err = w.Write(data); err != nil { 912 return true, fmt.Errorf("unable to write data response: %s", err) 913 } 914 915 return true, nil 916 } 917 918 func (p *Config) awsTokenHook(w http.ResponseWriter, r *http.Request) (bool, error) { 919 920 if err := validateAWSHeaders(r); err != nil { 921 http.Error(w, fmt.Sprintf("invalid user agent: %s", err), http.StatusForbidden) 922 return true, err 923 } 924 925 awsRole, id, err := p.awsRole() 926 if err != nil { 927 return true, err 928 } 929 930 awsRoleParts := strings.Split(awsRole, "/") 931 if len(awsRoleParts) == 0 { 932 http.Error(w, fmt.Sprintf("invalid role: %s", err), http.StatusNotFound) 933 return true, fmt.Errorf("invalid role: %s", awsRole) 934 } 935 936 awsRoleName := awsRoleParts[len(awsRoleParts)-1] 937 938 if strings.HasSuffix(r.RequestURI, "security-credentials/") { 939 if _, err := w.Write([]byte(awsRoleName)); err != nil { 940 return true, err 941 } 942 return true, nil 943 } 944 945 if !strings.HasSuffix(r.RequestURI, "security-credentials/"+awsRoleName) { 946 http.Error(w, "not found", http.StatusNotFound) 947 return true, fmt.Errorf("not found") 948 } 949 950 token, err := p.tokenIssuer.Issue(r.Context(), id, common.ServiceTokenTypeAWS, awsRole, time.Hour) 951 if err != nil { 952 http.Error(w, fmt.Sprintf("Unable to issue token: %s", err), http.StatusBadRequest) 953 return true, nil 954 } 955 956 if _, err := w.Write([]byte(token)); err != nil { 957 zap.L().Error("Unable to write response on token API") 958 } 959 return true, nil 960 } 961 962 func (p *Config) awsRole() (string, string, error) { 963 964 _, plc, err := p.metadata.GetCurrentPolicy() 965 if err != nil { 966 return "", "", err 967 } 968 969 awsRole := "" 970 for _, scope := range plc.Scopes { 971 if strings.HasPrefix(scope, common.AWSRoleARNPrefix) { 972 if awsRole != "" && awsRole != scope[len(common.AWSRolePrefix):] { 973 return "", "", fmt.Errorf("overlapping roles detected") 974 } 975 awsRole = scope[len(common.AWSRolePrefix):] 976 } 977 } 978 979 if awsRole == "" { 980 return "", "", fmt.Errorf("role not found") 981 } 982 983 return awsRole, plc.ManagementID, nil 984 } 985 986 var ( 987 allowedAgents = []string{"aws-cli/", "aws-chalice/", "Boto3/", "Botocore/", "aws-sdk-"} 988 ) 989 990 func validateAWSHeaders(r *http.Request) error { 991 992 userAgent, ok := r.Header["User-Agent"] 993 if !ok { 994 return fmt.Errorf("no user-agent provided") 995 } 996 997 for _, u := range userAgent { 998 for _, t := range allowedAgents { 999 if strings.HasPrefix(u, t) { 1000 return nil 1001 } 1002 } 1003 } 1004 1005 return fmt.Errorf("invalid user agent: %v", userAgent) 1006 } 1007 1008 // func reportDownStream(record *collector.FlowRecord, action *policy.FlowPolicy) *collector.FlowRecord { 1009 // return &collector.FlowRecord{ 1010 // ContextID: record.ContextID, 1011 // Destination: &collector.EndPoint{ 1012 // URI: record.Destination.URI, 1013 // HTTPMethod: record.Destination.HTTPMethod, 1014 // Type: collector.EndPointTypeExternalIP, 1015 // Port: record.Destination.Port, 1016 // IP: record.Destination.IP, 1017 // ID: action.ServiceID, 1018 // }, 1019 // Source: &collector.EndPoint{ 1020 // Type: record.Destination.Type, 1021 // ID: record.Destination.ID, 1022 // IP: "0.0.0.0", 1023 // }, 1024 // Action: action.Action, 1025 // L4Protocol: record.L4Protocol, 1026 // ServiceType: record.ServiceType, 1027 // ServiceID: record.ServiceID, 1028 // Tags: record.Tags, 1029 // PolicyID: action.PolicyID, 1030 // Count: 1, 1031 // } 1032 // }