github.com/aporeto-inc/trireme-lib@v10.358.0+incompatible/controller/internal/enforcer/applicationproxy/markedconn/markedconn.go (about)

     1  // +build linux windows
     2  
     3  package markedconn
     4  
     5  import (
     6  	"context"
     7  	"fmt"
     8  	"net"
     9  	"syscall"
    10  	"time"
    11  
    12  	"go.aporeto.io/enforcerd/trireme-lib/utils/netinterfaces"
    13  	"go.uber.org/zap"
    14  )
    15  
    16  // Control represents the dial control used to manipulate the raw connection.
    17  type Control func(network, address string, c syscall.RawConn) error
    18  
    19  func makeDialer(mark int, platformData *PlatformData) net.Dialer {
    20  	// platformData is the destHandle
    21  	return net.Dialer{
    22  		Control: ControlFunc(mark, true, platformData),
    23  	}
    24  }
    25  
    26  // DialMarkedWithContext will dial a TCP connection to the provide address and mark the socket
    27  // with the provided mark.
    28  func DialMarkedWithContext(ctx context.Context, network string, addr string, platformData *PlatformData, mark int) (net.Conn, error) {
    29  	// platformData is for Windows
    30  	if platformData != nil && platformData.postConnectFunc != nil {
    31  		defer platformData.postConnectFunc(platformData.handle)
    32  	}
    33  	d := makeDialer(mark, platformData)
    34  
    35  	conn, err := d.DialContext(ctx, network, addr)
    36  	if err != nil {
    37  		zap.L().Error("Failed to dial to downstream node",
    38  			zap.Error(err),
    39  			zap.String("Address", addr),
    40  			zap.String("Network type", network),
    41  		)
    42  	}
    43  	return conn, err
    44  }
    45  
    46  // NewSocketListener will create a listener and mark the socket with the provided mark.
    47  func NewSocketListener(ctx context.Context, port string, mark int) (net.Listener, error) {
    48  	listenerCfg := makeListenerConfig(mark)
    49  
    50  	listener, err := listenerCfg.Listen(ctx, "tcp", port)
    51  
    52  	if err != nil {
    53  		return nil, fmt.Errorf("Failed to create listener: %s", err)
    54  	}
    55  
    56  	return ProxiedListener{
    57  		netListener:      listener,
    58  		mark:             mark,
    59  		platformDataCtrl: NewPlatformDataControl(),
    60  	}, nil
    61  }
    62  
    63  // ProxiedConnection is a proxied connection where we can recover the
    64  // original destination.
    65  type ProxiedConnection struct {
    66  	originalIP            net.IP
    67  	originalPort          int
    68  	originalTCPConnection *net.TCPConn
    69  	platformData          *PlatformData
    70  }
    71  
    72  // PlatformData is proxy/socket data (platform-specific)
    73  type PlatformData struct {
    74  	handle          uintptr
    75  	postConnectFunc func(fd uintptr)
    76  }
    77  
    78  // GetOriginalDestination sets the original destination of the connection.
    79  func (p *ProxiedConnection) GetOriginalDestination() (net.IP, int) {
    80  	return p.originalIP, p.originalPort
    81  }
    82  
    83  // GetPlatformData gets the platform-specific socket data (needed for Windows)
    84  func (p *ProxiedConnection) GetPlatformData() *PlatformData {
    85  	return p.platformData
    86  }
    87  
    88  // GetTCPConnection returns the TCP connection object.
    89  func (p *ProxiedConnection) GetTCPConnection() *net.TCPConn {
    90  	return p.originalTCPConnection
    91  }
    92  
    93  // LocalAddr implements the corresponding method of net.Conn, but returns the original
    94  // address.
    95  func (p *ProxiedConnection) LocalAddr() net.Addr {
    96  
    97  	return &net.TCPAddr{
    98  		IP:   p.originalIP,
    99  		Port: p.originalPort,
   100  	}
   101  }
   102  
   103  // RemoteAddr returns the remote address
   104  func (p *ProxiedConnection) RemoteAddr() net.Addr {
   105  	return p.originalTCPConnection.RemoteAddr()
   106  }
   107  
   108  // Read reads data from the connection.
   109  func (p *ProxiedConnection) Read(b []byte) (n int, err error) {
   110  	return p.originalTCPConnection.Read(b)
   111  }
   112  
   113  // Write writes data to the connection.
   114  func (p *ProxiedConnection) Write(b []byte) (n int, err error) {
   115  	return p.originalTCPConnection.Write(b)
   116  }
   117  
   118  // Close closes the connection.
   119  func (p *ProxiedConnection) Close() error {
   120  	return p.originalTCPConnection.Close()
   121  }
   122  
   123  // SetDeadline passes the read deadline to the original TCP connection.
   124  func (p *ProxiedConnection) SetDeadline(t time.Time) error {
   125  	return p.originalTCPConnection.SetDeadline(t)
   126  }
   127  
   128  // SetReadDeadline implements the call by passing it to the original connection.
   129  func (p *ProxiedConnection) SetReadDeadline(t time.Time) error {
   130  	return p.originalTCPConnection.SetReadDeadline(t)
   131  }
   132  
   133  // SetWriteDeadline implements the call by passing it to the original connection.
   134  func (p *ProxiedConnection) SetWriteDeadline(t time.Time) error {
   135  	return p.originalTCPConnection.SetWriteDeadline(t)
   136  }
   137  
   138  // ProxiedListener is a proxied listener that uses proxied connections.
   139  type ProxiedListener struct {
   140  	netListener      net.Listener
   141  	mark             int
   142  	platformDataCtrl *PlatformDataControl
   143  }
   144  
   145  type passFD interface {
   146  	Control(func(uintptr)) error
   147  }
   148  
   149  func getOriginalDestination(conn *net.TCPConn) (net.IP, int, *PlatformData, error) { // nolint interfacer
   150  
   151  	rawconn, err := conn.SyscallConn()
   152  	if err != nil {
   153  		return nil, 0, nil, err
   154  	}
   155  
   156  	localIPString, _, err := net.SplitHostPort(conn.LocalAddr().String())
   157  	if err != nil {
   158  		return nil, 0, nil, err
   159  	}
   160  
   161  	localIP := net.ParseIP(localIPString)
   162  
   163  	return getOriginalDestPlatform(rawconn, localIP.To4() != nil)
   164  }
   165  
   166  // Accept implements the accept method of the interface.
   167  func (l ProxiedListener) Accept() (c net.Conn, err error) {
   168  	nc, err := l.netListener.Accept()
   169  	if err != nil {
   170  		return nil, err
   171  	}
   172  
   173  	tcpConn, ok := nc.(*net.TCPConn)
   174  	if !ok {
   175  		zap.L().Error("Received a non-TCP connection - this should never happen", zap.Error(err))
   176  		return nil, fmt.Errorf("Not a tcp connection - ignoring")
   177  	}
   178  
   179  	ip, port, platformData, err := getOriginalDestination(tcpConn)
   180  	if err != nil {
   181  		zap.L().Error("Failed to discover original destination - aborting", zap.Error(err))
   182  		return nil, err
   183  	}
   184  	l.platformDataCtrl.StorePlatformData(ip, port, platformData)
   185  
   186  	return &ProxiedConnection{
   187  		originalIP:            ip,
   188  		originalPort:          port,
   189  		originalTCPConnection: tcpConn,
   190  		platformData:          platformData,
   191  	}, nil
   192  }
   193  
   194  // Addr implements the Addr method of net.Listener.
   195  func (l ProxiedListener) Addr() net.Addr {
   196  	return l.netListener.Addr()
   197  }
   198  
   199  // Close implements the Close method of the net.Listener.
   200  func (l ProxiedListener) Close() error {
   201  	return l.netListener.Close()
   202  }
   203  
   204  // GetInterfaces retrieves all the local interfaces.
   205  func GetInterfaces() map[string]struct{} {
   206  	ipmap := map[string]struct{}{}
   207  
   208  	ifaces, err := netinterfaces.GetInterfacesInfo()
   209  	if err != nil {
   210  		zap.L().Error("Unable to get interfaces info", zap.Error(err))
   211  	}
   212  
   213  	for _, iface := range ifaces {
   214  		for _, ip := range iface.IPs {
   215  			ipmap[ip.String()] = struct{}{}
   216  		}
   217  	}
   218  
   219  	return ipmap
   220  }