github.com/slackhq/nebula@v1.9.0/iputil/packet.go (about)

     1  package iputil
     2  
     3  import (
     4  	"encoding/binary"
     5  
     6  	"golang.org/x/net/ipv4"
     7  )
     8  
     9  const (
    10  	// Need 96 bytes for the largest reject packet:
    11  	// - 20 byte ipv4 header
    12  	// - 8 byte icmpv4 header
    13  	// - 68 byte body (60 byte max orig ipv4 header + 8 byte orig icmpv4 header)
    14  	MaxRejectPacketSize = ipv4.HeaderLen + 8 + 60 + 8
    15  )
    16  
    17  func CreateRejectPacket(packet []byte, out []byte) []byte {
    18  	if len(packet) < ipv4.HeaderLen || int(packet[0]>>4) != ipv4.Version {
    19  		return nil
    20  	}
    21  
    22  	switch packet[9] {
    23  	case 6: // tcp
    24  		return ipv4CreateRejectTCPPacket(packet, out)
    25  	default:
    26  		return ipv4CreateRejectICMPPacket(packet, out)
    27  	}
    28  }
    29  
    30  func ipv4CreateRejectICMPPacket(packet []byte, out []byte) []byte {
    31  	ihl := int(packet[0]&0x0f) << 2
    32  
    33  	if len(packet) < ihl {
    34  		// We need at least this many bytes for this to be a valid packet
    35  		return nil
    36  	}
    37  
    38  	// ICMP reply includes original header and first 8 bytes of the packet
    39  	packetLen := len(packet)
    40  	if packetLen > ihl+8 {
    41  		packetLen = ihl + 8
    42  	}
    43  
    44  	outLen := ipv4.HeaderLen + 8 + packetLen
    45  	if outLen > cap(out) {
    46  		return nil
    47  	}
    48  
    49  	out = out[:outLen]
    50  
    51  	ipHdr := out[0:ipv4.HeaderLen]
    52  	ipHdr[0] = ipv4.Version<<4 | (ipv4.HeaderLen >> 2)    // version, ihl
    53  	ipHdr[1] = 0                                          // DSCP, ECN
    54  	binary.BigEndian.PutUint16(ipHdr[2:], uint16(outLen)) // Total Length
    55  
    56  	ipHdr[4] = 0  // id
    57  	ipHdr[5] = 0  //  .
    58  	ipHdr[6] = 0  // flags, fragment offset
    59  	ipHdr[7] = 0  //  .
    60  	ipHdr[8] = 64 // TTL
    61  	ipHdr[9] = 1  // protocol (icmp)
    62  	ipHdr[10] = 0 // checksum
    63  	ipHdr[11] = 0 //  .
    64  
    65  	// Swap dest / src IPs
    66  	copy(ipHdr[12:16], packet[16:20])
    67  	copy(ipHdr[16:20], packet[12:16])
    68  
    69  	// Calculate checksum
    70  	binary.BigEndian.PutUint16(ipHdr[10:], tcpipChecksum(ipHdr, 0))
    71  
    72  	// ICMP Destination Unreachable
    73  	icmpOut := out[ipv4.HeaderLen:]
    74  	icmpOut[0] = 3 // type (Destination unreachable)
    75  	icmpOut[1] = 3 // code (Port unreachable error)
    76  	icmpOut[2] = 0 // checksum
    77  	icmpOut[3] = 0 //  .
    78  	icmpOut[4] = 0 // unused
    79  	icmpOut[5] = 0 //  .
    80  	icmpOut[6] = 0 //  .
    81  	icmpOut[7] = 0 //  .
    82  
    83  	// Copy original IP header and first 8 bytes as body
    84  	copy(icmpOut[8:], packet[:packetLen])
    85  
    86  	// Calculate checksum
    87  	binary.BigEndian.PutUint16(icmpOut[2:], tcpipChecksum(icmpOut, 0))
    88  
    89  	return out
    90  }
    91  
    92  func ipv4CreateRejectTCPPacket(packet []byte, out []byte) []byte {
    93  	const tcpLen = 20
    94  
    95  	ihl := int(packet[0]&0x0f) << 2
    96  	outLen := ipv4.HeaderLen + tcpLen
    97  
    98  	if len(packet) < ihl+tcpLen {
    99  		// We need at least this many bytes for this to be a valid packet
   100  		return nil
   101  	}
   102  	if outLen > cap(out) {
   103  		return nil
   104  	}
   105  
   106  	out = out[:outLen]
   107  
   108  	ipHdr := out[0:ipv4.HeaderLen]
   109  	ipHdr[0] = ipv4.Version<<4 | (ipv4.HeaderLen >> 2)    // version, ihl
   110  	ipHdr[1] = 0                                          // DSCP, ECN
   111  	binary.BigEndian.PutUint16(ipHdr[2:], uint16(outLen)) // Total Length
   112  	ipHdr[4] = 0                                          // id
   113  	ipHdr[5] = 0                                          //  .
   114  	ipHdr[6] = 0                                          // flags, fragment offset
   115  	ipHdr[7] = 0                                          //  .
   116  	ipHdr[8] = 64                                         // TTL
   117  	ipHdr[9] = 6                                          // protocol (tcp)
   118  	ipHdr[10] = 0                                         // checksum
   119  	ipHdr[11] = 0                                         //  .
   120  
   121  	// Swap dest / src IPs
   122  	copy(ipHdr[12:16], packet[16:20])
   123  	copy(ipHdr[16:20], packet[12:16])
   124  
   125  	// Calculate checksum
   126  	binary.BigEndian.PutUint16(ipHdr[10:], tcpipChecksum(ipHdr, 0))
   127  
   128  	// TCP RST
   129  	tcpIn := packet[ihl:]
   130  	var ackSeq, seq uint32
   131  	outFlags := byte(0b00000100) // RST
   132  
   133  	// Set seq and ackSeq based on how iptables/netfilter does it in Linux:
   134  	// - https://github.com/torvalds/linux/blob/v5.19/net/ipv4/netfilter/nf_reject_ipv4.c#L193-L221
   135  	inAck := tcpIn[13]&0b00010000 != 0
   136  	if inAck {
   137  		seq = binary.BigEndian.Uint32(tcpIn[8:])
   138  	} else {
   139  		inSyn := uint32((tcpIn[13] & 0b00000010) >> 1)
   140  		inFin := uint32(tcpIn[13] & 0b00000001)
   141  		// seq from the packet + syn + fin + tcp segment length
   142  		ackSeq = binary.BigEndian.Uint32(tcpIn[4:]) + inSyn + inFin + uint32(len(tcpIn)) - uint32(tcpIn[12]>>4)<<2
   143  		outFlags |= 0b00010000 // ACK
   144  	}
   145  
   146  	tcpOut := out[ipv4.HeaderLen:]
   147  	// Swap dest / src ports
   148  	copy(tcpOut[0:2], tcpIn[2:4])
   149  	copy(tcpOut[2:4], tcpIn[0:2])
   150  	binary.BigEndian.PutUint32(tcpOut[4:], seq)
   151  	binary.BigEndian.PutUint32(tcpOut[8:], ackSeq)
   152  	tcpOut[12] = (tcpLen >> 2) << 4 // data offset,  reserved,  NS
   153  	tcpOut[13] = outFlags           // CWR, ECE, URG, ACK, PSH, RST, SYN, FIN
   154  	tcpOut[14] = 0                  // window size
   155  	tcpOut[15] = 0                  //  .
   156  	tcpOut[16] = 0                  // checksum
   157  	tcpOut[17] = 0                  //  .
   158  	tcpOut[18] = 0                  // URG Pointer
   159  	tcpOut[19] = 0                  //  .
   160  
   161  	// Calculate checksum
   162  	csum := ipv4PseudoheaderChecksum(ipHdr[12:16], ipHdr[16:20], 6, tcpLen)
   163  	binary.BigEndian.PutUint16(tcpOut[16:], tcpipChecksum(tcpOut, csum))
   164  
   165  	return out
   166  }
   167  
   168  func CreateICMPEchoResponse(packet, out []byte) []byte {
   169  	// Return early if this is not a simple ICMP Echo Request
   170  	//TODO: make constants out of these
   171  	if !(len(packet) >= 28 && len(packet) <= 9001 && packet[0] == 0x45 && packet[9] == 0x01 && packet[20] == 0x08) {
   172  		return nil
   173  	}
   174  
   175  	// We don't support fragmented packets
   176  	if packet[7] != 0 || (packet[6]&0x2F != 0) {
   177  		return nil
   178  	}
   179  
   180  	out = out[:len(packet)]
   181  
   182  	copy(out, packet)
   183  
   184  	// Swap dest / src IPs and recalculate checksum
   185  	ipv4 := out[0:20]
   186  	copy(ipv4[12:16], packet[16:20])
   187  	copy(ipv4[16:20], packet[12:16])
   188  	ipv4[10] = 0
   189  	ipv4[11] = 0
   190  	binary.BigEndian.PutUint16(ipv4[10:], tcpipChecksum(ipv4, 0))
   191  
   192  	// Change type to ICMP Echo Reply and recalculate checksum
   193  	icmp := out[20:]
   194  	icmp[0] = 0
   195  	icmp[2] = 0
   196  	icmp[3] = 0
   197  	binary.BigEndian.PutUint16(icmp[2:], tcpipChecksum(icmp, 0))
   198  
   199  	return out
   200  }
   201  
   202  // calculates the TCP/IP checksum defined in rfc1071. The passed-in
   203  // csum is any initial checksum data that's already been computed.
   204  //
   205  // based on:
   206  // - https://github.com/google/gopacket/blob/v1.1.19/layers/tcpip.go#L50-L70
   207  func tcpipChecksum(data []byte, csum uint32) uint16 {
   208  	// to handle odd lengths, we loop to length - 1, incrementing by 2, then
   209  	// handle the last byte specifically by checking against the original
   210  	// length.
   211  	length := len(data) - 1
   212  	for i := 0; i < length; i += 2 {
   213  		// For our test packet, doing this manually is about 25% faster
   214  		// (740 ns vs. 1000ns) than doing it by calling binary.BigEndian.Uint16.
   215  		csum += uint32(data[i]) << 8
   216  		csum += uint32(data[i+1])
   217  	}
   218  	if len(data)%2 == 1 {
   219  		csum += uint32(data[length]) << 8
   220  	}
   221  	for csum > 0xffff {
   222  		csum = (csum >> 16) + (csum & 0xffff)
   223  	}
   224  	return ^uint16(csum)
   225  }
   226  
   227  // based on:
   228  // - https://github.com/google/gopacket/blob/v1.1.19/layers/tcpip.go#L26-L35
   229  func ipv4PseudoheaderChecksum(src, dst []byte, proto, length uint32) (csum uint32) {
   230  	csum += (uint32(src[0]) + uint32(src[2])) << 8
   231  	csum += uint32(src[1]) + uint32(src[3])
   232  	csum += (uint32(dst[0]) + uint32(dst[2])) << 8
   233  	csum += uint32(dst[1]) + uint32(dst[3])
   234  	csum += proto
   235  	csum += length & 0xffff
   236  	csum += length >> 16
   237  	return csum
   238  }