github.com/slackhq/nebula@v1.9.0/service/service.go (about)

     1  package service
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"errors"
     7  	"fmt"
     8  	"log"
     9  	"math"
    10  	"net"
    11  	"os"
    12  	"strings"
    13  	"sync"
    14  
    15  	"github.com/sirupsen/logrus"
    16  	"github.com/slackhq/nebula"
    17  	"github.com/slackhq/nebula/config"
    18  	"github.com/slackhq/nebula/overlay"
    19  	"golang.org/x/sync/errgroup"
    20  	"gvisor.dev/gvisor/pkg/buffer"
    21  	"gvisor.dev/gvisor/pkg/tcpip"
    22  	"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
    23  	"gvisor.dev/gvisor/pkg/tcpip/header"
    24  	"gvisor.dev/gvisor/pkg/tcpip/link/channel"
    25  	"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
    26  	"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
    27  	"gvisor.dev/gvisor/pkg/tcpip/stack"
    28  	"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
    29  	"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
    30  	"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
    31  	"gvisor.dev/gvisor/pkg/waiter"
    32  )
    33  
    34  const nicID = 1
    35  
    36  type Service struct {
    37  	eg      *errgroup.Group
    38  	control *nebula.Control
    39  	ipstack *stack.Stack
    40  
    41  	mu struct {
    42  		sync.Mutex
    43  
    44  		listeners map[uint16]*tcpListener
    45  	}
    46  }
    47  
    48  func New(config *config.C) (*Service, error) {
    49  	logger := logrus.New()
    50  	logger.Out = os.Stdout
    51  
    52  	control, err := nebula.Main(config, false, "custom-app", logger, overlay.NewUserDeviceFromConfig)
    53  	if err != nil {
    54  		return nil, err
    55  	}
    56  	control.Start()
    57  
    58  	ctx := control.Context()
    59  	eg, ctx := errgroup.WithContext(ctx)
    60  	s := Service{
    61  		eg:      eg,
    62  		control: control,
    63  	}
    64  	s.mu.listeners = map[uint16]*tcpListener{}
    65  
    66  	device, ok := control.Device().(*overlay.UserDevice)
    67  	if !ok {
    68  		return nil, errors.New("must be using user device")
    69  	}
    70  
    71  	s.ipstack = stack.New(stack.Options{
    72  		NetworkProtocols:   []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
    73  		TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol4, icmp.NewProtocol6},
    74  	})
    75  	sackEnabledOpt := tcpip.TCPSACKEnabled(true) // TCP SACK is disabled by default
    76  	tcpipErr := s.ipstack.SetTransportProtocolOption(tcp.ProtocolNumber, &sackEnabledOpt)
    77  	if tcpipErr != nil {
    78  		return nil, fmt.Errorf("could not enable TCP SACK: %v", tcpipErr)
    79  	}
    80  	linkEP := channel.New( /*size*/ 512 /*mtu*/, 1280, "")
    81  	if tcpipProblem := s.ipstack.CreateNIC(nicID, linkEP); tcpipProblem != nil {
    82  		return nil, fmt.Errorf("could not create netstack NIC: %v", tcpipProblem)
    83  	}
    84  	ipv4Subnet, _ := tcpip.NewSubnet(tcpip.AddrFrom4([4]byte{0x00, 0x00, 0x00, 0x00}), tcpip.MaskFrom(strings.Repeat("\x00", 4)))
    85  	s.ipstack.SetRouteTable([]tcpip.Route{
    86  		{
    87  			Destination: ipv4Subnet,
    88  			NIC:         nicID,
    89  		},
    90  	})
    91  
    92  	ipNet := device.Cidr()
    93  	pa := tcpip.ProtocolAddress{
    94  		AddressWithPrefix: tcpip.AddrFromSlice(ipNet.IP).WithPrefix(),
    95  		Protocol:          ipv4.ProtocolNumber,
    96  	}
    97  	if err := s.ipstack.AddProtocolAddress(nicID, pa, stack.AddressProperties{
    98  		PEB:        stack.CanBePrimaryEndpoint, // zero value default
    99  		ConfigType: stack.AddressConfigStatic,  // zero value default
   100  	}); err != nil {
   101  		return nil, fmt.Errorf("error creating IP: %s", err)
   102  	}
   103  
   104  	const tcpReceiveBufferSize = 0
   105  	const maxInFlightConnectionAttempts = 1024
   106  	tcpFwd := tcp.NewForwarder(s.ipstack, tcpReceiveBufferSize, maxInFlightConnectionAttempts, s.tcpHandler)
   107  	s.ipstack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpFwd.HandlePacket)
   108  
   109  	reader, writer := device.Pipe()
   110  
   111  	go func() {
   112  		<-ctx.Done()
   113  		reader.Close()
   114  		writer.Close()
   115  	}()
   116  
   117  	// create Goroutines to forward packets between Nebula and Gvisor
   118  	eg.Go(func() error {
   119  		buf := make([]byte, header.IPv4MaximumHeaderSize+header.IPv4MaximumPayloadSize)
   120  		for {
   121  			// this will read exactly one packet
   122  			n, err := reader.Read(buf)
   123  			if err != nil {
   124  				return err
   125  			}
   126  			packetBuf := stack.NewPacketBuffer(stack.PacketBufferOptions{
   127  				Payload: buffer.MakeWithData(bytes.Clone(buf[:n])),
   128  			})
   129  			linkEP.InjectInbound(header.IPv4ProtocolNumber, packetBuf)
   130  
   131  			if err := ctx.Err(); err != nil {
   132  				return err
   133  			}
   134  		}
   135  	})
   136  	eg.Go(func() error {
   137  		for {
   138  			packet := linkEP.ReadContext(ctx)
   139  			if packet == nil {
   140  				if err := ctx.Err(); err != nil {
   141  					return err
   142  				}
   143  				continue
   144  			}
   145  			bufView := packet.ToView()
   146  			if _, err := bufView.WriteTo(writer); err != nil {
   147  				return err
   148  			}
   149  			bufView.Release()
   150  		}
   151  	})
   152  
   153  	return &s, nil
   154  }
   155  
   156  // DialContext dials the provided address. Currently only TCP is supported.
   157  func (s *Service) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
   158  	if network != "tcp" && network != "tcp4" {
   159  		return nil, errors.New("only tcp is supported")
   160  	}
   161  
   162  	addr, err := net.ResolveTCPAddr(network, address)
   163  	if err != nil {
   164  		return nil, err
   165  	}
   166  
   167  	fullAddr := tcpip.FullAddress{
   168  		NIC:  nicID,
   169  		Addr: tcpip.AddrFromSlice(addr.IP),
   170  		Port: uint16(addr.Port),
   171  	}
   172  
   173  	return gonet.DialContextTCP(ctx, s.ipstack, fullAddr, ipv4.ProtocolNumber)
   174  }
   175  
   176  // Listen listens on the provided address. Currently only TCP with wildcard
   177  // addresses are supported.
   178  func (s *Service) Listen(network, address string) (net.Listener, error) {
   179  	if network != "tcp" && network != "tcp4" {
   180  		return nil, errors.New("only tcp is supported")
   181  	}
   182  	addr, err := net.ResolveTCPAddr(network, address)
   183  	if err != nil {
   184  		return nil, err
   185  	}
   186  	if addr.IP != nil && !bytes.Equal(addr.IP, []byte{0, 0, 0, 0}) {
   187  		return nil, fmt.Errorf("only wildcard address supported, got %q %v", address, addr.IP)
   188  	}
   189  	if addr.Port == 0 {
   190  		return nil, errors.New("specific port required, got 0")
   191  	}
   192  	if addr.Port < 0 || addr.Port >= math.MaxUint16 {
   193  		return nil, fmt.Errorf("invalid port %d", addr.Port)
   194  	}
   195  	port := uint16(addr.Port)
   196  
   197  	l := &tcpListener{
   198  		port:   port,
   199  		s:      s,
   200  		addr:   addr,
   201  		accept: make(chan net.Conn),
   202  	}
   203  
   204  	s.mu.Lock()
   205  	defer s.mu.Unlock()
   206  
   207  	if _, ok := s.mu.listeners[port]; ok {
   208  		return nil, fmt.Errorf("already listening on port %d", port)
   209  	}
   210  	s.mu.listeners[port] = l
   211  
   212  	return l, nil
   213  }
   214  
   215  func (s *Service) Wait() error {
   216  	return s.eg.Wait()
   217  }
   218  
   219  func (s *Service) Close() error {
   220  	s.control.Stop()
   221  	return nil
   222  }
   223  
   224  func (s *Service) tcpHandler(r *tcp.ForwarderRequest) {
   225  	endpointID := r.ID()
   226  
   227  	s.mu.Lock()
   228  	defer s.mu.Unlock()
   229  
   230  	l, ok := s.mu.listeners[endpointID.LocalPort]
   231  	if !ok {
   232  		r.Complete(true)
   233  		return
   234  	}
   235  
   236  	var wq waiter.Queue
   237  	ep, err := r.CreateEndpoint(&wq)
   238  	if err != nil {
   239  		log.Printf("got error creating endpoint %q", err)
   240  		r.Complete(true)
   241  		return
   242  	}
   243  	r.Complete(false)
   244  	ep.SocketOptions().SetKeepAlive(true)
   245  
   246  	conn := gonet.NewTCPConn(&wq, ep)
   247  	l.accept <- conn
   248  }