github.com/aporeto-inc/trireme-lib@v10.358.0+incompatible/controller/internal/enforcer/applicationproxy/tcp/ping_tcp.go (about) 1 package tcp 2 3 import ( 4 "bytes" 5 "context" 6 "crypto/tls" 7 "crypto/x509" 8 "encoding/json" 9 "fmt" 10 "io" 11 "net" 12 "time" 13 14 "go.aporeto.io/enforcerd/trireme-lib/collector" 15 "go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/applicationproxy/common" 16 "go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/applicationproxy/markedconn" 17 "go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/applicationproxy/pingrequest" 18 "go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/applicationproxy/serviceregistry" 19 "go.aporeto.io/enforcerd/trireme-lib/controller/pkg/packet" 20 "go.aporeto.io/enforcerd/trireme-lib/policy" 21 "go.aporeto.io/gaia" 22 "go.aporeto.io/gaia/x509extensions" 23 "go.uber.org/zap" 24 ) 25 26 // InitiatePing initiates the ping request 27 func (p *Proxy) InitiatePing(ctx context.Context, sctx *serviceregistry.ServiceContext, sdata *serviceregistry.DependentServiceData, pingConfig *policy.PingConfig) error { 28 29 zap.L().Debug("Initiating L4 ping") 30 31 for i := 0; i < pingConfig.Iterations; i++ { 32 if err := p.sendPingRequest(ctx, pingConfig, sctx, sdata, i); err != nil { 33 return err 34 } 35 } 36 37 return nil 38 } 39 40 func (p *Proxy) sendPingRequest( 41 ctx context.Context, 42 pingConfig *policy.PingConfig, 43 sctx *serviceregistry.ServiceContext, 44 sdata *serviceregistry.DependentServiceData, 45 iterationID int) error { 46 47 pingID := pingConfig.ID 48 destIP := pingConfig.IP 49 destPort := pingConfig.Port 50 51 _, netaction, _ := sctx.PUContext.ApplicationACLPolicyFromAddr(destIP, destPort, packet.IPProtocolTCP) 52 53 pingErr := "dial" 54 if e := pingConfig.Error(); e != "" { 55 pingErr = e 56 } 57 58 pr := &collector.PingReport{ 59 PingID: pingID, 60 IterationID: iterationID, 61 PUID: sctx.PUContext.ManagementID(), 62 Namespace: sctx.PUContext.ManagementNamespace(), 63 Protocol: 6, 64 ServiceType: "L4", 65 AgentVersion: p.agentVersion.String(), 66 ApplicationListening: false, 67 ACLPolicyID: netaction.PolicyID, 68 ACLPolicyAction: netaction.Action, 69 Error: pingErr, 70 TargetTCPNetworks: pingConfig.TargetTCPNetworks, 71 ExcludedNetworks: pingConfig.ExcludedNetworks, 72 Type: gaia.PingProbeTypeRequest, 73 RemoteEndpointType: collector.EndPointTypeExternalIP, 74 ClaimsType: gaia.PingProbeClaimsTypeReceived, 75 RemoteNamespaceType: gaia.PingProbeRemoteNamespaceTypePlain, 76 PayloadSizeType: gaia.PingProbePayloadSizeTypeTransmitted, 77 } 78 79 defer p.collector.CollectPingEvent(pr) 80 81 conn, err := dial(ctx, destIP, destPort, p.mark) 82 if err != nil { 83 return err 84 } 85 defer conn.Close() // nolint: errcheck 86 87 src := conn.RemoteAddr().(*net.TCPAddr) 88 pl := p.getPolicyReporter(sctx.PUContext, src.IP, src.Port, destIP, int(destPort), sdata.ServiceObject) 89 pl.client = true 90 91 // ServerName: Use first configured FQDN or the destination IP 92 serverName, err := common.GetTLSServerName(conn.RemoteAddr().String(), sdata.ServiceObject) 93 if err != nil { 94 return fmt.Errorf("unable to get the server name: %s", err) 95 } 96 97 // Encrypt Down Connection 98 p.RLock() 99 ca := p.caPool 100 p.RUnlock() 101 102 tlsCert, err := tls.X509KeyPair([]byte(pingConfig.ServiceCertificate), []byte(pingConfig.ServiceKey)) 103 if err != nil { 104 return fmt.Errorf("unable to parse X509 certificate: %w", err) 105 } 106 107 certs := []tls.Certificate{ 108 tlsCert, 109 } 110 111 t, err := getClientTLSConfig(ca, certs, serverName, false) 112 if err != nil { 113 return fmt.Errorf("unable to generate tls configuration: %s", err) 114 } 115 116 // Do TLS 117 tlsConn := tls.Client(conn, t) 118 defer tlsConn.Close() // nolint errcheck 119 120 payload := &policy.PingPayload{ 121 PingID: pingID, 122 IterationID: iterationID, 123 ServiceType: policy.ServiceTCP, 124 } 125 126 host := fmt.Sprintf("https://%s:%d", destIP, destPort) 127 data, err := pingrequest.CreateRaw(host, payload) 128 if err != nil { 129 return err 130 } 131 132 laddr := tlsConn.LocalAddr().(*net.TCPAddr) 133 raddr := tlsConn.RemoteAddr().(*net.TCPAddr) 134 135 startTime := time.Now() 136 if err := write(tlsConn, data); err != nil { 137 pr.Error = err.Error() 138 pr.FourTuple = fmt.Sprintf( 139 "%s:%s:%d:%d", 140 laddr.IP.String(), 141 raddr.IP.String(), 142 laddr.Port, 143 raddr.Port, 144 ) 145 return err 146 } 147 148 pr.Error = "" 149 pr.RTT = time.Since(startTime).String() 150 pr.PayloadSize = len(data) 151 pr.ApplicationListening = true 152 pr.Type = gaia.PingProbeTypeResponse 153 pr.FourTuple = fmt.Sprintf( 154 "%s:%s:%d:%d", 155 raddr.IP.String(), 156 laddr.IP.String(), 157 raddr.Port, 158 laddr.Port, 159 ) 160 161 if len(tlsConn.ConnectionState().PeerCertificates) > 0 { 162 return extract(pr, tlsConn.ConnectionState().PeerCertificates[0], pl) 163 } 164 165 return nil 166 } 167 168 func (p *Proxy) processPingRequest(conn *tls.Conn, pl *lookup) error { 169 170 zap.L().Debug("Processing ping request") 171 172 if err := conn.SetReadDeadline(time.Now().Add(3 * time.Second)); err != nil { 173 return err 174 } 175 176 var dst bytes.Buffer 177 if _, err := io.Copy(&dst, conn); err != nil { 178 return err 179 } 180 181 pp, err := pingrequest.ExtractRaw(dst.Bytes()) 182 if err != nil { 183 return err 184 } 185 186 pr := &collector.PingReport{ 187 PingID: pp.PingID, 188 IterationID: pp.IterationID, 189 Type: gaia.PingProbeTypeRequest, 190 PUID: pl.puContext.ManagementID(), 191 Namespace: pl.puContext.ManagementNamespace(), 192 PayloadSize: len(dst.Bytes()), 193 PayloadSizeType: gaia.PingProbePayloadSizeTypeReceived, 194 Protocol: 6, 195 ServiceType: "L4", 196 FourTuple: fmt.Sprintf("%s:%s:%d:%d", 197 pl.SourceIP.String(), 198 pl.DestIP.String(), 199 pl.SourcePort, 200 pl.DestPort), 201 AgentVersion: p.agentVersion.String(), 202 RemoteEndpointType: collector.EndPointTypePU, 203 IsServer: true, 204 ClaimsType: gaia.PingProbeClaimsTypeReceived, 205 RemoteNamespaceType: gaia.PingProbeRemoteNamespaceTypePlain, 206 TargetTCPNetworks: true, 207 ExcludedNetworks: false, 208 } 209 210 if pp.ServiceType != policy.ServiceTCP { 211 pr.Error = fmt.Sprintf("service type mismatch, expected: %d, actual: %d", policy.ServiceTCP, pp.ServiceType) 212 } 213 214 if len(conn.ConnectionState().PeerCertificates) > 0 { 215 if err := extract(pr, conn.ConnectionState().PeerCertificates[0], pl); err != nil { 216 return err 217 } 218 } 219 220 p.collector.CollectPingEvent(pr) 221 222 return nil 223 } 224 225 func extract(pr *collector.PingReport, cert *x509.Certificate, pl *lookup) error { 226 227 pr.RemotePUID = cert.Subject.CommonName 228 pr.RemoteEndpointType = collector.EndPointTypePU 229 if len(cert.Subject.Organization) > 0 { 230 pr.RemoteNamespace = cert.Subject.Organization[0] 231 } 232 pr.PeerCertIssuer = cert.Issuer.String() 233 pr.PeerCertSubject = cert.Subject.String() 234 pr.PeerCertExpiry = cert.NotAfter 235 236 if found, controller := common.ExtractExtension(x509extensions.Controller(), cert.Extensions); found { 237 pr.RemoteController = string(controller) 238 } 239 240 if found, value := common.ExtractExtension(x509extensions.IdentityTags(), cert.Extensions); found { 241 242 claims := []string{} 243 if err := json.Unmarshal(value, &claims); err != nil { 244 return fmt.Errorf("unable to unmarshal tags: %w", err) 245 } 246 247 pr.Claims = claims 248 249 tags := policy.NewTagStoreFromSlice(claims) 250 _, pkt := pl.Policy(tags) 251 252 pr.PolicyID = pkt.PolicyID 253 pr.PolicyAction = pkt.Action 254 if pkt.Action.Rejected() { 255 pr.Error = collector.PolicyDrop 256 } 257 } 258 259 return nil 260 } 261 262 func pingEnabled(conn *tls.Conn) bool { 263 264 peerCerts := conn.ConnectionState().PeerCertificates 265 if len(peerCerts) <= 0 { 266 return false 267 } 268 269 found, _ := common.ExtractExtension(x509extensions.Ping(), peerCerts[0].Extensions) 270 return found 271 } 272 273 func dial(ctx context.Context, ip net.IP, port uint16, mark int) (net.Conn, error) { 274 275 raddr := &net.TCPAddr{ 276 IP: ip, 277 Port: int(port), 278 } 279 280 d := net.Dialer{ 281 Timeout: 5 * time.Second, 282 Control: markedconn.ControlFunc(mark, false, nil), 283 } 284 return d.DialContext(ctx, "tcp", raddr.String()) 285 } 286 287 func write(conn net.Conn, data []byte) error { 288 289 if err := conn.SetWriteDeadline(time.Now().Add(5 * time.Second)); err != nil { 290 return err 291 } 292 293 n, err := conn.Write(data) 294 if err != nil && err != io.EOF { 295 return err 296 } 297 298 if n != len(data) { 299 return fmt.Errorf("failed to write data, expected: %v, written: %v", len(data), n) 300 } 301 302 return nil 303 }