github.com/aporeto-inc/trireme-lib@v10.358.0+incompatible/controller/internal/enforcer/applicationproxy/tcp/tcp.go (about)

     1  package tcp
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"crypto/x509"
     7  	"fmt"
     8  	"io"
     9  	"net"
    10  	"sync"
    11  	"syscall"
    12  	"time"
    13  
    14  	"github.com/blang/semver"
    15  	"go.aporeto.io/enforcerd/trireme-lib/collector"
    16  	"go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/applicationproxy/common"
    17  	"go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/applicationproxy/markedconn"
    18  	"go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/applicationproxy/protomux"
    19  	"go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/applicationproxy/serviceregistry"
    20  	"go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/applicationproxy/tcp/verifier"
    21  	"go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/applicationproxy/tlshelper"
    22  	"go.aporeto.io/enforcerd/trireme-lib/controller/pkg/pucontext"
    23  	"go.aporeto.io/enforcerd/trireme-lib/controller/pkg/secrets"
    24  	"go.aporeto.io/enforcerd/trireme-lib/policy"
    25  	"go.uber.org/zap"
    26  )
    27  
    28  // Proxy maintains state for proxies connections from listen to backend.
    29  type Proxy struct {
    30  	collector      collector.EventCollector
    31  	myControllerID string
    32  	puID           string
    33  	mark           int
    34  
    35  	// TLS cert for the service
    36  	certificate *tls.Certificate
    37  	// caPool contains the system roots and in addition the services external CAs
    38  	caPool *x509.CertPool
    39  
    40  	// Verfier implements ID and IP ACL rules using the Peer Certificate Validation Handler
    41  	verifier verifier.Verifier
    42  
    43  	// List of local IP's
    44  	localIPs map[string]struct{}
    45  
    46  	agentVersion semver.Version
    47  
    48  	sync.RWMutex
    49  }
    50  
    51  // NewTCPProxy creates a new instance of proxy reate a new instance of Proxy
    52  func NewTCPProxy(
    53  	c collector.EventCollector,
    54  	puID string,
    55  	certificate *tls.Certificate,
    56  	caPool *x509.CertPool,
    57  	agentVersion semver.Version,
    58  	mark int,
    59  ) *Proxy {
    60  
    61  	localIPs := markedconn.GetInterfaces()
    62  
    63  	return &Proxy{
    64  		collector:    c,
    65  		puID:         puID,
    66  		verifier:     verifier.New(caPool),
    67  		localIPs:     localIPs,
    68  		certificate:  certificate,
    69  		caPool:       caPool,
    70  		agentVersion: agentVersion,
    71  		mark:         mark,
    72  	}
    73  }
    74  
    75  // RunNetworkServer implements enforcer.Enforcer interface
    76  func (p *Proxy) RunNetworkServer(
    77  	ctx context.Context,
    78  	listener net.Listener,
    79  ) error {
    80  
    81  	go func() {
    82  		for {
    83  			select {
    84  			case <-time.After(5 * time.Second):
    85  				p.Lock()
    86  				p.localIPs = markedconn.GetInterfaces()
    87  				p.Unlock()
    88  			case <-ctx.Done():
    89  				return
    90  			}
    91  		}
    92  	}()
    93  
    94  	// Encryption is done transparently for TCP.
    95  	go p.serve(ctx, listener)
    96  
    97  	return nil
    98  }
    99  
   100  // UpdateSecrets updates the secrets of the connections.
   101  func (p *Proxy) UpdateSecrets(
   102  	cert *tls.Certificate,
   103  	caPool *x509.CertPool,
   104  	s secrets.Secrets,
   105  	certPEM string,
   106  	keyPEM string,
   107  ) {
   108  	p.Lock()
   109  	defer p.Unlock()
   110  
   111  	p.certificate = cert
   112  	p.caPool = caPool
   113  
   114  	p.verifier.TrustCAs(caPool)
   115  }
   116  
   117  func (p *Proxy) serve(
   118  	ctx context.Context,
   119  	listener net.Listener,
   120  ) {
   121  	for {
   122  		select {
   123  		case <-ctx.Done():
   124  			return
   125  		default:
   126  			conn, err := listener.Accept()
   127  			if err != nil {
   128  				return
   129  			}
   130  			if protoListener, ok := listener.(*protomux.ProtoListener); ok {
   131  				// Windows: we don't really need the platform-specific data map for plain tcp (we can get it from the conn).
   132  				// So just remove from the map here.
   133  				markedconn.RemovePlatformData(protoListener.Listener, conn)
   134  			}
   135  			go p.handle(ctx, conn)
   136  		}
   137  	}
   138  }
   139  
   140  // ShutDown shuts down the server.
   141  func (p *Proxy) ShutDown() error {
   142  	return nil
   143  }
   144  
   145  func (p *Proxy) getService(
   146  	ip net.IP,
   147  	port int,
   148  	local bool,
   149  ) (*policy.ApplicationService, error) {
   150  
   151  	// If the destination is a local IP, it means that we are processing a client connection.
   152  	if local {
   153  		_, serviceData, err := serviceregistry.Instance().RetrieveDependentServiceDataByIDAndNetwork(p.puID, ip, port, "")
   154  		if err != nil {
   155  			return nil, fmt.Errorf("unknown dependent service pu:%s %s/%d: %s", p.puID, ip.String(), port, err)
   156  		}
   157  		return serviceData.ServiceObject, nil
   158  	}
   159  
   160  	portContext, err := serviceregistry.Instance().RetrieveExposedServiceContext(ip, port, "")
   161  	if err != nil {
   162  		return nil, fmt.Errorf("unknown exposed service %s/%d: %s", ip.String(), port, err)
   163  	}
   164  	return portContext.Service, nil
   165  }
   166  
   167  // handle handles a connection. upstream connection is the connection
   168  // to the next hop while downstream connection is the client who
   169  // initiated this connection.
   170  // Client PU:
   171  //   - upstream connection is from client to proxy.
   172  //   - downstream connection is from proxy to the nexthop (service, LB, PU)
   173  // Server PU:
   174  //   - upstream connection is from client or another enforcer
   175  //   - downstream connection is from proxy to the server nexthop
   176  func (p *Proxy) handle(ctx context.Context, upConn net.Conn) {
   177  
   178  	defer upConn.Close() // nolint
   179  
   180  	// TODO: handle proxy protocol
   181  
   182  	proxiedUpConn := upConn.(*markedconn.ProxiedConnection)
   183  	ip, port := proxiedUpConn.GetOriginalDestination()
   184  	platformData := proxiedUpConn.GetPlatformData()
   185  
   186  	service, err := p.getService(ip, port, p.isLocal(upConn))
   187  	if err != nil {
   188  		zap.L().Error("no service found", zap.Error(err))
   189  		return
   190  	}
   191  
   192  	puContext, err := p.puContextFromContextID(p.puID)
   193  	if err != nil {
   194  		zap.L().Error("no pu found", zap.String("puid", p.puID), zap.Error(err))
   195  		return
   196  	}
   197  
   198  	p.handleWithPUAndService(ctx, upConn, ip, port, platformData, puContext, service)
   199  }
   200  
   201  func (p *Proxy) getPolicyReporter(
   202  	puContext *pucontext.PUContext,
   203  	sip net.IP,
   204  	sport int,
   205  	dip net.IP,
   206  	dport int,
   207  	service *policy.ApplicationService,
   208  ) *lookup {
   209  
   210  	pfp := &proxyFlowProperties{
   211  		myControllerID: p.myControllerID,
   212  		DestIP:         dip.String(),
   213  		DestPort:       uint16(dport),
   214  		SourceIP:       sip.String(),
   215  		SourcePort:     0, // TODO: Investigate if this should be set
   216  		ServiceID:      service.ID,
   217  		DestType:       collector.EndPointTypePU,
   218  		SourceType:     collector.EndPointTypePU,
   219  	}
   220  
   221  	return &lookup{
   222  		SourceIP:   sip,
   223  		DestIP:     dip,
   224  		SourcePort: uint16(sport),
   225  		DestPort:   uint16(dport),
   226  		collector:  p.collector,
   227  		puContext:  puContext,
   228  		pfp:        pfp,
   229  	}
   230  }
   231  
   232  func (p *Proxy) handleWithPUAndService(
   233  	ctx context.Context,
   234  	upConn net.Conn,
   235  	origDestIP net.IP,
   236  	origDestPort int,
   237  	platformData *markedconn.PlatformData,
   238  	puContext *pucontext.PUContext,
   239  	service *policy.ApplicationService,
   240  ) {
   241  	// If we received connection isn't on private port, downstream connection has to be changed to
   242  	// service listening port.
   243  	downPort := origDestPort
   244  	if downPort == service.PublicPort() {
   245  		downPort = service.PrivatePort()
   246  	}
   247  
   248  	// Initialize a policy and reporting object
   249  	src := upConn.RemoteAddr().(*net.TCPAddr)
   250  	pr := p.getPolicyReporter(puContext, src.IP, src.Port, origDestIP, origDestPort, service)
   251  
   252  	downConn, err := p.initiateDownstreamTCPConnection(ctx, origDestIP, downPort, platformData)
   253  	if err != nil {
   254  		// Report rejection
   255  		pr.ReportStats(collector.EndPointTypeExternalIP, "", "default", collector.UnableToDial, nil, nil, false)
   256  		return
   257  	}
   258  	defer downConn.Close() // nolint
   259  
   260  	if err := p.proxyData(ctx, upConn, downConn, service, pr); err != nil {
   261  		zap.L().Debug("Error with proxying data", zap.Error(err))
   262  	}
   263  }
   264  
   265  func (p *Proxy) startEncryptedClientDataPath(
   266  	ctx context.Context,
   267  	downConn net.Conn,
   268  	upConn net.Conn,
   269  	service *policy.ApplicationService,
   270  	pr *lookup,
   271  ) error {
   272  
   273  	// Set a flag so policy engine knows if its on server or client
   274  	pr.client = true
   275  
   276  	// ServerName: Use first configured FQDN or the destination IP
   277  	serverName, err := common.GetTLSServerName(downConn.RemoteAddr().String(), service)
   278  	if err != nil {
   279  		return fmt.Errorf("unable to get the server name: %s", err)
   280  	}
   281  
   282  	// Encrypt Down Connection
   283  	p.RLock()
   284  	ca := p.caPool
   285  	certs := []tls.Certificate{}
   286  	if p.certificate != nil {
   287  		certs = append(certs, *p.certificate)
   288  	}
   289  	p.RUnlock()
   290  
   291  	t, err := getClientTLSConfig(ca, certs, serverName, service.External)
   292  	if err != nil {
   293  		return fmt.Errorf("unable to generate tls configuration: %s", err)
   294  	}
   295  
   296  	t.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
   297  		return p.verifier.VerifyPeerCertificate(rawCerts, verifiedChains, pr, false)
   298  	}
   299  
   300  	// Do TLS
   301  	tlsConn := tls.Client(downConn, t)
   302  	defer tlsConn.Close() // nolint errcheck
   303  	downConn = tlsConn
   304  
   305  	zap.L().Debug(
   306  		"Handle client connection",
   307  		zap.String("src", upConn.RemoteAddr().String()),
   308  		zap.String("dst", downConn.RemoteAddr().String()),
   309  		zap.String("tls.server", t.ServerName),
   310  		zap.Bool("tls.rootCAs", t.RootCAs != nil),
   311  		zap.Int("tls.certs", len(t.Certificates)),
   312  	)
   313  
   314  	// TLS will automatically start negotiation on write. Nothing to do for us.
   315  	p.copyData(ctx, upConn, downConn)
   316  	return nil
   317  }
   318  
   319  func (p *Proxy) startEncryptedServerDataPath(
   320  	ctx context.Context,
   321  	downConn net.Conn,
   322  	upConn net.Conn,
   323  	service *policy.ApplicationService,
   324  	pr *lookup,
   325  ) error {
   326  
   327  	zap.L().Debug(
   328  		"Handle server connection",
   329  		zap.String("src", upConn.RemoteAddr().String()),
   330  		zap.String("dst", downConn.RemoteAddr().String()),
   331  		zap.String("orig-dst", pr.DestIP.String()),
   332  		zap.Uint16("orig-dstport", pr.DestPort),
   333  	)
   334  
   335  	if service.PrivateTLSListener {
   336  		zap.L().Debug("convert connection to server as TLS")
   337  		downConn = tls.Client(downConn, &tls.Config{
   338  			InsecureSkipVerify: true,
   339  		})
   340  	}
   341  
   342  	proxiedUpConn := upConn.(*markedconn.ProxiedConnection)
   343  	_, originalPort := proxiedUpConn.GetOriginalDestination()
   344  
   345  	// Use Aporeto certs
   346  	p.RLock()
   347  	caPool := p.caPool
   348  	clientCerts := []tls.Certificate{}
   349  	if p.certificate != nil {
   350  		clientCerts = []tls.Certificate{*p.certificate}
   351  	}
   352  	p.RUnlock()
   353  
   354  	tlsConfig, err := getServerTLSConfig(
   355  		caPool,
   356  		clientCerts,
   357  		originalPort,
   358  		service,
   359  	)
   360  	if err != nil {
   361  		return fmt.Errorf("invalid tls server configuration: %s", err)
   362  	}
   363  
   364  	if tlsConfig != nil {
   365  		// Register Peer Certificate Verification so we can apply policies.
   366  		tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
   367  			return p.verifier.VerifyPeerCertificate(rawCerts, verifiedChains, pr, tlsConfig.ClientAuth == tls.RequireAndVerifyClientCert)
   368  		}
   369  
   370  		tlsConn := tls.Server(upConn.(*markedconn.ProxiedConnection).GetTCPConnection(), tlsConfig)
   371  		defer tlsConn.Close() // nolint errcheck
   372  
   373  		// Manually initiating the TLS handshake to get the connection state.
   374  		// The call to write will skip TLS handshake.
   375  		if err := tlsConn.Handshake(); err != nil {
   376  			return err
   377  		}
   378  
   379  		if pingEnabled(tlsConn) {
   380  			return p.processPingRequest(tlsConn, pr)
   381  		}
   382  
   383  		upConn = tlsConn
   384  	} else {
   385  		// In case of no TLS, apply IP policies right here.
   386  		action := pr.IPLookup()
   387  		zap.L().Debug("ip acl lookup", zap.Bool("action", action))
   388  		if !action {
   389  			return fmt.Errorf("ip acl drop")
   390  		}
   391  	}
   392  
   393  	// TLS will automatically start negotiation on write. Nothing to for us.
   394  	p.copyData(ctx, upConn, downConn)
   395  	return nil
   396  }
   397  
   398  func (p *Proxy) copyData(
   399  	ctx context.Context,
   400  	source, dest net.Conn,
   401  ) {
   402  	var wg sync.WaitGroup
   403  	wg.Add(2)
   404  	go func() {
   405  		dataprocessor(ctx, source, dest)
   406  		wg.Done()
   407  	}()
   408  	go func() {
   409  		dataprocessor(ctx, dest, source)
   410  		wg.Done()
   411  	}()
   412  	wg.Wait()
   413  }
   414  
   415  type readwithContext func(p []byte) (n int, err error)
   416  
   417  func (r readwithContext) Read(p []byte) (int, error) { return r(p) }
   418  
   419  func dataprocessor(
   420  	ctx context.Context,
   421  	source net.Conn,
   422  	dest net.Conn,
   423  ) { // nolint
   424  	defer func() {
   425  		switch connType := dest.(type) {
   426  		case *tls.Conn:
   427  			connType.CloseWrite() // nolint errcheck
   428  		case *net.TCPConn:
   429  			connType.CloseWrite() // nolint errcheck
   430  		case *markedconn.ProxiedConnection:
   431  			connType.GetTCPConnection().CloseWrite() // nolint errcheck
   432  		}
   433  	}()
   434  
   435  	if _, err := io.Copy(dest, readwithContext(
   436  		func(p []byte) (int, error) {
   437  			select {
   438  			case <-ctx.Done():
   439  				return 0, ctx.Err()
   440  			default:
   441  				return source.Read(p)
   442  			}
   443  		},
   444  	),
   445  	); err != nil { // nolint
   446  		logErr(err)
   447  	}
   448  }
   449  
   450  func (p *Proxy) proxyData(
   451  	ctx context.Context,
   452  	upConn net.Conn,
   453  	downConn net.Conn,
   454  	service *policy.ApplicationService,
   455  	pr *lookup,
   456  ) error {
   457  
   458  	// If the destination is not a local IP, it means that we are processing a client connection.
   459  	if p.isLocal(upConn) {
   460  		return p.startEncryptedClientDataPath(ctx, downConn, upConn, service, pr)
   461  	}
   462  
   463  	return p.startEncryptedServerDataPath(ctx, downConn, upConn, service, pr)
   464  }
   465  
   466  func (p *Proxy) puContextFromContextID(
   467  	puID string,
   468  ) (*pucontext.PUContext, error) {
   469  
   470  	sctx, err := serviceregistry.Instance().RetrieveServiceByID(puID)
   471  	if err != nil {
   472  		return nil, fmt.Errorf("Context not found %s", puID)
   473  	}
   474  
   475  	return sctx.PUContext, nil
   476  }
   477  
   478  // initiateDownstreamTCPConnection initiates a downstream TCP connection
   479  func (p *Proxy) initiateDownstreamTCPConnection(
   480  	ctx context.Context,
   481  	ip net.IP,
   482  	port int,
   483  	platformData *markedconn.PlatformData,
   484  ) (net.Conn, error) {
   485  
   486  	raddr := &net.TCPAddr{
   487  		IP:   ip,
   488  		Port: port,
   489  	}
   490  	return markedconn.DialMarkedWithContext(ctx, "tcp", raddr.String(), platformData, p.mark)
   491  }
   492  
   493  func (p *Proxy) isLocal(conn net.Conn) bool {
   494  
   495  	host, _, err := net.SplitHostPort(conn.RemoteAddr().String())
   496  	if err != nil {
   497  		return false
   498  	}
   499  
   500  	p.RLock()
   501  	defer p.RUnlock()
   502  
   503  	if _, ok := p.localIPs[host]; ok {
   504  		return true
   505  	}
   506  	return false
   507  }
   508  
   509  func logErr(err error) bool {
   510  	switch err.(type) {
   511  	case syscall.Errno:
   512  		zap.L().Error("Connection error to destination", zap.Error(err))
   513  	default:
   514  		zap.L().Error("Connection terminated", zap.Error(err))
   515  	}
   516  	return false
   517  }
   518  
   519  // getPublicServerTLSConfig provides the TLS configuration for the public port.
   520  // There is a valid case where we dont provide TLS configuration (nil) but error
   521  // is also nil to support the case of publicly exposed port.
   522  func getPublicServerTLSConfig(
   523  	caPool *x509.CertPool,
   524  	clientCerts []tls.Certificate,
   525  	service *policy.ApplicationService,
   526  ) (t *tls.Config, err error) {
   527  
   528  	// Apply Public configuration
   529  	if (service.PublicServiceTLSType != policy.ServiceTLSTypeCustom) && (service.PublicServiceTLSType != policy.ServiceTLSTypeAporeto) {
   530  		return nil, nil
   531  	}
   532  
   533  	t = tlshelper.NewBaseTLSServerConfig()
   534  
   535  	// Server Cert and Key.
   536  	if service.PublicServiceTLSType == policy.ServiceTLSTypeCustom {
   537  		// Use custom certs
   538  		if len(service.PublicServiceCertificate) > 0 && len(service.PublicServiceCertificateKey) > 0 {
   539  
   540  			cert, err := tls.X509KeyPair(service.PublicServiceCertificate, service.PublicServiceCertificateKey)
   541  			if err != nil {
   542  				return nil, fmt.Errorf("invalid public cert pair")
   543  			}
   544  			t.Certificates = []tls.Certificate{cert}
   545  		}
   546  	} else if service.PublicServiceTLSType == policy.ServiceTLSTypeAporeto {
   547  		// Use Aporeto certs
   548  		t.Certificates = clientCerts
   549  	}
   550  
   551  	// mTLS with client
   552  	if service.UserAuthorizationType == policy.UserAuthorizationMutualTLS {
   553  		t.ClientAuth = tls.RequireAndVerifyClientCert
   554  		t.ClientCAs = caPool
   555  		if len(service.MutualTLSTrustedRoots) > 0 {
   556  			if !t.ClientCAs.AppendCertsFromPEM(service.MutualTLSTrustedRoots) {
   557  				return nil, fmt.Errorf("Unable to process client CAs")
   558  			}
   559  		}
   560  	}
   561  
   562  	return t, nil
   563  }
   564  
   565  // getExposedServerMTLSConfig provides the mTLS configuration for the server.
   566  func getExposedServerMTLSConfig(
   567  	caPool *x509.CertPool,
   568  	certs []tls.Certificate,
   569  ) (t *tls.Config, err error) {
   570  
   571  	if len(certs) == 0 {
   572  		return nil, fmt.Errorf("Failed to start encryption")
   573  	}
   574  
   575  	t = tlshelper.NewBaseTLSServerConfig()
   576  	t.Certificates = certs
   577  	t.ClientCAs = caPool
   578  	t.ClientAuth = tls.RequireAndVerifyClientCert
   579  	return t, nil
   580  }
   581  
   582  // getServerTLSConfig provides the server TLS configuration. It handles the
   583  // server on public and exposed ports.
   584  // returns:
   585  //    - error
   586  //    - tls.Config which can be nil even when error is nil to indicate no TLS
   587  func getServerTLSConfig(
   588  	caPool *x509.CertPool,
   589  	certs []tls.Certificate,
   590  	originalPort int,
   591  	service *policy.ApplicationService,
   592  ) (t *tls.Config, err error) {
   593  
   594  	if originalPort != service.PublicPort() {
   595  		// mTLS for Up Connection for exposed ports protected by Aporeto
   596  		return getExposedServerMTLSConfig(caPool, certs)
   597  	}
   598  	// TLS configuration supported on public ports
   599  	return getPublicServerTLSConfig(caPool, certs, service)
   600  }
   601  
   602  // getTLSConfig generates a tls.Config for a given client based on the service it may be accessing.
   603  // - Services protected by Aporeto should do mTLS.
   604  // - External (Third Party) Services do TLS only.
   605  func getClientTLSConfig(
   606  	caPool *x509.CertPool,
   607  	clientCerts []tls.Certificate,
   608  	serverName string,
   609  	external bool,
   610  ) (t *tls.Config, err error) {
   611  
   612  	t = tlshelper.NewBaseTLSClientConfig()
   613  	t.RootCAs = caPool
   614  	t.ServerName = serverName
   615  
   616  	if !external {
   617  		if len(clientCerts) == 0 {
   618  			return nil, fmt.Errorf("no client certs provided for mTLS")
   619  		}
   620  		// Do mTLS enforcer protected services. TLS for external service.
   621  		t.Certificates = clientCerts
   622  	}
   623  	return t, nil
   624  }