github.com/aporeto-inc/trireme-lib@v10.358.0+incompatible/controller/internal/enforcer/applicationproxy/tcp/ping_tcp.go (about)

     1  package tcp
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"crypto/tls"
     7  	"crypto/x509"
     8  	"encoding/json"
     9  	"fmt"
    10  	"io"
    11  	"net"
    12  	"time"
    13  
    14  	"go.aporeto.io/enforcerd/trireme-lib/collector"
    15  	"go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/applicationproxy/common"
    16  	"go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/applicationproxy/markedconn"
    17  	"go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/applicationproxy/pingrequest"
    18  	"go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/applicationproxy/serviceregistry"
    19  	"go.aporeto.io/enforcerd/trireme-lib/controller/pkg/packet"
    20  	"go.aporeto.io/enforcerd/trireme-lib/policy"
    21  	"go.aporeto.io/gaia"
    22  	"go.aporeto.io/gaia/x509extensions"
    23  	"go.uber.org/zap"
    24  )
    25  
    26  // InitiatePing initiates the ping request
    27  func (p *Proxy) InitiatePing(ctx context.Context, sctx *serviceregistry.ServiceContext, sdata *serviceregistry.DependentServiceData, pingConfig *policy.PingConfig) error {
    28  
    29  	zap.L().Debug("Initiating L4 ping")
    30  
    31  	for i := 0; i < pingConfig.Iterations; i++ {
    32  		if err := p.sendPingRequest(ctx, pingConfig, sctx, sdata, i); err != nil {
    33  			return err
    34  		}
    35  	}
    36  
    37  	return nil
    38  }
    39  
    40  func (p *Proxy) sendPingRequest(
    41  	ctx context.Context,
    42  	pingConfig *policy.PingConfig,
    43  	sctx *serviceregistry.ServiceContext,
    44  	sdata *serviceregistry.DependentServiceData,
    45  	iterationID int) error {
    46  
    47  	pingID := pingConfig.ID
    48  	destIP := pingConfig.IP
    49  	destPort := pingConfig.Port
    50  
    51  	_, netaction, _ := sctx.PUContext.ApplicationACLPolicyFromAddr(destIP, destPort, packet.IPProtocolTCP)
    52  
    53  	pingErr := "dial"
    54  	if e := pingConfig.Error(); e != "" {
    55  		pingErr = e
    56  	}
    57  
    58  	pr := &collector.PingReport{
    59  		PingID:               pingID,
    60  		IterationID:          iterationID,
    61  		PUID:                 sctx.PUContext.ManagementID(),
    62  		Namespace:            sctx.PUContext.ManagementNamespace(),
    63  		Protocol:             6,
    64  		ServiceType:          "L4",
    65  		AgentVersion:         p.agentVersion.String(),
    66  		ApplicationListening: false,
    67  		ACLPolicyID:          netaction.PolicyID,
    68  		ACLPolicyAction:      netaction.Action,
    69  		Error:                pingErr,
    70  		TargetTCPNetworks:    pingConfig.TargetTCPNetworks,
    71  		ExcludedNetworks:     pingConfig.ExcludedNetworks,
    72  		Type:                 gaia.PingProbeTypeRequest,
    73  		RemoteEndpointType:   collector.EndPointTypeExternalIP,
    74  		ClaimsType:           gaia.PingProbeClaimsTypeReceived,
    75  		RemoteNamespaceType:  gaia.PingProbeRemoteNamespaceTypePlain,
    76  		PayloadSizeType:      gaia.PingProbePayloadSizeTypeTransmitted,
    77  	}
    78  
    79  	defer p.collector.CollectPingEvent(pr)
    80  
    81  	conn, err := dial(ctx, destIP, destPort, p.mark)
    82  	if err != nil {
    83  		return err
    84  	}
    85  	defer conn.Close() // nolint: errcheck
    86  
    87  	src := conn.RemoteAddr().(*net.TCPAddr)
    88  	pl := p.getPolicyReporter(sctx.PUContext, src.IP, src.Port, destIP, int(destPort), sdata.ServiceObject)
    89  	pl.client = true
    90  
    91  	// ServerName: Use first configured FQDN or the destination IP
    92  	serverName, err := common.GetTLSServerName(conn.RemoteAddr().String(), sdata.ServiceObject)
    93  	if err != nil {
    94  		return fmt.Errorf("unable to get the server name: %s", err)
    95  	}
    96  
    97  	// Encrypt Down Connection
    98  	p.RLock()
    99  	ca := p.caPool
   100  	p.RUnlock()
   101  
   102  	tlsCert, err := tls.X509KeyPair([]byte(pingConfig.ServiceCertificate), []byte(pingConfig.ServiceKey))
   103  	if err != nil {
   104  		return fmt.Errorf("unable to parse X509 certificate: %w", err)
   105  	}
   106  
   107  	certs := []tls.Certificate{
   108  		tlsCert,
   109  	}
   110  
   111  	t, err := getClientTLSConfig(ca, certs, serverName, false)
   112  	if err != nil {
   113  		return fmt.Errorf("unable to generate tls configuration: %s", err)
   114  	}
   115  
   116  	// Do TLS
   117  	tlsConn := tls.Client(conn, t)
   118  	defer tlsConn.Close() // nolint errcheck
   119  
   120  	payload := &policy.PingPayload{
   121  		PingID:      pingID,
   122  		IterationID: iterationID,
   123  		ServiceType: policy.ServiceTCP,
   124  	}
   125  
   126  	host := fmt.Sprintf("https://%s:%d", destIP, destPort)
   127  	data, err := pingrequest.CreateRaw(host, payload)
   128  	if err != nil {
   129  		return err
   130  	}
   131  
   132  	laddr := tlsConn.LocalAddr().(*net.TCPAddr)
   133  	raddr := tlsConn.RemoteAddr().(*net.TCPAddr)
   134  
   135  	startTime := time.Now()
   136  	if err := write(tlsConn, data); err != nil {
   137  		pr.Error = err.Error()
   138  		pr.FourTuple = fmt.Sprintf(
   139  			"%s:%s:%d:%d",
   140  			laddr.IP.String(),
   141  			raddr.IP.String(),
   142  			laddr.Port,
   143  			raddr.Port,
   144  		)
   145  		return err
   146  	}
   147  
   148  	pr.Error = ""
   149  	pr.RTT = time.Since(startTime).String()
   150  	pr.PayloadSize = len(data)
   151  	pr.ApplicationListening = true
   152  	pr.Type = gaia.PingProbeTypeResponse
   153  	pr.FourTuple = fmt.Sprintf(
   154  		"%s:%s:%d:%d",
   155  		raddr.IP.String(),
   156  		laddr.IP.String(),
   157  		raddr.Port,
   158  		laddr.Port,
   159  	)
   160  
   161  	if len(tlsConn.ConnectionState().PeerCertificates) > 0 {
   162  		return extract(pr, tlsConn.ConnectionState().PeerCertificates[0], pl)
   163  	}
   164  
   165  	return nil
   166  }
   167  
   168  func (p *Proxy) processPingRequest(conn *tls.Conn, pl *lookup) error {
   169  
   170  	zap.L().Debug("Processing ping request")
   171  
   172  	if err := conn.SetReadDeadline(time.Now().Add(3 * time.Second)); err != nil {
   173  		return err
   174  	}
   175  
   176  	var dst bytes.Buffer
   177  	if _, err := io.Copy(&dst, conn); err != nil {
   178  		return err
   179  	}
   180  
   181  	pp, err := pingrequest.ExtractRaw(dst.Bytes())
   182  	if err != nil {
   183  		return err
   184  	}
   185  
   186  	pr := &collector.PingReport{
   187  		PingID:          pp.PingID,
   188  		IterationID:     pp.IterationID,
   189  		Type:            gaia.PingProbeTypeRequest,
   190  		PUID:            pl.puContext.ManagementID(),
   191  		Namespace:       pl.puContext.ManagementNamespace(),
   192  		PayloadSize:     len(dst.Bytes()),
   193  		PayloadSizeType: gaia.PingProbePayloadSizeTypeReceived,
   194  		Protocol:        6,
   195  		ServiceType:     "L4",
   196  		FourTuple: fmt.Sprintf("%s:%s:%d:%d",
   197  			pl.SourceIP.String(),
   198  			pl.DestIP.String(),
   199  			pl.SourcePort,
   200  			pl.DestPort),
   201  		AgentVersion:        p.agentVersion.String(),
   202  		RemoteEndpointType:  collector.EndPointTypePU,
   203  		IsServer:            true,
   204  		ClaimsType:          gaia.PingProbeClaimsTypeReceived,
   205  		RemoteNamespaceType: gaia.PingProbeRemoteNamespaceTypePlain,
   206  		TargetTCPNetworks:   true,
   207  		ExcludedNetworks:    false,
   208  	}
   209  
   210  	if pp.ServiceType != policy.ServiceTCP {
   211  		pr.Error = fmt.Sprintf("service type mismatch, expected: %d, actual: %d", policy.ServiceTCP, pp.ServiceType)
   212  	}
   213  
   214  	if len(conn.ConnectionState().PeerCertificates) > 0 {
   215  		if err := extract(pr, conn.ConnectionState().PeerCertificates[0], pl); err != nil {
   216  			return err
   217  		}
   218  	}
   219  
   220  	p.collector.CollectPingEvent(pr)
   221  
   222  	return nil
   223  }
   224  
   225  func extract(pr *collector.PingReport, cert *x509.Certificate, pl *lookup) error {
   226  
   227  	pr.RemotePUID = cert.Subject.CommonName
   228  	pr.RemoteEndpointType = collector.EndPointTypePU
   229  	if len(cert.Subject.Organization) > 0 {
   230  		pr.RemoteNamespace = cert.Subject.Organization[0]
   231  	}
   232  	pr.PeerCertIssuer = cert.Issuer.String()
   233  	pr.PeerCertSubject = cert.Subject.String()
   234  	pr.PeerCertExpiry = cert.NotAfter
   235  
   236  	if found, controller := common.ExtractExtension(x509extensions.Controller(), cert.Extensions); found {
   237  		pr.RemoteController = string(controller)
   238  	}
   239  
   240  	if found, value := common.ExtractExtension(x509extensions.IdentityTags(), cert.Extensions); found {
   241  
   242  		claims := []string{}
   243  		if err := json.Unmarshal(value, &claims); err != nil {
   244  			return fmt.Errorf("unable to unmarshal tags: %w", err)
   245  		}
   246  
   247  		pr.Claims = claims
   248  
   249  		tags := policy.NewTagStoreFromSlice(claims)
   250  		_, pkt := pl.Policy(tags)
   251  
   252  		pr.PolicyID = pkt.PolicyID
   253  		pr.PolicyAction = pkt.Action
   254  		if pkt.Action.Rejected() {
   255  			pr.Error = collector.PolicyDrop
   256  		}
   257  	}
   258  
   259  	return nil
   260  }
   261  
   262  func pingEnabled(conn *tls.Conn) bool {
   263  
   264  	peerCerts := conn.ConnectionState().PeerCertificates
   265  	if len(peerCerts) <= 0 {
   266  		return false
   267  	}
   268  
   269  	found, _ := common.ExtractExtension(x509extensions.Ping(), peerCerts[0].Extensions)
   270  	return found
   271  }
   272  
   273  func dial(ctx context.Context, ip net.IP, port uint16, mark int) (net.Conn, error) {
   274  
   275  	raddr := &net.TCPAddr{
   276  		IP:   ip,
   277  		Port: int(port),
   278  	}
   279  
   280  	d := net.Dialer{
   281  		Timeout: 5 * time.Second,
   282  		Control: markedconn.ControlFunc(mark, false, nil),
   283  	}
   284  	return d.DialContext(ctx, "tcp", raddr.String())
   285  }
   286  
   287  func write(conn net.Conn, data []byte) error {
   288  
   289  	if err := conn.SetWriteDeadline(time.Now().Add(5 * time.Second)); err != nil {
   290  		return err
   291  	}
   292  
   293  	n, err := conn.Write(data)
   294  	if err != nil && err != io.EOF {
   295  		return err
   296  	}
   297  
   298  	if n != len(data) {
   299  		return fmt.Errorf("failed to write data, expected: %v, written: %v", len(data), n)
   300  	}
   301  
   302  	return nil
   303  }