github.com/slackhq/nebula@v1.9.0/service/service.go (about) 1 package service 2 3 import ( 4 "bytes" 5 "context" 6 "errors" 7 "fmt" 8 "log" 9 "math" 10 "net" 11 "os" 12 "strings" 13 "sync" 14 15 "github.com/sirupsen/logrus" 16 "github.com/slackhq/nebula" 17 "github.com/slackhq/nebula/config" 18 "github.com/slackhq/nebula/overlay" 19 "golang.org/x/sync/errgroup" 20 "gvisor.dev/gvisor/pkg/buffer" 21 "gvisor.dev/gvisor/pkg/tcpip" 22 "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" 23 "gvisor.dev/gvisor/pkg/tcpip/header" 24 "gvisor.dev/gvisor/pkg/tcpip/link/channel" 25 "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" 26 "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" 27 "gvisor.dev/gvisor/pkg/tcpip/stack" 28 "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" 29 "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" 30 "gvisor.dev/gvisor/pkg/tcpip/transport/udp" 31 "gvisor.dev/gvisor/pkg/waiter" 32 ) 33 34 const nicID = 1 35 36 type Service struct { 37 eg *errgroup.Group 38 control *nebula.Control 39 ipstack *stack.Stack 40 41 mu struct { 42 sync.Mutex 43 44 listeners map[uint16]*tcpListener 45 } 46 } 47 48 func New(config *config.C) (*Service, error) { 49 logger := logrus.New() 50 logger.Out = os.Stdout 51 52 control, err := nebula.Main(config, false, "custom-app", logger, overlay.NewUserDeviceFromConfig) 53 if err != nil { 54 return nil, err 55 } 56 control.Start() 57 58 ctx := control.Context() 59 eg, ctx := errgroup.WithContext(ctx) 60 s := Service{ 61 eg: eg, 62 control: control, 63 } 64 s.mu.listeners = map[uint16]*tcpListener{} 65 66 device, ok := control.Device().(*overlay.UserDevice) 67 if !ok { 68 return nil, errors.New("must be using user device") 69 } 70 71 s.ipstack = stack.New(stack.Options{ 72 NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, 73 TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol4, icmp.NewProtocol6}, 74 }) 75 sackEnabledOpt := tcpip.TCPSACKEnabled(true) // TCP SACK is disabled by default 76 tcpipErr := s.ipstack.SetTransportProtocolOption(tcp.ProtocolNumber, &sackEnabledOpt) 77 if tcpipErr != nil { 78 return nil, fmt.Errorf("could not enable TCP SACK: %v", tcpipErr) 79 } 80 linkEP := channel.New( /*size*/ 512 /*mtu*/, 1280, "") 81 if tcpipProblem := s.ipstack.CreateNIC(nicID, linkEP); tcpipProblem != nil { 82 return nil, fmt.Errorf("could not create netstack NIC: %v", tcpipProblem) 83 } 84 ipv4Subnet, _ := tcpip.NewSubnet(tcpip.AddrFrom4([4]byte{0x00, 0x00, 0x00, 0x00}), tcpip.MaskFrom(strings.Repeat("\x00", 4))) 85 s.ipstack.SetRouteTable([]tcpip.Route{ 86 { 87 Destination: ipv4Subnet, 88 NIC: nicID, 89 }, 90 }) 91 92 ipNet := device.Cidr() 93 pa := tcpip.ProtocolAddress{ 94 AddressWithPrefix: tcpip.AddrFromSlice(ipNet.IP).WithPrefix(), 95 Protocol: ipv4.ProtocolNumber, 96 } 97 if err := s.ipstack.AddProtocolAddress(nicID, pa, stack.AddressProperties{ 98 PEB: stack.CanBePrimaryEndpoint, // zero value default 99 ConfigType: stack.AddressConfigStatic, // zero value default 100 }); err != nil { 101 return nil, fmt.Errorf("error creating IP: %s", err) 102 } 103 104 const tcpReceiveBufferSize = 0 105 const maxInFlightConnectionAttempts = 1024 106 tcpFwd := tcp.NewForwarder(s.ipstack, tcpReceiveBufferSize, maxInFlightConnectionAttempts, s.tcpHandler) 107 s.ipstack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpFwd.HandlePacket) 108 109 reader, writer := device.Pipe() 110 111 go func() { 112 <-ctx.Done() 113 reader.Close() 114 writer.Close() 115 }() 116 117 // create Goroutines to forward packets between Nebula and Gvisor 118 eg.Go(func() error { 119 buf := make([]byte, header.IPv4MaximumHeaderSize+header.IPv4MaximumPayloadSize) 120 for { 121 // this will read exactly one packet 122 n, err := reader.Read(buf) 123 if err != nil { 124 return err 125 } 126 packetBuf := stack.NewPacketBuffer(stack.PacketBufferOptions{ 127 Payload: buffer.MakeWithData(bytes.Clone(buf[:n])), 128 }) 129 linkEP.InjectInbound(header.IPv4ProtocolNumber, packetBuf) 130 131 if err := ctx.Err(); err != nil { 132 return err 133 } 134 } 135 }) 136 eg.Go(func() error { 137 for { 138 packet := linkEP.ReadContext(ctx) 139 if packet == nil { 140 if err := ctx.Err(); err != nil { 141 return err 142 } 143 continue 144 } 145 bufView := packet.ToView() 146 if _, err := bufView.WriteTo(writer); err != nil { 147 return err 148 } 149 bufView.Release() 150 } 151 }) 152 153 return &s, nil 154 } 155 156 // DialContext dials the provided address. Currently only TCP is supported. 157 func (s *Service) DialContext(ctx context.Context, network, address string) (net.Conn, error) { 158 if network != "tcp" && network != "tcp4" { 159 return nil, errors.New("only tcp is supported") 160 } 161 162 addr, err := net.ResolveTCPAddr(network, address) 163 if err != nil { 164 return nil, err 165 } 166 167 fullAddr := tcpip.FullAddress{ 168 NIC: nicID, 169 Addr: tcpip.AddrFromSlice(addr.IP), 170 Port: uint16(addr.Port), 171 } 172 173 return gonet.DialContextTCP(ctx, s.ipstack, fullAddr, ipv4.ProtocolNumber) 174 } 175 176 // Listen listens on the provided address. Currently only TCP with wildcard 177 // addresses are supported. 178 func (s *Service) Listen(network, address string) (net.Listener, error) { 179 if network != "tcp" && network != "tcp4" { 180 return nil, errors.New("only tcp is supported") 181 } 182 addr, err := net.ResolveTCPAddr(network, address) 183 if err != nil { 184 return nil, err 185 } 186 if addr.IP != nil && !bytes.Equal(addr.IP, []byte{0, 0, 0, 0}) { 187 return nil, fmt.Errorf("only wildcard address supported, got %q %v", address, addr.IP) 188 } 189 if addr.Port == 0 { 190 return nil, errors.New("specific port required, got 0") 191 } 192 if addr.Port < 0 || addr.Port >= math.MaxUint16 { 193 return nil, fmt.Errorf("invalid port %d", addr.Port) 194 } 195 port := uint16(addr.Port) 196 197 l := &tcpListener{ 198 port: port, 199 s: s, 200 addr: addr, 201 accept: make(chan net.Conn), 202 } 203 204 s.mu.Lock() 205 defer s.mu.Unlock() 206 207 if _, ok := s.mu.listeners[port]; ok { 208 return nil, fmt.Errorf("already listening on port %d", port) 209 } 210 s.mu.listeners[port] = l 211 212 return l, nil 213 } 214 215 func (s *Service) Wait() error { 216 return s.eg.Wait() 217 } 218 219 func (s *Service) Close() error { 220 s.control.Stop() 221 return nil 222 } 223 224 func (s *Service) tcpHandler(r *tcp.ForwarderRequest) { 225 endpointID := r.ID() 226 227 s.mu.Lock() 228 defer s.mu.Unlock() 229 230 l, ok := s.mu.listeners[endpointID.LocalPort] 231 if !ok { 232 r.Complete(true) 233 return 234 } 235 236 var wq waiter.Queue 237 ep, err := r.CreateEndpoint(&wq) 238 if err != nil { 239 log.Printf("got error creating endpoint %q", err) 240 r.Complete(true) 241 return 242 } 243 r.Complete(false) 244 ep.SocketOptions().SetKeepAlive(true) 245 246 conn := gonet.NewTCPConn(&wq, ep) 247 l.accept <- conn 248 }