github.com/aporeto-inc/trireme-lib@v10.358.0+incompatible/controller/internal/enforcer/applicationproxy/http/ping_http.go (about)

     1  package httpproxy
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"net"
     7  	"net/http"
     8  	"time"
     9  
    10  	"go.aporeto.io/enforcerd/trireme-lib/collector"
    11  	"go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/apiauth"
    12  	"go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/applicationproxy/common"
    13  	"go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/applicationproxy/markedconn"
    14  	"go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/applicationproxy/serviceregistry"
    15  	"go.aporeto.io/enforcerd/trireme-lib/controller/pkg/packet"
    16  	"go.aporeto.io/enforcerd/trireme-lib/controller/pkg/servicetokens"
    17  	"go.aporeto.io/enforcerd/trireme-lib/policy"
    18  	"go.aporeto.io/gaia"
    19  	"go.aporeto.io/gaia/x509extensions"
    20  	"go.uber.org/zap"
    21  )
    22  
    23  const fourTupleKey = "fourTuple"
    24  
    25  type fourTuple struct {
    26  	sourceAddress      net.IP
    27  	destinationAddress net.IP
    28  	sourcePort         int
    29  	destinationPort    int
    30  }
    31  
    32  // InitiatePing starts an encrypted connection to the given config.
    33  func (p *Config) InitiatePing(ctx context.Context, sctx *serviceregistry.ServiceContext, sdata *serviceregistry.DependentServiceData, pingConfig *policy.PingConfig) error {
    34  
    35  	zap.L().Debug("Initiating L7 ping")
    36  
    37  	for i := 0; i < pingConfig.Iterations; i++ {
    38  		if err := p.sendPingRequest(ctx, pingConfig, sctx, sdata, i); err != nil {
    39  			return err
    40  		}
    41  	}
    42  
    43  	return nil
    44  }
    45  
    46  func (p *Config) sendPingRequest(
    47  	ctx context.Context,
    48  	pingConfig *policy.PingConfig,
    49  	sctx *serviceregistry.ServiceContext,
    50  	sdata *serviceregistry.DependentServiceData,
    51  	iterationID int) error {
    52  
    53  	pingID := pingConfig.ID
    54  	destIP := pingConfig.IP
    55  	destPort := pingConfig.Port
    56  
    57  	_, netaction, _ := sctx.PUContext.ApplicationACLPolicyFromAddr(destIP, destPort, packet.IPProtocolTCP)
    58  
    59  	pingErr := "dial"
    60  	if e := pingConfig.Error(); e != "" {
    61  		pingErr = e
    62  	}
    63  
    64  	pr := &collector.PingReport{
    65  		PingID:               pingID,
    66  		IterationID:          iterationID,
    67  		ServiceID:            sdata.APICache.ID,
    68  		PUID:                 sctx.PUContext.ManagementID(),
    69  		Namespace:            sctx.PUContext.ManagementNamespace(),
    70  		Protocol:             6,
    71  		ServiceType:          "L7",
    72  		AgentVersion:         p.agentVersion.String(),
    73  		ApplicationListening: false,
    74  		ACLPolicyID:          netaction.PolicyID,
    75  		ACLPolicyAction:      netaction.Action,
    76  		Error:                pingErr,
    77  		TargetTCPNetworks:    pingConfig.TargetTCPNetworks,
    78  		ExcludedNetworks:     pingConfig.ExcludedNetworks,
    79  		Type:                 gaia.PingProbeTypeRequest,
    80  		RemoteEndpointType:   collector.EndPointTypeExternalIP,
    81  		Claims:               sctx.PUContext.Identity().GetSlice(),
    82  		ClaimsType:           gaia.PingProbeClaimsTypeTransmitted,
    83  		RemoteNamespaceType:  gaia.PingProbeRemoteNamespaceTypePlain,
    84  		PayloadSizeType:      gaia.PingProbePayloadSizeTypeTransmitted,
    85  	}
    86  
    87  	ft := &fourTuple{}
    88  
    89  	p.RLock()
    90  	encodingKey := p.secrets.EncodingKey()
    91  	pubKey := p.secrets.TransmittedKey()
    92  	p.RUnlock()
    93  
    94  	pingPayload := &policy.PingPayload{
    95  		PingID:      pingID,
    96  		IterationID: iterationID,
    97  	}
    98  
    99  	token, err := servicetokens.CreateAndSign(
   100  		"",
   101  		sctx.PUContext.Identity().GetSlice(),
   102  		sctx.PUContext.Scopes(),
   103  		sctx.PUContext.ManagementID(),
   104  		apiauth.DefaultValidity,
   105  		encodingKey,
   106  		pingPayload,
   107  	)
   108  	if err != nil {
   109  		return err
   110  	}
   111  
   112  	networkDialerWithContext := func(ctx context.Context, _, addr string) (net.Conn, error) {
   113  
   114  		conn, err := dial(ctx, addr, p.mark)
   115  		if err != nil {
   116  			return nil, fmt.Errorf("unable to dial remote: %s", err)
   117  		}
   118  
   119  		if v := ctx.Value(fourTupleKey); v != nil {
   120  			if r, ok := v.(*fourTuple); ok {
   121  				laddr := conn.LocalAddr().(*net.TCPAddr)
   122  				raddr := conn.RemoteAddr().(*net.TCPAddr)
   123  				r.sourceAddress = laddr.IP
   124  				r.sourcePort = laddr.Port
   125  				r.destinationAddress = raddr.IP
   126  				r.destinationPort = raddr.Port
   127  			}
   128  		}
   129  
   130  		return conn, nil
   131  	}
   132  
   133  	raddr := &net.TCPAddr{
   134  		IP:   destIP,
   135  		Port: int(destPort),
   136  	}
   137  
   138  	// ServerName: Use first configured FQDN or the destination IP
   139  	serverName, err := common.GetTLSServerName(raddr.String(), sdata.ServiceObject)
   140  	if err != nil {
   141  		return fmt.Errorf("unable to get the server name: %s", err)
   142  	}
   143  
   144  	// Used to validate the hostname in the returned server certs.
   145  	// TODO: Maybe we should elevate this as first class citizen ?
   146  	p.tlsClientConfig.ServerName = serverName
   147  
   148  	encryptedTransport := &http.Transport{
   149  		TLSClientConfig:     p.tlsClientConfig,
   150  		DialContext:         networkDialerWithContext,
   151  		MaxIdleConnsPerHost: 2000,
   152  		MaxIdleConns:        2000,
   153  		ForceAttemptHTTP2:   true,
   154  	}
   155  
   156  	client := &http.Client{
   157  		Transport: encryptedTransport,
   158  		Timeout:   5 * time.Second,
   159  	}
   160  
   161  	host := fmt.Sprintf("https://%s:%d", destIP, destPort)
   162  	ctxWithReport := context.WithValue(ctx, fourTupleKey, ft) // nolint: golint,staticcheck
   163  	req, err := http.NewRequestWithContext(ctxWithReport, "GET", host, nil)
   164  	if err != nil {
   165  		return err
   166  	}
   167  
   168  	defer p.collector.CollectPingEvent(pr)
   169  
   170  	pr.PayloadSize = len(pubKey) + len(token)
   171  
   172  	req.Header.Add("X-APORETO-KEY", string(pubKey))
   173  	req.Header.Add("X-APORETO-AUTH", token)
   174  
   175  	startTime := time.Now()
   176  	res, err := client.Do(req)
   177  	if err != nil {
   178  		pr.Error = err.Error()
   179  		pr.FourTuple = fmt.Sprintf(
   180  			"%s:%s:%d:%d",
   181  			ft.sourceAddress.String(),
   182  			ft.destinationAddress.String(),
   183  			ft.sourcePort,
   184  			ft.destinationPort,
   185  		)
   186  		return err
   187  	}
   188  
   189  	res.Body.Close() // nolint: errcheck
   190  
   191  	pr.Error = ""
   192  	pr.RTT = time.Since(startTime).String()
   193  	pr.ApplicationListening = true
   194  	pr.Type = gaia.PingProbeTypeResponse
   195  	pr.FourTuple = fmt.Sprintf(
   196  		"%s:%s:%d:%d",
   197  		ft.destinationAddress.String(),
   198  		ft.sourceAddress.String(),
   199  		ft.destinationPort,
   200  		ft.sourcePort,
   201  	)
   202  
   203  	if len(res.TLS.PeerCertificates) > 0 {
   204  		pr.RemotePUID = res.TLS.PeerCertificates[0].Subject.CommonName
   205  		pr.RemoteEndpointType = collector.EndPointTypePU
   206  		if len(res.TLS.PeerCertificates[0].Subject.Organization) > 0 {
   207  			pr.RemoteNamespace = res.TLS.PeerCertificates[0].Subject.Organization[0]
   208  		}
   209  		pr.PeerCertIssuer = res.TLS.PeerCertificates[0].Issuer.String()
   210  		pr.PeerCertSubject = res.TLS.PeerCertificates[0].Subject.String()
   211  		pr.PeerCertExpiry = res.TLS.PeerCertificates[0].NotAfter
   212  
   213  		if found, controller := common.ExtractExtension(x509extensions.Controller(), res.TLS.PeerCertificates[0].Extensions); found {
   214  			pr.RemoteController = string(controller)
   215  		}
   216  	}
   217  
   218  	return nil
   219  }
   220  
   221  func dial(ctx context.Context, addr string, mark int) (net.Conn, error) {
   222  
   223  	d := net.Dialer{
   224  		Timeout: 5 * time.Second,
   225  		Control: markedconn.ControlFunc(mark, false, nil),
   226  	}
   227  	return d.DialContext(ctx, "tcp", addr)
   228  }