github.com/aporeto-inc/trireme-lib@v10.358.0+incompatible/controller/internal/enforcer/nfqdatapath/diagnostics_tcp.go (about)

     1  // +build !windows
     2  
     3  package nfqdatapath
     4  
     5  import (
     6  	"context"
     7  	"fmt"
     8  	"math/rand"
     9  	"net"
    10  	"strconv"
    11  	"syscall"
    12  	"time"
    13  
    14  	"github.com/aporeto-inc/gopkt/layers"
    15  	"github.com/aporeto-inc/gopkt/packet/ipv4"
    16  	"github.com/aporeto-inc/gopkt/packet/raw"
    17  	"github.com/aporeto-inc/gopkt/packet/tcp"
    18  	"github.com/aporeto-inc/gopkt/routing"
    19  	"github.com/phayes/freeport"
    20  	"github.com/vmihailenco/msgpack"
    21  	"go.aporeto.io/trireme-lib/collector"
    22  	enforcerconstants "go.aporeto.io/trireme-lib/controller/internal/enforcer/constants"
    23  	"go.aporeto.io/trireme-lib/controller/pkg/claimsheader"
    24  	"go.aporeto.io/trireme-lib/controller/pkg/connection"
    25  	tpacket "go.aporeto.io/trireme-lib/controller/pkg/packet"
    26  	"go.aporeto.io/trireme-lib/controller/pkg/pucontext"
    27  	"go.aporeto.io/trireme-lib/controller/pkg/tokens"
    28  	"go.aporeto.io/trireme-lib/policy"
    29  	"go.aporeto.io/trireme-lib/utils/crypto"
    30  	"go.uber.org/zap"
    31  )
    32  
    33  func (d *Datapath) initiateDiagnostics(_ context.Context, contextID string, pingConfig *policy.PingConfig) error {
    34  
    35  	if pingConfig == nil {
    36  		return nil
    37  	}
    38  
    39  	zap.L().Debug("Initiating diagnostics (syn)")
    40  
    41  	srcIP, err := getSrcIP(pingConfig.IP)
    42  	if err != nil {
    43  		return fmt.Errorf("unable to get source ip: %v", err)
    44  	}
    45  
    46  	conn, err := dialWithMark(srcIP, pingConfig.IP)
    47  	if err != nil {
    48  		return fmt.Errorf("unable to dial on app syn: %v", err)
    49  	}
    50  	defer conn.Close() // nolint:errcheck
    51  
    52  	item, err := d.puFromContextID.Get(contextID)
    53  	if err != nil {
    54  		return fmt.Errorf("unable to find context with ID %s in cache: %v", contextID, err)
    55  	}
    56  
    57  	context, ok := item.(*pucontext.PUContext)
    58  	if !ok {
    59  		return fmt.Errorf("invalid pu context: %v", contextID)
    60  	}
    61  
    62  	for i := 1; i <= pingConfig.Requests; i++ {
    63  		for _, ports := range pingConfig.Ports {
    64  			for dstPort := ports.Min; dstPort <= ports.Max; dstPort++ {
    65  				if err := d.sendSynPacket(context, pingConfig, conn, srcIP, dstPort, i); err != nil {
    66  					return err
    67  				}
    68  			}
    69  		}
    70  	}
    71  
    72  	return nil
    73  }
    74  
    75  // sendSynPacket sends tcp syn packet to the socket. It also dispatches a report.
    76  func (d *Datapath) sendSynPacket(context *pucontext.PUContext, pingConfig *policy.PingConfig, conn net.Conn, srcIP net.IP, dstPort uint16, request int) error {
    77  
    78  	tcpConn := connection.NewTCPConnection(context, nil)
    79  	tcpConn.Secrets = d.secrets()
    80  
    81  	claimsHeader := claimsheader.NewClaimsHeader(
    82  		claimsheader.OptionPingType(pingConfig.Type),
    83  	)
    84  
    85  	tcpData, err := d.tokenAccessor.CreateSynPacketToken(context, &tcpConn.Auth, claimsHeader, d.secrets())
    86  	if err != nil {
    87  		return fmt.Errorf("unable to create syn token: %v", err)
    88  	}
    89  
    90  	srcPort, err := freeport.GetFreePort()
    91  	if err != nil {
    92  		return fmt.Errorf("unable to get free source port: %v", err)
    93  	}
    94  
    95  	p, err := constructTCPPacket(srcIP, pingConfig.IP, uint16(srcPort), dstPort, tcp.Syn, tcpData)
    96  	if err != nil {
    97  		return fmt.Errorf("unable to construct syn packet: %v", err)
    98  	}
    99  
   100  	sessionID, err := crypto.GenerateRandomString(20)
   101  	if err != nil {
   102  		return err
   103  	}
   104  
   105  	if err := write(conn, p); err != nil {
   106  		return fmt.Errorf("unable to send syn packet: %v", err)
   107  	}
   108  
   109  	tcpConn.PingConfig = &connection.PingConfig{
   110  		StartTime: time.Now(),
   111  		Type:      pingConfig.Type,
   112  		SessionID: sessionID,
   113  		Request:   request,
   114  	}
   115  
   116  	d.sendOriginPingReport(
   117  		sessionID,
   118  		d.agentVersion.String(),
   119  		flowTuple(
   120  			tpacket.PacketTypeApplication,
   121  			srcIP.String(),
   122  			pingConfig.IP.String(),
   123  			uint16(srcPort),
   124  			dstPort,
   125  		),
   126  		context,
   127  		pingConfig.Type,
   128  		len(tcpData),
   129  		request,
   130  	)
   131  
   132  	tcpConn.SetState(connection.TCPSynSend)
   133  	d.sourcePortConnectionCache.AddOrUpdate(
   134  		packetTuple(tpacket.PacketTypeApplication, srcIP.String(), pingConfig.IP.String(), uint16(srcPort), dstPort),
   135  		tcpConn,
   136  	)
   137  
   138  	return nil
   139  }
   140  
   141  // processDiagnosticNetSynPacket should only be called when the packet is recognized as a diagnostic syn packet.
   142  func (d *Datapath) processDiagnosticNetSynPacket(
   143  	context *pucontext.PUContext,
   144  	tcpConn *connection.TCPConnection,
   145  	tcpPacket *tpacket.Packet,
   146  	claims *tokens.ConnectionClaims,
   147  ) error {
   148  
   149  	ch := claims.H.ToClaimsHeader()
   150  	tcpConn.PingConfig.Type = ch.PingType()
   151  	tcpConn.SetState(connection.TCPSynReceived)
   152  
   153  	zap.L().Debug("Processing diagnostic network syn packet",
   154  		zap.String("pingType", ch.PingType().String()),
   155  	)
   156  
   157  	if ch.PingType() == claimsheader.PingTypeDefaultIdentityPassthrough {
   158  		zap.L().Debug("Processing diagnostic network syn packet: defaultpassthrough")
   159  
   160  		tcpConn.PingConfig.Passthrough = true
   161  		d.appReplyConnectionTracker.AddOrUpdate(tcpPacket.L4ReverseFlowHash(), tcpConn)
   162  		return nil
   163  	}
   164  
   165  	conn, err := dialWithMark(tcpPacket.DestinationAddress(), tcpPacket.SourceAddress())
   166  	if err != nil {
   167  		return fmt.Errorf("unable to dial on net syn: %v", err)
   168  	}
   169  	defer conn.Close() // nolint:errcheck
   170  
   171  	var tcpData []byte
   172  	// If diagnostic type is custom, we add custom payload.
   173  	// Else, we add default payload.
   174  	if ch.PingType() == claimsheader.PingTypeCustomIdentity {
   175  		ci := &customIdentity{
   176  			AgentVersion:         d.agentVersion.String(),
   177  			TransmitterID:        context.ManagementID(),
   178  			TransmitterNamespace: context.ManagementNamespace(),
   179  			FlowTuple: flowTuple(
   180  				tpacket.PacketTypeApplication,
   181  				tcpPacket.SourceAddress().String(),
   182  				tcpPacket.DestinationAddress().String(),
   183  				tcpPacket.SourcePort(),
   184  				tcpPacket.DestPort(),
   185  			),
   186  		}
   187  		tcpData, err = ci.encode()
   188  		if err != nil {
   189  			return err
   190  		}
   191  	} else {
   192  		tcpData, err = d.tokenAccessor.CreateSynAckPacketToken(context, &tcpConn.Auth, ch, d.secrets())
   193  		if err != nil {
   194  			return fmt.Errorf("unable to create default synack token: %v", err)
   195  		}
   196  	}
   197  
   198  	p, err := constructTCPPacket(
   199  		tcpPacket.DestinationAddress(),
   200  		tcpPacket.SourceAddress(),
   201  		tcpPacket.DestPort(),
   202  		tcpPacket.SourcePort(),
   203  		tcp.Syn|tcp.Ack,
   204  		tcpData,
   205  	)
   206  	if err != nil {
   207  		return fmt.Errorf("unable to construct synack packet: %v", err)
   208  	}
   209  
   210  	if err := write(conn, p); err != nil {
   211  		return fmt.Errorf("unable to send synack packet: %v", err)
   212  	}
   213  
   214  	tcpConn.SetState(connection.TCPSynAckSend)
   215  	return nil
   216  }
   217  
   218  // processDiagnosticNetSynAckPacket should only be called when the packet is recognized as a diagnostic synack packet.
   219  func (d *Datapath) processDiagnosticNetSynAckPacket(
   220  	context *pucontext.PUContext,
   221  	tcpConn *connection.TCPConnection,
   222  	tcpPacket *tpacket.Packet,
   223  	claims *tokens.ConnectionClaims,
   224  	ext bool,
   225  	custom bool,
   226  ) error {
   227  	zap.L().Debug("Processing diagnostic network synack packet",
   228  		zap.Bool("externalNetwork", ext),
   229  		zap.Bool("customPayload", custom),
   230  		zap.String("pingType", tcpConn.PingConfig.Type.String()),
   231  	)
   232  
   233  	if tcpConn.GetState() == connection.TCPSynAckReceived {
   234  		zap.L().Debug("Ignoring duplicate synack packets")
   235  		return nil
   236  	}
   237  
   238  	receiveTime := time.Since(tcpConn.PingConfig.StartTime)
   239  	tcpConn.SetState(connection.TCPSynAckReceived)
   240  
   241  	// Synack from externalnetwork.
   242  	if ext {
   243  		tcpConn.PingConfig.Passthrough = true
   244  		d.sendReplyPingReport(&customIdentity{}, tcpConn, context, receiveTime.String(), len(tcpPacket.ReadTCPData()))
   245  		return nil
   246  	}
   247  
   248  	// Synack from an endpoint with custom identity enabled.
   249  	if custom {
   250  		ci := &customIdentity{}
   251  		if err := ci.decode(tcpPacket.ReadTCPData()); err != nil {
   252  			return err
   253  		}
   254  
   255  		d.sendReplyPingReport(ci, tcpConn, context, receiveTime.String(), len(tcpPacket.ReadTCPData()))
   256  		return nil
   257  	}
   258  
   259  	txtID, ok := claims.T.Get(enforcerconstants.TransmitterLabel)
   260  	if !ok {
   261  		return fmt.Errorf("missing transmitter label")
   262  	}
   263  
   264  	ci := &customIdentity{
   265  		TransmitterID: txtID,
   266  	}
   267  
   268  	d.sendReplyPingReport(ci, tcpConn, context, receiveTime.String(), len(tcpPacket.ReadTCPData()))
   269  
   270  	if tcpConn.PingConfig.Type == claimsheader.PingTypeDefaultIdentityPassthrough {
   271  		zap.L().Debug("Processing diagnostic network synack packet: defaultpassthrough")
   272  		tcpConn.PingConfig.Passthrough = true
   273  		return nil
   274  	}
   275  
   276  	return nil
   277  }
   278  
   279  // constructTCPPacket constructs a valid tcp packet that can be sent on wire.
   280  func constructTCPPacket(srcIP, dstIP net.IP, srcPort, dstPort uint16, flag tcp.Flags, tcpData []byte) ([]byte, error) {
   281  
   282  	// pseudo header.
   283  	// NOTE: Used only for computing checksum.
   284  	ipPacket := ipv4.Make()
   285  	ipPacket.SrcAddr = srcIP
   286  	ipPacket.DstAddr = dstIP
   287  	ipPacket.Protocol = ipv4.TCP
   288  
   289  	// tcp.
   290  	tcpPacket := tcp.Make()
   291  	tcpPacket.SrcPort = srcPort
   292  	tcpPacket.DstPort = dstPort
   293  	tcpPacket.Flags = flag
   294  	tcpPacket.Seq = rand.Uint32()
   295  	tcpPacket.WindowSize = 0xAAAA
   296  	tcpPacket.Options = []tcp.Option{
   297  		{
   298  			Type: tcp.MSS,
   299  			Len:  4,
   300  			Data: []byte{0x05, 0x8C},
   301  		}, {
   302  			Type: 34, // tfo
   303  			Len:  enforcerconstants.TCPAuthenticationOptionBaseLen,
   304  			Data: make([]byte, 2),
   305  		},
   306  	}
   307  	tcpPacket.DataOff = uint8(7) // 5 (header size) + 2 * (4 byte options)
   308  
   309  	// payload.
   310  	payload := raw.Make()
   311  	payload.Data = tcpData
   312  
   313  	tcpPacket.SetPayload(payload)  // nolint:errcheck
   314  	ipPacket.SetPayload(tcpPacket) // nolint:errcheck
   315  
   316  	// pack the layers together.
   317  	buf, err := layers.Pack(tcpPacket, payload)
   318  	if err != nil {
   319  		return nil, fmt.Errorf("unable to encode packet to wire format: %v", err)
   320  	}
   321  
   322  	return buf, nil
   323  }
   324  
   325  // getSrcIP returns the interface ip that can reach the destination.
   326  func getSrcIP(dstIP net.IP) (net.IP, error) {
   327  
   328  	route, err := routing.RouteTo(dstIP)
   329  	if err != nil || route == nil {
   330  		return nil, fmt.Errorf("no route found for destination %s: %v", dstIP.String(), err)
   331  	}
   332  
   333  	ip, err := route.GetIfaceIPv4Addr()
   334  	if err != nil {
   335  		return nil, fmt.Errorf("unable to get interface ip address: %v", err)
   336  	}
   337  
   338  	return ip, nil
   339  }
   340  
   341  // flowTuple returns the tuple based on the stage in format <sip:dip:spt:dpt> or <dip:sip:dpt:spt>
   342  func flowTuple(stage uint64, srcIP, dstIP string, srcPort, dstPort uint16) string {
   343  
   344  	if stage == tpacket.PacketTypeNetwork {
   345  		return fmt.Sprintf("%s:%s:%s:%s", dstIP, srcIP, strconv.Itoa(int(dstPort)), strconv.Itoa(int(srcPort)))
   346  	}
   347  
   348  	return fmt.Sprintf("%s:%s:%s:%s", srcIP, dstIP, strconv.Itoa(int(srcPort)), strconv.Itoa(int(dstPort)))
   349  }
   350  
   351  // packetTuple returns the tuple based on the stage in format <sip:spt> or <dip:dpt>
   352  func packetTuple(stage uint64, srcIP, dstIP string, srcPort, dstPort uint16) string {
   353  
   354  	if stage == tpacket.PacketTypeNetwork {
   355  		return dstIP + ":" + strconv.Itoa(int(dstPort))
   356  	}
   357  
   358  	return srcIP + ":" + strconv.Itoa(int(srcPort))
   359  }
   360  
   361  // dialWithMark opens raw ipv4:tcp socket and connects to the remote network.
   362  func dialWithMark(srcIP, dstIP net.IP) (net.Conn, error) {
   363  
   364  	d := net.Dialer{
   365  		Timeout:   5 * time.Second,
   366  		KeepAlive: -1, // keepalive disabled.
   367  		LocalAddr: &net.IPAddr{IP: srcIP},
   368  		Control: func(_, _ string, c syscall.RawConn) error {
   369  			return c.Control(func(fd uintptr) {
   370  				if err := syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, 0x24, 0x40); err != nil {
   371  					zap.L().Error("unable to assign mark", zap.Error(err))
   372  				}
   373  			})
   374  		},
   375  	}
   376  
   377  	return d.Dial("ip4:tcp", dstIP.String())
   378  }
   379  
   380  // write writes the given data to the conn.
   381  func write(conn net.Conn, data []byte) error {
   382  
   383  	n, err := conn.Write(data)
   384  	if err != nil {
   385  		return err
   386  	}
   387  
   388  	if n != len(data) {
   389  		return fmt.Errorf("partial data written, total: %v, written: %v", len(data), n)
   390  	}
   391  
   392  	return nil
   393  }
   394  
   395  // sendOriginPingReport sends a report on syn sent state.
   396  func (d *Datapath) sendOriginPingReport(
   397  	sessionID,
   398  	agentVersion,
   399  	flowTuple string,
   400  	context *pucontext.PUContext,
   401  	pingType claimsheader.PingType,
   402  	payloadSize,
   403  	request int,
   404  ) {
   405  	d.sendPingReport(
   406  		sessionID,
   407  		agentVersion,
   408  		flowTuple,
   409  		"",
   410  		context.ManagementID(),
   411  		context.ManagementNamespace(),
   412  		"",
   413  		"",
   414  		pingType,
   415  		collector.Origin,
   416  		payloadSize,
   417  		request,
   418  	)
   419  }
   420  
   421  // sendOriginPingReport sends a report on synack recv state.
   422  func (d *Datapath) sendReplyPingReport(
   423  	ci *customIdentity,
   424  	tcpConn *connection.TCPConnection,
   425  	context *pucontext.PUContext,
   426  	rtt string,
   427  	payloadSize int,
   428  ) {
   429  	d.sendPingReport(
   430  		tcpConn.PingConfig.SessionID,
   431  		ci.AgentVersion,
   432  		ci.FlowTuple,
   433  		rtt,
   434  		context.ManagementID(),
   435  		context.ManagementNamespace(),
   436  		ci.TransmitterID,
   437  		ci.TransmitterNamespace,
   438  		tcpConn.PingConfig.Type,
   439  		collector.Reply,
   440  		payloadSize,
   441  		tcpConn.PingConfig.Request,
   442  	)
   443  }
   444  
   445  func (d *Datapath) sendPingReport(
   446  	sessionID,
   447  	agentVersion,
   448  	flowTuple,
   449  	rtt,
   450  	srcID,
   451  	srcNS,
   452  	dstID,
   453  	dstNS string,
   454  	PingType claimsheader.PingType,
   455  	stage collector.Stage,
   456  	payloadSize,
   457  	request int,
   458  ) {
   459  
   460  	report := &collector.PingReport{
   461  		AgentVersion:         agentVersion,
   462  		FlowTuple:            flowTuple,
   463  		Latency:              rtt,
   464  		PayloadSize:          payloadSize,
   465  		Type:                 PingType,
   466  		Stage:                stage,
   467  		SourceID:             srcID,
   468  		SourceNamespace:      srcNS,
   469  		DestinationNamespace: dstNS,
   470  		DestinationID:        dstID,
   471  		SessionID:            sessionID,
   472  		Protocol:             tpacket.IPProtocolTCP,
   473  		ServiceType:          "L3",
   474  		Request:              request,
   475  	}
   476  
   477  	d.collector.CollectPingEvent(report)
   478  }
   479  
   480  // customIdentity holds data that needs to be passed on wire.
   481  type customIdentity struct {
   482  	AgentVersion         string
   483  	TransmitterID        string
   484  	TransmitterNamespace string
   485  	FlowTuple            string
   486  }
   487  
   488  // encode returns bytes of c, returns error on nil.
   489  func (c *customIdentity) encode() ([]byte, error) {
   490  
   491  	if c == nil {
   492  		return nil, fmt.Errorf("cannot encode nil custom identity")
   493  	}
   494  
   495  	b, err := msgpack.Marshal(c)
   496  	if err != nil {
   497  		return nil, fmt.Errorf("unable to encode custom identity: %v", err)
   498  	}
   499  
   500  	return b, nil
   501  }
   502  
   503  // decode returns customIdentity, returns error on nil.
   504  func (c *customIdentity) decode(b []byte) error {
   505  
   506  	if c == nil {
   507  		return fmt.Errorf("cannot decode nil custom identity")
   508  	}
   509  
   510  	if err := msgpack.Unmarshal(b, c); err != nil {
   511  		return fmt.Errorf("unable to decode custom identity: %v", err)
   512  	}
   513  
   514  	return nil
   515  }