github.com/aporeto-inc/trireme-lib@v10.358.0+incompatible/controller/internal/enforcer/applicationproxy/http/http.go (about)

     1  package httpproxy
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"crypto/tls"
     7  	"crypto/x509"
     8  	"encoding/json"
     9  	"encoding/pem"
    10  	"fmt"
    11  	"net"
    12  	"net/http"
    13  	"net/url"
    14  	"strings"
    15  	"sync"
    16  	"time"
    17  
    18  	"github.com/blang/semver"
    19  	jwt "github.com/dgrijalva/jwt-go"
    20  	"github.com/vulcand/oxy/forward"
    21  	"go.aporeto.io/enforcerd/trireme-lib/collector"
    22  	"go.aporeto.io/enforcerd/trireme-lib/common"
    23  	"go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/apiauth"
    24  	pcommon "go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/applicationproxy/common"
    25  	"go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/applicationproxy/markedconn"
    26  	"go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/applicationproxy/protomux"
    27  	"go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/applicationproxy/serviceregistry"
    28  	"go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/applicationproxy/tlshelper"
    29  	"go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/flowstats"
    30  	"go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/metadata"
    31  	"go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/utils/ephemeralkeys"
    32  	"go.aporeto.io/enforcerd/trireme-lib/controller/pkg/bufferpool"
    33  	"go.aporeto.io/enforcerd/trireme-lib/controller/pkg/secrets"
    34  	"go.aporeto.io/enforcerd/trireme-lib/policy"
    35  	"go.aporeto.io/gaia"
    36  	"go.aporeto.io/gaia/x509extensions"
    37  	"go.uber.org/zap"
    38  )
    39  
    40  type statsContextKeyType string
    41  
    42  const (
    43  	statsContextKey = statsContextKeyType("statsContext")
    44  
    45  	// TriremeOIDCCallbackURI is the callback URI that must be presented by
    46  	// any OIDC provider.
    47  	TriremeOIDCCallbackURI = "/aporeto/oidc/callback"
    48  	typeCertificate        = "CERTIFICATE"
    49  )
    50  
    51  // JWTClaims is the structure of the claims we are sending on the wire.
    52  type JWTClaims struct {
    53  	jwt.StandardClaims
    54  	SourceID string
    55  	Scopes   []string
    56  	Profile  []string
    57  }
    58  
    59  type hookFunc func(w http.ResponseWriter, r *http.Request) (bool, error)
    60  
    61  // Config maintains state for proxies connections from listen to backend.
    62  type Config struct {
    63  	cert             *tls.Certificate
    64  	ca               *x509.CertPool
    65  	keyPEM           string
    66  	certPEM          string
    67  	secrets          secrets.Secrets
    68  	datapathKeyPair  ephemeralkeys.KeyAccessor
    69  	collector        collector.EventCollector
    70  	puContext        string
    71  	localIPs         map[string]struct{}
    72  	applicationProxy bool
    73  	mark             int
    74  	server           *http.Server
    75  	fwd              *forward.Forwarder
    76  	fwdTLS           *forward.Forwarder
    77  	tlsClientConfig  *tls.Config
    78  	auth             *apiauth.Processor
    79  	metadata         *metadata.Client
    80  	tokenIssuer      common.ServiceTokenIssuer
    81  	hooks            map[string]hookFunc
    82  	agentVersion     semver.Version
    83  
    84  	sync.RWMutex
    85  }
    86  
    87  // NewHTTPProxy creates a new instance of proxy reate a new instance of Proxy
    88  func NewHTTPProxy(
    89  	c collector.EventCollector,
    90  	puContext string,
    91  	caPool *x509.CertPool,
    92  	applicationProxy bool,
    93  	mark int,
    94  	secrets secrets.Secrets,
    95  	tokenIssuer common.ServiceTokenIssuer,
    96  	datapathKeyPair ephemeralkeys.KeyAccessor,
    97  	agentVersion semver.Version,
    98  ) *Config {
    99  
   100  	h := &Config{
   101  		collector:        c,
   102  		puContext:        puContext,
   103  		ca:               caPool,
   104  		applicationProxy: applicationProxy,
   105  		mark:             mark,
   106  		secrets:          secrets,
   107  		localIPs:         markedconn.GetInterfaces(),
   108  		tlsClientConfig: &tls.Config{
   109  			RootCAs: caPool,
   110  		},
   111  		auth:            apiauth.New(puContext, secrets),
   112  		metadata:        metadata.NewClient(puContext, tokenIssuer),
   113  		tokenIssuer:     tokenIssuer,
   114  		datapathKeyPair: datapathKeyPair,
   115  		agentVersion:    agentVersion,
   116  	}
   117  
   118  	hooks := map[string]hookFunc{
   119  		common.MetadataHookPolicy:      h.policyHook,
   120  		common.MetadataHookHealth:      h.healthHook,
   121  		common.MetadataHookCertificate: h.certificateHook,
   122  		common.MetadataHookKey:         h.keyHook,
   123  		common.MetadataHookToken:       h.tokenHook,
   124  		common.AWSHookInfo:             h.awsInfoHook,
   125  		common.AWSHookRole:             h.awsTokenHook,
   126  	}
   127  
   128  	h.hooks = hooks
   129  
   130  	return h
   131  }
   132  
   133  // clientTLSConfiguration calculates the right certificates and requests to the clients.
   134  func (p *Config) clientTLSConfiguration(conn net.Conn, originalConfig *tls.Config) (*tls.Config, error) {
   135  	if mconn, ok := conn.(*markedconn.ProxiedConnection); ok {
   136  		ip, port := mconn.GetOriginalDestination()
   137  		portContext, err := serviceregistry.Instance().RetrieveExposedServiceContext(ip, port, "")
   138  		if err != nil {
   139  			return nil, fmt.Errorf("Unknown service: %s", err)
   140  		}
   141  		if portContext.Service.UserAuthorizationType == policy.UserAuthorizationMutualTLS || portContext.Service.UserAuthorizationType == policy.UserAuthorizationJWT {
   142  			clientCAs := p.ca
   143  			// now append the User given CA certPool
   144  			if portContext.ClientTrustedRoots != nil {
   145  				// append only when the certpool is given
   146  				if len(portContext.Service.MutualTLSTrustedRoots) > 0 {
   147  					if !clientCAs.AppendCertsFromPEM(portContext.Service.MutualTLSTrustedRoots) {
   148  						return nil, fmt.Errorf("Unable to process client CAs")
   149  					}
   150  				}
   151  			}
   152  			config := p.newBaseTLSConfig()
   153  			config.ClientAuth = tls.VerifyClientCertIfGiven
   154  			config.ClientCAs = clientCAs
   155  			return config, nil
   156  		}
   157  		return originalConfig, nil
   158  	}
   159  	return nil, fmt.Errorf("Invalid connection")
   160  }
   161  
   162  // newBaseTLSConfig creates the new basic TLS configuration for the server.
   163  func (p *Config) newBaseTLSConfig() *tls.Config {
   164  	c := tlshelper.NewBaseTLSServerConfig()
   165  	c.NextProtos = []string{"h2"}
   166  	c.GetCertificate = p.GetCertificateFunc
   167  	c.ClientCAs = p.ca
   168  	return c
   169  }
   170  
   171  // newBaseTLSClientConfig creates the new basic TLS configuration for the client.
   172  func (p *Config) newBaseTLSClientConfig() *tls.Config {
   173  	c := tlshelper.NewBaseTLSClientConfig()
   174  	c.NextProtos = []string{"h2"}
   175  	c.GetCertificate = p.GetCertificateFunc
   176  	c.GetClientCertificate = p.GetClientCertificateFunc
   177  	return c
   178  }
   179  
   180  // GetClientCertificateFunc returns the certificate that will be used by the Proxy as a client during the TLS
   181  func (p *Config) GetClientCertificateFunc(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
   182  	p.RLock()
   183  	defer p.RUnlock()
   184  	if p.cert != nil {
   185  		cert, err := x509.ParseCertificate(p.cert.Certificate[0])
   186  		if err != nil {
   187  			zap.L().Error("http: Cannot build the cert chain")
   188  		}
   189  		if cert != nil {
   190  			by, _ := x509CertToPem(cert)
   191  			pemCert, err := buildCertChain(by, p.secrets.CertAuthority())
   192  			if err != nil {
   193  				zap.L().Error("http: Cannot build the cert chain")
   194  			}
   195  			var certChain tls.Certificate
   196  			var certDERBlock *pem.Block
   197  			for {
   198  				certDERBlock, pemCert = pem.Decode(pemCert)
   199  				if certDERBlock == nil {
   200  					break
   201  				}
   202  				if certDERBlock.Type == typeCertificate {
   203  					certChain.Certificate = append(certChain.Certificate, certDERBlock.Bytes)
   204  				}
   205  			}
   206  			certChain.PrivateKey = p.cert.PrivateKey
   207  			return &certChain, nil
   208  		}
   209  		return p.cert, nil
   210  	}
   211  	return nil, nil
   212  }
   213  
   214  // RunNetworkServer runs an HTTP network server. If TLS is needed, the
   215  // listener should be already a TLS listener.
   216  func (p *Config) RunNetworkServer(ctx context.Context, l net.Listener, encrypted bool) error {
   217  	p.Lock()
   218  	defer p.Unlock()
   219  
   220  	if p.server != nil {
   221  		return fmt.Errorf("Server already running")
   222  	}
   223  
   224  	// for usage by callbacks below
   225  	protoListener, _ := l.(*protomux.ProtoListener)
   226  
   227  	// If its an encrypted, wrap the listener in a TLS context. This is activated
   228  	// for the listener from the network, but not for the listener from a PU.
   229  	if encrypted {
   230  		config := p.newBaseTLSConfig()
   231  		config.GetConfigForClient = func(helloMsg *tls.ClientHelloInfo) (*tls.Config, error) {
   232  			p.RLock()
   233  			defer p.RUnlock()
   234  			return p.clientTLSConfiguration(helloMsg.Conn, config)
   235  		}
   236  		config.GetClientCertificate = func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
   237  			p.RLock()
   238  			defer p.RUnlock()
   239  			return p.cert, nil
   240  		}
   241  		l = tls.NewListener(l, config)
   242  	}
   243  	// now create a client config, this is required if Aporeto is a client.
   244  	p.tlsClientConfig = p.newBaseTLSClientConfig()
   245  
   246  	reportStats := func(ctx context.Context) {
   247  		if state := ctx.Value(statsContextKey); state != nil {
   248  			if r, ok := state.(*flowstats.ConnectionState); ok {
   249  				r.Stats.Action = policy.Reject | policy.Log
   250  				r.Stats.DropReason = collector.UnableToDial
   251  				r.Stats.PolicyID = collector.DefaultEndPoint
   252  				p.collector.CollectFlowEvent(r.Stats)
   253  			}
   254  		}
   255  	}
   256  
   257  	networkDialerWithContext := func(ctx context.Context, network, _ string) (net.Conn, error) {
   258  		raddr, ok := ctx.Value(http.LocalAddrContextKey).(*net.TCPAddr)
   259  		if !ok {
   260  			reportStats(ctx)
   261  			return nil, fmt.Errorf("invalid destination address")
   262  		}
   263  		var platformData *markedconn.PlatformData
   264  		if protoListener != nil {
   265  			platformData = markedconn.TakePlatformData(protoListener.Listener, raddr.IP, raddr.Port)
   266  		}
   267  		conn, err := markedconn.DialMarkedWithContext(ctx, "tcp", raddr.String(), platformData, p.mark)
   268  		if err != nil {
   269  			reportStats(ctx)
   270  			return nil, fmt.Errorf("Failed to dial remote: %s", err)
   271  		}
   272  		return conn, nil
   273  	}
   274  
   275  	appDialerWithContext := func(ctx context.Context, network, _ string) (net.Conn, error) {
   276  		raddr, ok := ctx.Value(http.LocalAddrContextKey).(*net.TCPAddr)
   277  		if !ok {
   278  			reportStats(ctx)
   279  			return nil, fmt.Errorf("invalid destination address")
   280  		}
   281  		pctx, err := serviceregistry.Instance().RetrieveExposedServiceContext(raddr.IP, raddr.Port, "")
   282  		if err != nil {
   283  			return nil, err
   284  		}
   285  		raddr.Port = pctx.TargetPort
   286  		var platformData *markedconn.PlatformData
   287  		if protoListener != nil {
   288  			platformData = markedconn.TakePlatformData(protoListener.Listener, raddr.IP, raddr.Port)
   289  		}
   290  		conn, err := markedconn.DialMarkedWithContext(ctx, "tcp", raddr.String(), platformData, p.mark)
   291  		if err != nil {
   292  			reportStats(ctx)
   293  			return nil, fmt.Errorf("Failed to dial remote: %s", err)
   294  		}
   295  		return conn, nil
   296  	}
   297  
   298  	// Dial functions for the websockets.
   299  	netDial := func(network, addr string) (net.Conn, error) {
   300  		raddr, err := net.ResolveTCPAddr(network, addr)
   301  		if err != nil {
   302  			reportStats(ctx)
   303  			return nil, err
   304  		}
   305  		var platformData *markedconn.PlatformData
   306  		if protoListener != nil {
   307  			platformData = markedconn.TakePlatformData(protoListener.Listener, raddr.IP, raddr.Port)
   308  		}
   309  		conn, err := markedconn.DialMarkedWithContext(ctx, "tcp", raddr.String(), platformData, p.mark)
   310  		if err != nil {
   311  			reportStats(ctx)
   312  			return nil, fmt.Errorf("Failed to dial remote: %s", err)
   313  		}
   314  		return conn, nil
   315  	}
   316  
   317  	appDial := func(network, addr string) (net.Conn, error) {
   318  		raddr, err := net.ResolveTCPAddr(network, addr)
   319  		if err != nil {
   320  			reportStats(ctx)
   321  			return nil, err
   322  		}
   323  		pctx, err := serviceregistry.Instance().RetrieveExposedServiceContext(raddr.IP, raddr.Port, "")
   324  		if err != nil {
   325  			return nil, err
   326  		}
   327  		raddr.Port = pctx.TargetPort
   328  		var platformData *markedconn.PlatformData
   329  		if protoListener != nil {
   330  			platformData = markedconn.TakePlatformData(protoListener.Listener, raddr.IP, raddr.Port)
   331  		}
   332  		conn, err := markedconn.DialMarkedWithContext(ctx, "tcp", raddr.String(), platformData, p.mark)
   333  		if err != nil {
   334  			reportStats(ctx)
   335  			return nil, fmt.Errorf("Failed to dial remote: %s", err)
   336  		}
   337  		return conn, nil
   338  	}
   339  
   340  	// Create an encrypted downstream transport. We will mark the downstream connection
   341  	// to let the iptables rule capture it.
   342  	encryptedTransport := &http.Transport{
   343  		TLSClientConfig:     p.tlsClientConfig,
   344  		DialContext:         networkDialerWithContext,
   345  		MaxIdleConnsPerHost: 2000,
   346  		MaxIdleConns:        2000,
   347  		ForceAttemptHTTP2:   true,
   348  	}
   349  
   350  	// Create an unencrypted transport for talking to the application. If encryption
   351  	// is selected do not verify the certificates. This is supposed to be inside the
   352  	// same system. TODO: use pinned certificates.
   353  	transport := &http.Transport{
   354  		TLSClientConfig: &tls.Config{
   355  			InsecureSkipVerify: true,
   356  			GetClientCertificate: func(*tls.CertificateRequestInfo) (*tls.Certificate, error) { // nolint
   357  				p.RLock()
   358  				defer p.RUnlock()
   359  				return p.cert, nil
   360  			},
   361  		},
   362  		DialContext:         appDialerWithContext,
   363  		MaxIdleConns:        2000,
   364  		MaxIdleConnsPerHost: 2000,
   365  	}
   366  
   367  	// Create the proxies downwards the network and the application.
   368  	var err error
   369  	p.fwdTLS, err = forward.New(
   370  		forward.RoundTripper(encryptedTransport),
   371  		forward.WebsocketTLSClientConfig(&tls.Config{RootCAs: p.ca}),
   372  		forward.WebSocketNetDial(netDial),
   373  		forward.BufferPool(bufferpool.NewPool(32*1204)),
   374  		forward.ErrorHandler(TriremeHTTPErrHandler{}),
   375  	)
   376  	if err != nil {
   377  		return fmt.Errorf("Cannot initialize encrypted transport: %s", err)
   378  	}
   379  
   380  	p.fwd, err = forward.New(
   381  		forward.RoundTripper(NewTriremeRoundTripper(transport)),
   382  		forward.WebsocketTLSClientConfig(&tls.Config{InsecureSkipVerify: true}),
   383  		forward.WebSocketNetDial(appDial),
   384  		forward.BufferPool(bufferpool.NewPool(32*1204)),
   385  		forward.ErrorHandler(TriremeHTTPErrHandler{}),
   386  	)
   387  	if err != nil {
   388  		return fmt.Errorf("Cannot initialize unencrypted transport: %s", err)
   389  	}
   390  
   391  	processor := p.processAppRequest
   392  	if !p.applicationProxy {
   393  		processor = p.processNetRequest
   394  	}
   395  
   396  	p.server = &http.Server{
   397  		Handler: http.HandlerFunc(processor),
   398  	}
   399  
   400  	go func() {
   401  		<-ctx.Done()
   402  		p.server.Close() // nolint
   403  	}()
   404  	go p.server.Serve(l) // nolint
   405  
   406  	return nil
   407  }
   408  
   409  // ShutDown terminates the server.
   410  func (p *Config) ShutDown() error {
   411  	return p.server.Close()
   412  }
   413  
   414  // UpdateSecrets updates the secrets
   415  func (p *Config) UpdateSecrets(cert *tls.Certificate, caPool *x509.CertPool, s secrets.Secrets, certPEM, keyPEM string) {
   416  	p.Lock()
   417  	p.cert = cert
   418  	p.ca = caPool
   419  	p.secrets = s
   420  	p.certPEM = certPEM
   421  	p.keyPEM = keyPEM
   422  	p.tlsClientConfig.RootCAs = caPool
   423  	p.Unlock()
   424  
   425  	p.metadata.UpdateSecrets([]byte(certPEM), []byte(keyPEM))
   426  	p.auth.UpdateSecrets(s)
   427  }
   428  
   429  // GetCertificateFunc implements the TLS interface for getting the certificate. This
   430  // allows us to update the certificates of the connection on the fly.
   431  func (p *Config) GetCertificateFunc(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
   432  	p.RLock()
   433  	defer p.RUnlock()
   434  	// First we check if this is a direct access to the public port. In this case
   435  	// we will use the service public certificate. Otherwise, we will return the
   436  	// enforcer certificate since this is internal access.
   437  	if mconn, ok := clientHello.Conn.(*markedconn.ProxiedConnection); ok {
   438  		ip, port := mconn.GetOriginalDestination()
   439  		portContext, err := serviceregistry.Instance().RetrieveExposedServiceContext(ip, port, "")
   440  		if err != nil {
   441  			return nil, fmt.Errorf("service not available: %s %d", ip.String(), port)
   442  		}
   443  		service := portContext.Service
   444  		if service.PublicNetworkInfo != nil && service.PublicNetworkInfo.Ports.Min == uint16(port) && len(service.PublicServiceCertificate) > 0 {
   445  			tlsCert, err := tls.X509KeyPair(service.PublicServiceCertificate, service.PublicServiceCertificateKey)
   446  			if err != nil {
   447  				return nil, fmt.Errorf("failed to parse server certificate: %s", err)
   448  			}
   449  			return &tlsCert, nil
   450  		}
   451  		if p.cert != nil {
   452  
   453  			cert, err := x509.ParseCertificate(p.cert.Certificate[0])
   454  			if err != nil {
   455  				return nil, fmt.Errorf("Leaf cert is missing")
   456  			}
   457  			if cert != nil {
   458  				by, _ := x509CertToPem(cert)
   459  				pemCert, err := buildCertChain(by, p.secrets.CertAuthority())
   460  				if err != nil {
   461  					zap.L().Error("http: Cannot build the cert chain")
   462  					return nil, fmt.Errorf("Cannot build the cert chain")
   463  				}
   464  				var certChain tls.Certificate
   465  				//certPEMBlock := []byte(rootcaBundle)
   466  				var certDERBlock *pem.Block
   467  				for {
   468  					certDERBlock, pemCert = pem.Decode(pemCert)
   469  					if certDERBlock == nil {
   470  						break
   471  					}
   472  					if certDERBlock.Type == typeCertificate {
   473  						certChain.Certificate = append(certChain.Certificate, certDERBlock.Bytes)
   474  					}
   475  				}
   476  				certChain.PrivateKey = p.cert.PrivateKey
   477  				//certChain.Certificate
   478  				return &certChain, nil
   479  			}
   480  			return p.cert, nil
   481  		}
   482  		return nil, fmt.Errorf("no cert available - cert is nil")
   483  	}
   484  	if p.cert != nil {
   485  		return p.cert, nil
   486  	}
   487  	return nil, fmt.Errorf("no cert available - cert is nil")
   488  }
   489  
   490  func buildCertChain(certPEM, caPEM []byte) ([]byte, error) {
   491  	zap.L().Debug("http: BEFORE in buildCertChain certPEM", zap.String("certPEM", string(certPEM)), zap.String("caPEM", string(caPEM)))
   492  	certChain := []*x509.Certificate{}
   493  	clientPEMBlock := certPEM
   494  
   495  	derBlock, _ := pem.Decode(clientPEMBlock)
   496  	if derBlock != nil {
   497  		if derBlock.Type == typeCertificate {
   498  			cert, err := x509.ParseCertificate(derBlock.Bytes)
   499  			if err != nil {
   500  				return nil, err
   501  			}
   502  			certChain = append(certChain, cert)
   503  		} else {
   504  			return nil, fmt.Errorf("invalid pem block type: %s", derBlock.Type)
   505  		}
   506  	}
   507  	var certDERBlock *pem.Block
   508  	for {
   509  		certDERBlock, caPEM = pem.Decode(caPEM)
   510  		if certDERBlock == nil {
   511  			break
   512  		}
   513  		if certDERBlock.Type == typeCertificate {
   514  			cert, err := x509.ParseCertificate(certDERBlock.Bytes)
   515  			if err != nil {
   516  				return nil, err
   517  			}
   518  			certChain = append(certChain, cert)
   519  		} else {
   520  			return nil, fmt.Errorf("invalid pem block type: %s", certDERBlock.Type)
   521  		}
   522  	}
   523  	by, _ := x509CertChainToPem(certChain)
   524  	zap.L().Debug("http: After building the cert chain", zap.String("certChain", string(by)))
   525  	return x509CertChainToPem(certChain)
   526  }
   527  
   528  // x509CertChainToPem converts chain of x509 certs to byte.
   529  func x509CertChainToPem(certChain []*x509.Certificate) ([]byte, error) {
   530  	var pemBytes bytes.Buffer
   531  	for _, cert := range certChain {
   532  		if err := pem.Encode(&pemBytes, &pem.Block{Type: typeCertificate, Bytes: cert.Raw}); err != nil {
   533  			return nil, err
   534  		}
   535  	}
   536  	return pemBytes.Bytes(), nil
   537  }
   538  
   539  // x509CertToPem converts x509 to byte.
   540  func x509CertToPem(cert *x509.Certificate) ([]byte, error) {
   541  	var pemBytes bytes.Buffer
   542  	if err := pem.Encode(&pemBytes, &pem.Block{Type: typeCertificate, Bytes: cert.Raw}); err != nil {
   543  		return nil, err
   544  	}
   545  	return pemBytes.Bytes(), nil
   546  }
   547  func (p *Config) processAppRequest(w http.ResponseWriter, r *http.Request) {
   548  
   549  	zap.L().Debug("Processing Application Request", zap.String("URI", r.RequestURI), zap.String("Host", r.Host))
   550  	originalDestination := r.Context().Value(http.LocalAddrContextKey).(*net.TCPAddr)
   551  
   552  	// Authorize the request by calling the authorizer library.
   553  	authRequest := &apiauth.Request{
   554  		OriginalDestination: originalDestination,
   555  		Method:              r.Method,
   556  		URL:                 r.URL,
   557  		RequestURI:          r.RequestURI,
   558  	}
   559  
   560  	resp, err := p.auth.ApplicationRequest(authRequest)
   561  	if err != nil {
   562  		if resp.PUContext != nil {
   563  			state := flowstats.NewAppConnectionState(p.puContext, r, authRequest, resp)
   564  			state.Stats.Action = resp.Action
   565  			state.Stats.PolicyID = resp.NetworkPolicyID
   566  			p.collector.CollectFlowEvent(state.Stats)
   567  		}
   568  		http.Error(w, err.Error(), err.(*apiauth.AuthError).Status())
   569  		return
   570  	}
   571  
   572  	state := flowstats.NewAppConnectionState(p.puContext, r, authRequest, resp)
   573  	if resp.External {
   574  		defer p.collector.CollectFlowEvent(state.Stats)
   575  	}
   576  
   577  	if resp.HookMethod != "" {
   578  		if hook, ok := p.hooks[resp.HookMethod]; ok {
   579  			if isHook, err := hook(w, r); err != nil || isHook {
   580  				if err != nil {
   581  					state.Stats.Action = policy.Reject
   582  					state.Stats.DropReason = collector.PolicyDrop
   583  				}
   584  				return
   585  			}
   586  		} else {
   587  			http.Error(w, "Invalid hook configuration", http.StatusInternalServerError)
   588  			return
   589  		}
   590  	}
   591  
   592  	httpScheme := "http://"
   593  	if resp.TLSListener {
   594  		httpScheme = "https://"
   595  	}
   596  
   597  	// Create the new target URL based on the Host parameter that we had.
   598  	r.URL, err = url.ParseRequestURI(httpScheme + r.Host)
   599  	if err != nil {
   600  		http.Error(w, "Invalid destination host name", http.StatusUnprocessableEntity)
   601  		return
   602  	}
   603  
   604  	// Add the headers with the authorization parameters and public key. The other side
   605  	// must validate our public key.
   606  	p.RLock()
   607  	r.Header.Add("X-APORETO-KEY", string(p.secrets.TransmittedKey()))
   608  	p.RUnlock()
   609  	r.Header.Add("X-APORETO-AUTH", resp.Token)
   610  
   611  	contextWithStats := context.WithValue(r.Context(), statsContextKey, state)
   612  	// Forward the request.
   613  	p.fwdTLS.ServeHTTP(w, r.WithContext(contextWithStats))
   614  }
   615  
   616  func (p *Config) processNetRequest(w http.ResponseWriter, r *http.Request) {
   617  
   618  	zap.L().Debug("Processing Network Request", zap.String("URI", r.RequestURI), zap.String("Host", r.Host))
   619  	originalDestination := r.Context().Value(http.LocalAddrContextKey).(*net.TCPAddr)
   620  
   621  	sourceAddress, err := net.ResolveTCPAddr("tcp", r.RemoteAddr)
   622  	if err != nil {
   623  		zap.L().Error("Internal server error - cannot determine source address information", zap.Error(err))
   624  		http.Error(w, "Invalid network information", http.StatusForbidden)
   625  		return
   626  	}
   627  
   628  	requestCookie, _ := r.Cookie("X-APORETO-AUTH") // nolint errcheck
   629  
   630  	pr := &collector.PingReport{}
   631  
   632  	request := &apiauth.Request{
   633  		OriginalDestination: originalDestination,
   634  		SourceAddress:       sourceAddress,
   635  		Header:              r.Header,
   636  		URL:                 r.URL,
   637  		Method:              r.Method,
   638  		RequestURI:          r.RequestURI,
   639  		Cookie:              requestCookie,
   640  		TLS:                 r.TLS,
   641  	}
   642  
   643  	response, err := p.auth.NetworkRequest(r.Context(), request)
   644  
   645  	var userID string
   646  	if response != nil && len(response.UserAttributes) > 0 {
   647  		userData := &collector.UserRecord{
   648  			Namespace: response.Namespace,
   649  			Claims:    response.UserAttributes,
   650  		}
   651  		p.collector.CollectUserEvent(userData)
   652  		userID = userData.ID
   653  	}
   654  
   655  	state := flowstats.NewNetworkConnectionState(p.puContext, userID, request, response)
   656  	defer func() {
   657  		if response != nil && response.PingConfig != nil {
   658  			pr.PingID = response.PingConfig.PingID
   659  			pr.IterationID = response.PingConfig.IterationID
   660  			pr.Type = gaia.PingProbeTypeRequest
   661  			pr.RemotePUID = response.SourcePUID
   662  			pr.PUID = response.PUContext.ManagementID()
   663  			pr.Namespace = response.Namespace
   664  			pr.PayloadSize = response.PingConfig.PayloadSize
   665  			pr.PayloadSizeType = gaia.PingProbePayloadSizeTypeReceived
   666  			pr.Protocol = 6
   667  			pr.ServiceType = "L7"
   668  			pr.FourTuple = fmt.Sprintf("%s:%s:%d:%d",
   669  				sourceAddress.IP.String(),
   670  				originalDestination.IP.String(),
   671  				sourceAddress.Port,
   672  				originalDestination.Port)
   673  			pr.PolicyID = response.NetworkPolicyID
   674  			pr.PolicyAction = response.Action
   675  			pr.ServiceID = response.ServiceID
   676  			pr.AgentVersion = p.agentVersion.String()
   677  			pr.RemoteEndpointType = collector.EndPointTypePU
   678  			pr.IsServer = true
   679  			pr.Claims = response.PingConfig.Claims
   680  			pr.ClaimsType = gaia.PingProbeClaimsTypeReceived
   681  			pr.RemoteNamespaceType = gaia.PingProbeRemoteNamespaceTypePlain
   682  			pr.TargetTCPNetworks = true
   683  			pr.ExcludedNetworks = false
   684  
   685  			if len(r.TLS.PeerCertificates) > 0 {
   686  				if len(r.TLS.PeerCertificates[0].Subject.Organization) > 0 {
   687  					pr.RemoteNamespace = r.TLS.PeerCertificates[0].Subject.Organization[0]
   688  				}
   689  				pr.PeerCertIssuer = r.TLS.PeerCertificates[0].Issuer.String()
   690  				pr.PeerCertSubject = r.TLS.PeerCertificates[0].Subject.String()
   691  				pr.PeerCertExpiry = r.TLS.PeerCertificates[0].NotAfter
   692  
   693  				if found, controller := pcommon.ExtractExtension(x509extensions.Controller(), r.TLS.PeerCertificates[0].Extensions); found {
   694  					pr.RemoteController = string(controller)
   695  				}
   696  			}
   697  
   698  			p.collector.CollectPingEvent(pr)
   699  		} else {
   700  			p.collector.CollectFlowEvent(state.Stats)
   701  		}
   702  	}()
   703  
   704  	if err != nil {
   705  
   706  		zap.L().Debug("Authorization error",
   707  			zap.Error(err),
   708  			zap.String("URI", r.RequestURI),
   709  			zap.String("Host", r.Host),
   710  		)
   711  		authError, ok := err.(*apiauth.AuthError)
   712  		if !ok {
   713  			http.Error(w, "Internal type error", http.StatusInternalServerError)
   714  			return
   715  		}
   716  
   717  		if response == nil {
   718  			// Basic errors are captured here.
   719  			http.Error(w, authError.Message(), authError.Status())
   720  			return
   721  		}
   722  
   723  		if response.PingConfig != nil {
   724  			pr.Error = response.DropReason
   725  		}
   726  
   727  		if !response.Redirect {
   728  			// If there is no redirect, we also return an error.
   729  			http.Error(w, authError.Message(), authError.Status())
   730  			return
   731  		}
   732  
   733  		// Redirect logic. Populate information here. This is forcing a
   734  		// redirect rather than an error.
   735  		if response.Cookie != nil {
   736  			http.SetCookie(w, response.Cookie)
   737  		}
   738  		w.Header().Add("Location", response.RedirectURI)
   739  		http.Error(w, response.Data, authError.Status())
   740  
   741  		return
   742  	}
   743  
   744  	// Select as http or https for communication with listening service.
   745  	httpPrefix := "http://"
   746  	if response.TLSListener {
   747  		httpPrefix = "https://"
   748  	}
   749  
   750  	// Create the target URI. Websocket Gorilla proxy takes it from the URL. For normal
   751  	// connections we don't want that.
   752  	if forward.IsWebsocketRequest(r) {
   753  		r.URL, err = url.ParseRequestURI(httpPrefix + originalDestination.String())
   754  	} else {
   755  		r.URL, err = url.ParseRequestURI(httpPrefix + r.Host)
   756  	}
   757  	if err != nil {
   758  		state.Stats.DropReason = collector.InvalidFormat
   759  		http.Error(w, fmt.Sprintf("Invalid HTTP Host parameter: %s", err), http.StatusBadRequest)
   760  		return
   761  	}
   762  
   763  	// Update the request headers with the user attributes as defined by the mappings
   764  	r.Header = response.Header
   765  
   766  	// Update the statistics and forward the request. We always encrypt downstream
   767  	state.Stats.Action = policy.Accept | policy.Encrypt | policy.Log
   768  
   769  	// // Treat the remote proxy scenario where the destination IPs are in a remote
   770  	// // host. Check of network rules that allow this transfer and report the corresponding
   771  	// // flows.
   772  	// if _, ok := p.localIPs[originalDestination.IP.String()]; !ok {
   773  	// 	_, action, err := pctx.PUContext.ApplicationACLPolicyFromAddr(originalDestination.IP, uint16(originalDestination.Port))
   774  	// 	if err != nil || action.Action.Rejected() {
   775  	// 		defer p.collector.CollectFlowEvent(reportDownStream(state.stats, action))
   776  	// 		http.Error(w, fmt.Sprintf("Access denied by network policy to downstream IP: %s", originalDestination.IP.String()), http.StatusNetworkAuthenticationRequired)
   777  	// 		return
   778  	// 	}
   779  	// 	if action.Action.Accepted() {
   780  	// 		defer p.collector.CollectFlowEvent(reportDownStream(state.stats, action))
   781  	// 	}
   782  	// }
   783  
   784  	contextWithStats := context.WithValue(r.Context(), statsContextKey, state)
   785  	p.fwd.ServeHTTP(w, r.WithContext(contextWithStats))
   786  	zap.L().Debug("Forwarding Request", zap.String("URI", r.RequestURI), zap.String("Host", r.Host))
   787  }
   788  
   789  func (p *Config) policyHook(w http.ResponseWriter, r *http.Request) (bool, error) {
   790  	if r.Header.Get(common.MetadataKey) != common.MetadataValue {
   791  		http.Error(w, "unauthorized request for policy", http.StatusForbidden)
   792  		return true, fmt.Errorf("unauthorized")
   793  	}
   794  
   795  	data, _, err := p.metadata.GetCurrentPolicy()
   796  	if err != nil {
   797  		http.Error(w, "Unable to retrieve current policy", http.StatusInternalServerError)
   798  		return true, err
   799  	}
   800  	if _, err := w.Write(data); err != nil {
   801  		zap.L().Error("Unable to write policy response")
   802  	}
   803  
   804  	return true, nil
   805  }
   806  
   807  func (p *Config) certificateHook(w http.ResponseWriter, r *http.Request) (bool, error) {
   808  	if r.Header.Get(common.MetadataKey) != common.MetadataValue {
   809  		http.Error(w, "unauthorized request for certificate", http.StatusForbidden)
   810  		return true, fmt.Errorf("unauthorized")
   811  	}
   812  
   813  	if _, err := w.Write(p.metadata.GetCertificate()); err != nil {
   814  		zap.L().Error("Unable to write response")
   815  	}
   816  
   817  	return true, nil
   818  }
   819  
   820  func (p *Config) keyHook(w http.ResponseWriter, r *http.Request) (bool, error) {
   821  	if r.Header.Get(common.MetadataKey) != common.MetadataValue {
   822  		http.Error(w, "unauthorized request for private key", http.StatusForbidden)
   823  		return true, fmt.Errorf("unauthorized")
   824  	}
   825  
   826  	if _, err := w.Write(p.metadata.GetPrivateKey()); err != nil {
   827  		zap.L().Error("Unable to write response")
   828  	}
   829  
   830  	return true, nil
   831  }
   832  
   833  func (p *Config) healthHook(w http.ResponseWriter, r *http.Request) (bool, error) {
   834  
   835  	// Health hook will only return ok if the current policy is already populated.
   836  	plc, _, err := p.metadata.GetCurrentPolicy()
   837  	if err != nil || plc == nil {
   838  		http.Error(w, "Unable to retrieve current policy", http.StatusInternalServerError)
   839  		return true, err
   840  	}
   841  
   842  	if _, err := w.Write([]byte("OK\n")); err != nil {
   843  		zap.L().Error("Unable to write response to health API")
   844  	}
   845  	return true, nil
   846  }
   847  
   848  func (p *Config) tokenHook(w http.ResponseWriter, r *http.Request) (bool, error) {
   849  
   850  	if r.Header.Get(common.MetadataKey) != common.MetadataValue {
   851  		http.Error(w, "unauthorized request for token", http.StatusForbidden)
   852  		return true, fmt.Errorf("unauthorized")
   853  	}
   854  
   855  	audience := r.URL.Query().Get("audience")
   856  	validityString := r.URL.Query().Get("validity")
   857  
   858  	validity := time.Minute * 60
   859  	var err error
   860  	if validityString != "" {
   861  		validity, err = time.ParseDuration(validityString)
   862  		if err != nil {
   863  			http.Error(w, "Invalid validity time requested. Please use notation of number+unit. Example: `10m`", http.StatusUnprocessableEntity)
   864  			return true, nil
   865  		}
   866  	}
   867  
   868  	token, err := p.tokenIssuer.Issue(r.Context(), p.puContext, common.ServiceTokenTypeOAUTH, audience, validity)
   869  	if err != nil {
   870  		http.Error(w, fmt.Sprintf("Unable to issue token: %s", err), http.StatusBadRequest)
   871  		return true, nil
   872  	}
   873  
   874  	if _, err := w.Write([]byte(token)); err != nil {
   875  		zap.L().Error("Unable to write response on token API")
   876  	}
   877  	return true, nil
   878  }
   879  
   880  func (p *Config) awsInfoHook(w http.ResponseWriter, r *http.Request) (bool, error) {
   881  
   882  	if err := validateAWSHeaders(r); err != nil {
   883  		http.Error(w, fmt.Sprintf("invalid user agent: %s", err), http.StatusForbidden)
   884  		return true, err
   885  	}
   886  
   887  	awsRole, id, err := p.awsRole()
   888  	if err != nil {
   889  		return true, err
   890  	}
   891  
   892  	type info struct {
   893  		Code               string    `json:"Code,omitempty"`
   894  		LastUpdated        time.Time `json:"LastUpdated,omitempty"`
   895  		InstanceProfileArn string    `json:"InstanceProfileArn,omitempty"`
   896  		InstanceProfileID  string    `json:"InstanceProfileId,omitempty"`
   897  	}
   898  
   899  	out := &info{
   900  		Code:               "Success",
   901  		LastUpdated:        time.Now(),
   902  		InstanceProfileArn: awsRole,
   903  		InstanceProfileID:  id,
   904  	}
   905  
   906  	data, err := json.MarshalIndent(out, " ", " ")
   907  	if err != nil {
   908  		return true, fmt.Errorf("error in marshall of info: %s", err)
   909  	}
   910  
   911  	if _, err = w.Write(data); err != nil {
   912  		return true, fmt.Errorf("unable to write data response: %s", err)
   913  	}
   914  
   915  	return true, nil
   916  }
   917  
   918  func (p *Config) awsTokenHook(w http.ResponseWriter, r *http.Request) (bool, error) {
   919  
   920  	if err := validateAWSHeaders(r); err != nil {
   921  		http.Error(w, fmt.Sprintf("invalid user agent: %s", err), http.StatusForbidden)
   922  		return true, err
   923  	}
   924  
   925  	awsRole, id, err := p.awsRole()
   926  	if err != nil {
   927  		return true, err
   928  	}
   929  
   930  	awsRoleParts := strings.Split(awsRole, "/")
   931  	if len(awsRoleParts) == 0 {
   932  		http.Error(w, fmt.Sprintf("invalid role: %s", err), http.StatusNotFound)
   933  		return true, fmt.Errorf("invalid role: %s", awsRole)
   934  	}
   935  
   936  	awsRoleName := awsRoleParts[len(awsRoleParts)-1]
   937  
   938  	if strings.HasSuffix(r.RequestURI, "security-credentials/") {
   939  		if _, err := w.Write([]byte(awsRoleName)); err != nil {
   940  			return true, err
   941  		}
   942  		return true, nil
   943  	}
   944  
   945  	if !strings.HasSuffix(r.RequestURI, "security-credentials/"+awsRoleName) {
   946  		http.Error(w, "not found", http.StatusNotFound)
   947  		return true, fmt.Errorf("not found")
   948  	}
   949  
   950  	token, err := p.tokenIssuer.Issue(r.Context(), id, common.ServiceTokenTypeAWS, awsRole, time.Hour)
   951  	if err != nil {
   952  		http.Error(w, fmt.Sprintf("Unable to issue token: %s", err), http.StatusBadRequest)
   953  		return true, nil
   954  	}
   955  
   956  	if _, err := w.Write([]byte(token)); err != nil {
   957  		zap.L().Error("Unable to write response on token API")
   958  	}
   959  	return true, nil
   960  }
   961  
   962  func (p *Config) awsRole() (string, string, error) {
   963  
   964  	_, plc, err := p.metadata.GetCurrentPolicy()
   965  	if err != nil {
   966  		return "", "", err
   967  	}
   968  
   969  	awsRole := ""
   970  	for _, scope := range plc.Scopes {
   971  		if strings.HasPrefix(scope, common.AWSRoleARNPrefix) {
   972  			if awsRole != "" && awsRole != scope[len(common.AWSRolePrefix):] {
   973  				return "", "", fmt.Errorf("overlapping roles detected")
   974  			}
   975  			awsRole = scope[len(common.AWSRolePrefix):]
   976  		}
   977  	}
   978  
   979  	if awsRole == "" {
   980  		return "", "", fmt.Errorf("role not found")
   981  	}
   982  
   983  	return awsRole, plc.ManagementID, nil
   984  }
   985  
   986  var (
   987  	allowedAgents = []string{"aws-cli/", "aws-chalice/", "Boto3/", "Botocore/", "aws-sdk-"}
   988  )
   989  
   990  func validateAWSHeaders(r *http.Request) error {
   991  
   992  	userAgent, ok := r.Header["User-Agent"]
   993  	if !ok {
   994  		return fmt.Errorf("no user-agent provided")
   995  	}
   996  
   997  	for _, u := range userAgent {
   998  		for _, t := range allowedAgents {
   999  			if strings.HasPrefix(u, t) {
  1000  				return nil
  1001  			}
  1002  		}
  1003  	}
  1004  
  1005  	return fmt.Errorf("invalid user agent: %v", userAgent)
  1006  }
  1007  
  1008  // func reportDownStream(record *collector.FlowRecord, action *policy.FlowPolicy) *collector.FlowRecord {
  1009  // 	return &collector.FlowRecord{
  1010  // 		ContextID: record.ContextID,
  1011  // 		Destination: &collector.EndPoint{
  1012  // 			URI:        record.Destination.URI,
  1013  // 			HTTPMethod: record.Destination.HTTPMethod,
  1014  // 			Type:       collector.EndPointTypeExternalIP,
  1015  // 			Port:       record.Destination.Port,
  1016  // 			IP:         record.Destination.IP,
  1017  // 			ID:         action.ServiceID,
  1018  // 		},
  1019  // 		Source: &collector.EndPoint{
  1020  // 			Type: record.Destination.Type,
  1021  // 			ID:   record.Destination.ID,
  1022  // 			IP:   "0.0.0.0",
  1023  // 		},
  1024  // 		Action:      action.Action,
  1025  // 		L4Protocol:  record.L4Protocol,
  1026  // 		ServiceType: record.ServiceType,
  1027  // 		ServiceID:   record.ServiceID,
  1028  // 		Tags:        record.Tags,
  1029  // 		PolicyID:    action.PolicyID,
  1030  // 		Count:       1,
  1031  // 	}
  1032  // }