github.com/slackhq/nebula@v1.9.0/outside.go (about)

     1  package nebula
     2  
     3  import (
     4  	"encoding/binary"
     5  	"errors"
     6  	"fmt"
     7  	"time"
     8  
     9  	"github.com/flynn/noise"
    10  	"github.com/sirupsen/logrus"
    11  	"github.com/slackhq/nebula/cert"
    12  	"github.com/slackhq/nebula/firewall"
    13  	"github.com/slackhq/nebula/header"
    14  	"github.com/slackhq/nebula/iputil"
    15  	"github.com/slackhq/nebula/udp"
    16  	"golang.org/x/net/ipv4"
    17  	"google.golang.org/protobuf/proto"
    18  )
    19  
    20  const (
    21  	minFwPacketLen = 4
    22  )
    23  
    24  func readOutsidePackets(f *Interface) udp.EncReader {
    25  	return func(
    26  		addr *udp.Addr,
    27  		out []byte,
    28  		packet []byte,
    29  		header *header.H,
    30  		fwPacket *firewall.Packet,
    31  		lhh udp.LightHouseHandlerFunc,
    32  		nb []byte,
    33  		q int,
    34  		localCache firewall.ConntrackCache,
    35  	) {
    36  		f.readOutsidePackets(addr, nil, out, packet, header, fwPacket, lhh, nb, q, localCache)
    37  	}
    38  }
    39  
    40  func (f *Interface) readOutsidePackets(addr *udp.Addr, via *ViaSender, out []byte, packet []byte, h *header.H, fwPacket *firewall.Packet, lhf udp.LightHouseHandlerFunc, nb []byte, q int, localCache firewall.ConntrackCache) {
    41  	err := h.Parse(packet)
    42  	if err != nil {
    43  		// TODO: best if we return this and let caller log
    44  		// TODO: Might be better to send the literal []byte("holepunch") packet and ignore that?
    45  		// Hole punch packets are 0 or 1 byte big, so lets ignore printing those errors
    46  		if len(packet) > 1 {
    47  			f.l.WithField("packet", packet).Infof("Error while parsing inbound packet from %s: %s", addr, err)
    48  		}
    49  		return
    50  	}
    51  
    52  	//l.Error("in packet ", header, packet[HeaderLen:])
    53  	if addr != nil {
    54  		if ip4 := addr.IP.To4(); ip4 != nil {
    55  			if ipMaskContains(f.lightHouse.myVpnIp, f.lightHouse.myVpnZeros, iputil.VpnIp(binary.BigEndian.Uint32(ip4))) {
    56  				if f.l.Level >= logrus.DebugLevel {
    57  					f.l.WithField("udpAddr", addr).Debug("Refusing to process double encrypted packet")
    58  				}
    59  				return
    60  			}
    61  		}
    62  	}
    63  
    64  	var hostinfo *HostInfo
    65  	// verify if we've seen this index before, otherwise respond to the handshake initiation
    66  	if h.Type == header.Message && h.Subtype == header.MessageRelay {
    67  		hostinfo = f.hostMap.QueryRelayIndex(h.RemoteIndex)
    68  	} else {
    69  		hostinfo = f.hostMap.QueryIndex(h.RemoteIndex)
    70  	}
    71  
    72  	var ci *ConnectionState
    73  	if hostinfo != nil {
    74  		ci = hostinfo.ConnectionState
    75  	}
    76  
    77  	switch h.Type {
    78  	case header.Message:
    79  		// TODO handleEncrypted sends directly to addr on error. Handle this in the tunneling case.
    80  		if !f.handleEncrypted(ci, addr, h) {
    81  			return
    82  		}
    83  
    84  		switch h.Subtype {
    85  		case header.MessageNone:
    86  			if !f.decryptToTun(hostinfo, h.MessageCounter, out, packet, fwPacket, nb, q, localCache) {
    87  				return
    88  			}
    89  		case header.MessageRelay:
    90  			// The entire body is sent as AD, not encrypted.
    91  			// The packet consists of a 16-byte parsed Nebula header, Associated Data-protected payload, and a trailing 16-byte AEAD signature value.
    92  			// The packet is guaranteed to be at least 16 bytes at this point, b/c it got past the h.Parse() call above. If it's
    93  			// otherwise malformed (meaning, there is no trailing 16 byte AEAD value), then this will result in at worst a 0-length slice
    94  			// which will gracefully fail in the DecryptDanger call.
    95  			signedPayload := packet[:len(packet)-hostinfo.ConnectionState.dKey.Overhead()]
    96  			signatureValue := packet[len(packet)-hostinfo.ConnectionState.dKey.Overhead():]
    97  			out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, signedPayload, signatureValue, h.MessageCounter, nb)
    98  			if err != nil {
    99  				return
   100  			}
   101  			// Successfully validated the thing. Get rid of the Relay header.
   102  			signedPayload = signedPayload[header.Len:]
   103  			// Pull the Roaming parts up here, and return in all call paths.
   104  			f.handleHostRoaming(hostinfo, addr)
   105  			// Track usage of both the HostInfo and the Relay for the received & authenticated packet
   106  			f.connectionManager.In(hostinfo.localIndexId)
   107  			f.connectionManager.RelayUsed(h.RemoteIndex)
   108  
   109  			relay, ok := hostinfo.relayState.QueryRelayForByIdx(h.RemoteIndex)
   110  			if !ok {
   111  				// The only way this happens is if hostmap has an index to the correct HostInfo, but the HostInfo is missing
   112  				// its internal mapping. This should never happen.
   113  				hostinfo.logger(f.l).WithFields(logrus.Fields{"vpnIp": hostinfo.vpnIp, "remoteIndex": h.RemoteIndex}).Error("HostInfo missing remote relay index")
   114  				return
   115  			}
   116  
   117  			switch relay.Type {
   118  			case TerminalType:
   119  				// If I am the target of this relay, process the unwrapped packet
   120  				// From this recursive point, all these variables are 'burned'. We shouldn't rely on them again.
   121  				f.readOutsidePackets(nil, &ViaSender{relayHI: hostinfo, remoteIdx: relay.RemoteIndex, relay: relay}, out[:0], signedPayload, h, fwPacket, lhf, nb, q, localCache)
   122  				return
   123  			case ForwardingType:
   124  				// Find the target HostInfo relay object
   125  				targetHI, targetRelay, err := f.hostMap.QueryVpnIpRelayFor(hostinfo.vpnIp, relay.PeerIp)
   126  				if err != nil {
   127  					hostinfo.logger(f.l).WithField("relayTo", relay.PeerIp).WithError(err).Info("Failed to find target host info by ip")
   128  					return
   129  				}
   130  
   131  				// If that relay is Established, forward the payload through it
   132  				if targetRelay.State == Established {
   133  					switch targetRelay.Type {
   134  					case ForwardingType:
   135  						// Forward this packet through the relay tunnel
   136  						// Find the target HostInfo
   137  						f.SendVia(targetHI, targetRelay, signedPayload, nb, out, false)
   138  						return
   139  					case TerminalType:
   140  						hostinfo.logger(f.l).Error("Unexpected Relay Type of Terminal")
   141  					}
   142  				} else {
   143  					hostinfo.logger(f.l).WithFields(logrus.Fields{"relayTo": relay.PeerIp, "relayFrom": hostinfo.vpnIp, "targetRelayState": targetRelay.State}).Info("Unexpected target relay state")
   144  					return
   145  				}
   146  			}
   147  		}
   148  
   149  	case header.LightHouse:
   150  		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
   151  		if !f.handleEncrypted(ci, addr, h) {
   152  			return
   153  		}
   154  
   155  		d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
   156  		if err != nil {
   157  			hostinfo.logger(f.l).WithError(err).WithField("udpAddr", addr).
   158  				WithField("packet", packet).
   159  				Error("Failed to decrypt lighthouse packet")
   160  
   161  			//TODO: maybe after build 64 is out? 06/14/2018 - NB
   162  			//f.sendRecvError(net.Addr(addr), header.RemoteIndex)
   163  			return
   164  		}
   165  
   166  		lhf(addr, hostinfo.vpnIp, d)
   167  
   168  		// Fallthrough to the bottom to record incoming traffic
   169  
   170  	case header.Test:
   171  		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
   172  		if !f.handleEncrypted(ci, addr, h) {
   173  			return
   174  		}
   175  
   176  		d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
   177  		if err != nil {
   178  			hostinfo.logger(f.l).WithError(err).WithField("udpAddr", addr).
   179  				WithField("packet", packet).
   180  				Error("Failed to decrypt test packet")
   181  
   182  			//TODO: maybe after build 64 is out? 06/14/2018 - NB
   183  			//f.sendRecvError(net.Addr(addr), header.RemoteIndex)
   184  			return
   185  		}
   186  
   187  		if h.Subtype == header.TestRequest {
   188  			// This testRequest might be from TryPromoteBest, so we should roam
   189  			// to the new IP address before responding
   190  			f.handleHostRoaming(hostinfo, addr)
   191  			f.send(header.Test, header.TestReply, ci, hostinfo, d, nb, out)
   192  		}
   193  
   194  		// Fallthrough to the bottom to record incoming traffic
   195  
   196  		// Non encrypted messages below here, they should not fall through to avoid tracking incoming traffic since they
   197  		// are unauthenticated
   198  
   199  	case header.Handshake:
   200  		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
   201  		f.handshakeManager.HandleIncoming(addr, via, packet, h)
   202  		return
   203  
   204  	case header.RecvError:
   205  		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
   206  		f.handleRecvError(addr, h)
   207  		return
   208  
   209  	case header.CloseTunnel:
   210  		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
   211  		if !f.handleEncrypted(ci, addr, h) {
   212  			return
   213  		}
   214  
   215  		hostinfo.logger(f.l).WithField("udpAddr", addr).
   216  			Info("Close tunnel received, tearing down.")
   217  
   218  		f.closeTunnel(hostinfo)
   219  		return
   220  
   221  	case header.Control:
   222  		if !f.handleEncrypted(ci, addr, h) {
   223  			return
   224  		}
   225  
   226  		d, err := f.decrypt(hostinfo, h.MessageCounter, out, packet, h, nb)
   227  		if err != nil {
   228  			hostinfo.logger(f.l).WithError(err).WithField("udpAddr", addr).
   229  				WithField("packet", packet).
   230  				Error("Failed to decrypt Control packet")
   231  			return
   232  		}
   233  		m := &NebulaControl{}
   234  		err = m.Unmarshal(d)
   235  		if err != nil {
   236  			hostinfo.logger(f.l).WithError(err).Error("Failed to unmarshal control message")
   237  			break
   238  		}
   239  
   240  		f.relayManager.HandleControlMsg(hostinfo, m, f)
   241  
   242  	default:
   243  		f.messageMetrics.Rx(h.Type, h.Subtype, 1)
   244  		hostinfo.logger(f.l).Debugf("Unexpected packet received from %s", addr)
   245  		return
   246  	}
   247  
   248  	f.handleHostRoaming(hostinfo, addr)
   249  
   250  	f.connectionManager.In(hostinfo.localIndexId)
   251  }
   252  
   253  // closeTunnel closes a tunnel locally, it does not send a closeTunnel packet to the remote
   254  func (f *Interface) closeTunnel(hostInfo *HostInfo) {
   255  	final := f.hostMap.DeleteHostInfo(hostInfo)
   256  	if final {
   257  		// We no longer have any tunnels with this vpn ip, clear learned lighthouse state to lower memory usage
   258  		f.lightHouse.DeleteVpnIp(hostInfo.vpnIp)
   259  	}
   260  }
   261  
   262  // sendCloseTunnel is a helper function to send a proper close tunnel packet to a remote
   263  func (f *Interface) sendCloseTunnel(h *HostInfo) {
   264  	f.send(header.CloseTunnel, 0, h.ConnectionState, h, []byte{}, make([]byte, 12, 12), make([]byte, mtu))
   265  }
   266  
   267  func (f *Interface) handleHostRoaming(hostinfo *HostInfo, addr *udp.Addr) {
   268  	if addr != nil && !hostinfo.remote.Equals(addr) {
   269  		if !f.lightHouse.GetRemoteAllowList().Allow(hostinfo.vpnIp, addr.IP) {
   270  			hostinfo.logger(f.l).WithField("newAddr", addr).Debug("lighthouse.remote_allow_list denied roaming")
   271  			return
   272  		}
   273  		if !hostinfo.lastRoam.IsZero() && addr.Equals(hostinfo.lastRoamRemote) && time.Since(hostinfo.lastRoam) < RoamingSuppressSeconds*time.Second {
   274  			if f.l.Level >= logrus.DebugLevel {
   275  				hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", addr).
   276  					Debugf("Suppressing roam back to previous remote for %d seconds", RoamingSuppressSeconds)
   277  			}
   278  			return
   279  		}
   280  
   281  		hostinfo.logger(f.l).WithField("udpAddr", hostinfo.remote).WithField("newAddr", addr).
   282  			Info("Host roamed to new udp ip/port.")
   283  		hostinfo.lastRoam = time.Now()
   284  		hostinfo.lastRoamRemote = hostinfo.remote
   285  		hostinfo.SetRemote(addr)
   286  	}
   287  
   288  }
   289  
   290  func (f *Interface) handleEncrypted(ci *ConnectionState, addr *udp.Addr, h *header.H) bool {
   291  	// If connectionstate exists and the replay protector allows, process packet
   292  	// Else, send recv errors for 300 seconds after a restart to allow fast reconnection.
   293  	if ci == nil || !ci.window.Check(f.l, h.MessageCounter) {
   294  		if addr != nil {
   295  			f.maybeSendRecvError(addr, h.RemoteIndex)
   296  			return false
   297  		} else {
   298  			return false
   299  		}
   300  	}
   301  
   302  	return true
   303  }
   304  
   305  // newPacket validates and parses the interesting bits for the firewall out of the ip and sub protocol headers
   306  func newPacket(data []byte, incoming bool, fp *firewall.Packet) error {
   307  	// Do we at least have an ipv4 header worth of data?
   308  	if len(data) < ipv4.HeaderLen {
   309  		return fmt.Errorf("packet is less than %v bytes", ipv4.HeaderLen)
   310  	}
   311  
   312  	// Is it an ipv4 packet?
   313  	if int((data[0]>>4)&0x0f) != 4 {
   314  		return fmt.Errorf("packet is not ipv4, type: %v", int((data[0]>>4)&0x0f))
   315  	}
   316  
   317  	// Adjust our start position based on the advertised ip header length
   318  	ihl := int(data[0]&0x0f) << 2
   319  
   320  	// Well formed ip header length?
   321  	if ihl < ipv4.HeaderLen {
   322  		return fmt.Errorf("packet had an invalid header length: %v", ihl)
   323  	}
   324  
   325  	// Check if this is the second or further fragment of a fragmented packet.
   326  	flagsfrags := binary.BigEndian.Uint16(data[6:8])
   327  	fp.Fragment = (flagsfrags & 0x1FFF) != 0
   328  
   329  	// Firewall handles protocol checks
   330  	fp.Protocol = data[9]
   331  
   332  	// Accounting for a variable header length, do we have enough data for our src/dst tuples?
   333  	minLen := ihl
   334  	if !fp.Fragment && fp.Protocol != firewall.ProtoICMP {
   335  		minLen += minFwPacketLen
   336  	}
   337  	if len(data) < minLen {
   338  		return fmt.Errorf("packet is less than %v bytes, ip header len: %v", minLen, ihl)
   339  	}
   340  
   341  	// Firewall packets are locally oriented
   342  	if incoming {
   343  		fp.RemoteIP = iputil.Ip2VpnIp(data[12:16])
   344  		fp.LocalIP = iputil.Ip2VpnIp(data[16:20])
   345  		if fp.Fragment || fp.Protocol == firewall.ProtoICMP {
   346  			fp.RemotePort = 0
   347  			fp.LocalPort = 0
   348  		} else {
   349  			fp.RemotePort = binary.BigEndian.Uint16(data[ihl : ihl+2])
   350  			fp.LocalPort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4])
   351  		}
   352  	} else {
   353  		fp.LocalIP = iputil.Ip2VpnIp(data[12:16])
   354  		fp.RemoteIP = iputil.Ip2VpnIp(data[16:20])
   355  		if fp.Fragment || fp.Protocol == firewall.ProtoICMP {
   356  			fp.RemotePort = 0
   357  			fp.LocalPort = 0
   358  		} else {
   359  			fp.LocalPort = binary.BigEndian.Uint16(data[ihl : ihl+2])
   360  			fp.RemotePort = binary.BigEndian.Uint16(data[ihl+2 : ihl+4])
   361  		}
   362  	}
   363  
   364  	return nil
   365  }
   366  
   367  func (f *Interface) decrypt(hostinfo *HostInfo, mc uint64, out []byte, packet []byte, h *header.H, nb []byte) ([]byte, error) {
   368  	var err error
   369  	out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], mc, nb)
   370  	if err != nil {
   371  		return nil, err
   372  	}
   373  
   374  	if !hostinfo.ConnectionState.window.Update(f.l, mc) {
   375  		hostinfo.logger(f.l).WithField("header", h).
   376  			Debugln("dropping out of window packet")
   377  		return nil, errors.New("out of window packet")
   378  	}
   379  
   380  	return out, nil
   381  }
   382  
   383  func (f *Interface) decryptToTun(hostinfo *HostInfo, messageCounter uint64, out []byte, packet []byte, fwPacket *firewall.Packet, nb []byte, q int, localCache firewall.ConntrackCache) bool {
   384  	var err error
   385  
   386  	out, err = hostinfo.ConnectionState.dKey.DecryptDanger(out, packet[:header.Len], packet[header.Len:], messageCounter, nb)
   387  	if err != nil {
   388  		hostinfo.logger(f.l).WithError(err).Error("Failed to decrypt packet")
   389  		//TODO: maybe after build 64 is out? 06/14/2018 - NB
   390  		//f.sendRecvError(hostinfo.remote, header.RemoteIndex)
   391  		return false
   392  	}
   393  
   394  	err = newPacket(out, true, fwPacket)
   395  	if err != nil {
   396  		hostinfo.logger(f.l).WithError(err).WithField("packet", out).
   397  			Warnf("Error while validating inbound packet")
   398  		return false
   399  	}
   400  
   401  	if !hostinfo.ConnectionState.window.Update(f.l, messageCounter) {
   402  		hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
   403  			Debugln("dropping out of window packet")
   404  		return false
   405  	}
   406  
   407  	dropReason := f.firewall.Drop(*fwPacket, true, hostinfo, f.pki.GetCAPool(), localCache)
   408  	if dropReason != nil {
   409  		// NOTE: We give `packet` as the `out` here since we already decrypted from it and we don't need it anymore
   410  		// This gives us a buffer to build the reject packet in
   411  		f.rejectOutside(out, hostinfo.ConnectionState, hostinfo, nb, packet, q)
   412  		if f.l.Level >= logrus.DebugLevel {
   413  			hostinfo.logger(f.l).WithField("fwPacket", fwPacket).
   414  				WithField("reason", dropReason).
   415  				Debugln("dropping inbound packet")
   416  		}
   417  		return false
   418  	}
   419  
   420  	f.connectionManager.In(hostinfo.localIndexId)
   421  	_, err = f.readers[q].Write(out)
   422  	if err != nil {
   423  		f.l.WithError(err).Error("Failed to write to tun")
   424  	}
   425  	return true
   426  }
   427  
   428  func (f *Interface) maybeSendRecvError(endpoint *udp.Addr, index uint32) {
   429  	if f.sendRecvErrorConfig.ShouldSendRecvError(endpoint.IP) {
   430  		f.sendRecvError(endpoint, index)
   431  	}
   432  }
   433  
   434  func (f *Interface) sendRecvError(endpoint *udp.Addr, index uint32) {
   435  	f.messageMetrics.Tx(header.RecvError, 0, 1)
   436  
   437  	//TODO: this should be a signed message so we can trust that we should drop the index
   438  	b := header.Encode(make([]byte, header.Len), header.Version, header.RecvError, 0, index, 0)
   439  	f.outside.WriteTo(b, endpoint)
   440  	if f.l.Level >= logrus.DebugLevel {
   441  		f.l.WithField("index", index).
   442  			WithField("udpAddr", endpoint).
   443  			Debug("Recv error sent")
   444  	}
   445  }
   446  
   447  func (f *Interface) handleRecvError(addr *udp.Addr, h *header.H) {
   448  	if f.l.Level >= logrus.DebugLevel {
   449  		f.l.WithField("index", h.RemoteIndex).
   450  			WithField("udpAddr", addr).
   451  			Debug("Recv error received")
   452  	}
   453  
   454  	hostinfo := f.hostMap.QueryReverseIndex(h.RemoteIndex)
   455  	if hostinfo == nil {
   456  		f.l.WithField("remoteIndex", h.RemoteIndex).Debugln("Did not find remote index in main hostmap")
   457  		return
   458  	}
   459  
   460  	if !hostinfo.RecvErrorExceeded() {
   461  		return
   462  	}
   463  
   464  	if hostinfo.remote != nil && !hostinfo.remote.Equals(addr) {
   465  		f.l.Infoln("Someone spoofing recv_errors? ", addr, hostinfo.remote)
   466  		return
   467  	}
   468  
   469  	f.closeTunnel(hostinfo)
   470  	// We also delete it from pending hostmap to allow for fast reconnect.
   471  	f.handshakeManager.DeleteHostInfo(hostinfo)
   472  }
   473  
   474  /*
   475  func (f *Interface) sendMeta(ci *ConnectionState, endpoint *net.UDPAddr, meta *NebulaMeta) {
   476  	if ci.eKey != nil {
   477  		//TODO: log error?
   478  		return
   479  	}
   480  
   481  	msg, err := proto.Marshal(meta)
   482  	if err != nil {
   483  		l.Debugln("failed to encode header")
   484  	}
   485  
   486  	c := ci.messageCounter
   487  	b := HeaderEncode(nil, Version, uint8(metadata), 0, hostinfo.remoteIndexId, c)
   488  	ci.messageCounter++
   489  
   490  	msg := ci.eKey.EncryptDanger(b, nil, msg, c)
   491  	//msg := ci.eKey.EncryptDanger(b, nil, []byte(fmt.Sprintf("%d", counter)), c)
   492  	f.outside.WriteTo(msg, endpoint)
   493  }
   494  */
   495  
   496  func RecombineCertAndValidate(h *noise.HandshakeState, rawCertBytes []byte, caPool *cert.NebulaCAPool) (*cert.NebulaCertificate, error) {
   497  	pk := h.PeerStatic()
   498  
   499  	if pk == nil {
   500  		return nil, errors.New("no peer static key was present")
   501  	}
   502  
   503  	if rawCertBytes == nil {
   504  		return nil, errors.New("provided payload was empty")
   505  	}
   506  
   507  	r := &cert.RawNebulaCertificate{}
   508  	err := proto.Unmarshal(rawCertBytes, r)
   509  	if err != nil {
   510  		return nil, fmt.Errorf("error unmarshaling cert: %s", err)
   511  	}
   512  
   513  	// If the Details are nil, just exit to avoid crashing
   514  	if r.Details == nil {
   515  		return nil, fmt.Errorf("certificate did not contain any details")
   516  	}
   517  
   518  	r.Details.PublicKey = pk
   519  	recombined, err := proto.Marshal(r)
   520  	if err != nil {
   521  		return nil, fmt.Errorf("error while recombining certificate: %s", err)
   522  	}
   523  
   524  	c, _ := cert.UnmarshalNebulaCertificate(recombined)
   525  	isValid, err := c.Verify(time.Now(), caPool)
   526  	if err != nil {
   527  		return c, fmt.Errorf("certificate validation failed: %s", err)
   528  	} else if !isValid {
   529  		// This case should never happen but here's to defensive programming!
   530  		return c, errors.New("certificate validation failed but did not return an error")
   531  	}
   532  
   533  	return c, nil
   534  }