github.com/amnezia-vpn/amneziawg-go@v0.2.8/tun/checksum.go (about)

     1  package tun
     2  
     3  import "encoding/binary"
     4  
     5  // TODO: Explore SIMD and/or other assembly optimizations.
     6  // TODO: Test native endian loads. See RFC 1071 section 2 part B.
     7  func checksumNoFold(b []byte, initial uint64) uint64 {
     8  	ac := initial
     9  
    10  	for len(b) >= 128 {
    11  		ac += uint64(binary.BigEndian.Uint32(b[:4]))
    12  		ac += uint64(binary.BigEndian.Uint32(b[4:8]))
    13  		ac += uint64(binary.BigEndian.Uint32(b[8:12]))
    14  		ac += uint64(binary.BigEndian.Uint32(b[12:16]))
    15  		ac += uint64(binary.BigEndian.Uint32(b[16:20]))
    16  		ac += uint64(binary.BigEndian.Uint32(b[20:24]))
    17  		ac += uint64(binary.BigEndian.Uint32(b[24:28]))
    18  		ac += uint64(binary.BigEndian.Uint32(b[28:32]))
    19  		ac += uint64(binary.BigEndian.Uint32(b[32:36]))
    20  		ac += uint64(binary.BigEndian.Uint32(b[36:40]))
    21  		ac += uint64(binary.BigEndian.Uint32(b[40:44]))
    22  		ac += uint64(binary.BigEndian.Uint32(b[44:48]))
    23  		ac += uint64(binary.BigEndian.Uint32(b[48:52]))
    24  		ac += uint64(binary.BigEndian.Uint32(b[52:56]))
    25  		ac += uint64(binary.BigEndian.Uint32(b[56:60]))
    26  		ac += uint64(binary.BigEndian.Uint32(b[60:64]))
    27  		ac += uint64(binary.BigEndian.Uint32(b[64:68]))
    28  		ac += uint64(binary.BigEndian.Uint32(b[68:72]))
    29  		ac += uint64(binary.BigEndian.Uint32(b[72:76]))
    30  		ac += uint64(binary.BigEndian.Uint32(b[76:80]))
    31  		ac += uint64(binary.BigEndian.Uint32(b[80:84]))
    32  		ac += uint64(binary.BigEndian.Uint32(b[84:88]))
    33  		ac += uint64(binary.BigEndian.Uint32(b[88:92]))
    34  		ac += uint64(binary.BigEndian.Uint32(b[92:96]))
    35  		ac += uint64(binary.BigEndian.Uint32(b[96:100]))
    36  		ac += uint64(binary.BigEndian.Uint32(b[100:104]))
    37  		ac += uint64(binary.BigEndian.Uint32(b[104:108]))
    38  		ac += uint64(binary.BigEndian.Uint32(b[108:112]))
    39  		ac += uint64(binary.BigEndian.Uint32(b[112:116]))
    40  		ac += uint64(binary.BigEndian.Uint32(b[116:120]))
    41  		ac += uint64(binary.BigEndian.Uint32(b[120:124]))
    42  		ac += uint64(binary.BigEndian.Uint32(b[124:128]))
    43  		b = b[128:]
    44  	}
    45  	if len(b) >= 64 {
    46  		ac += uint64(binary.BigEndian.Uint32(b[:4]))
    47  		ac += uint64(binary.BigEndian.Uint32(b[4:8]))
    48  		ac += uint64(binary.BigEndian.Uint32(b[8:12]))
    49  		ac += uint64(binary.BigEndian.Uint32(b[12:16]))
    50  		ac += uint64(binary.BigEndian.Uint32(b[16:20]))
    51  		ac += uint64(binary.BigEndian.Uint32(b[20:24]))
    52  		ac += uint64(binary.BigEndian.Uint32(b[24:28]))
    53  		ac += uint64(binary.BigEndian.Uint32(b[28:32]))
    54  		ac += uint64(binary.BigEndian.Uint32(b[32:36]))
    55  		ac += uint64(binary.BigEndian.Uint32(b[36:40]))
    56  		ac += uint64(binary.BigEndian.Uint32(b[40:44]))
    57  		ac += uint64(binary.BigEndian.Uint32(b[44:48]))
    58  		ac += uint64(binary.BigEndian.Uint32(b[48:52]))
    59  		ac += uint64(binary.BigEndian.Uint32(b[52:56]))
    60  		ac += uint64(binary.BigEndian.Uint32(b[56:60]))
    61  		ac += uint64(binary.BigEndian.Uint32(b[60:64]))
    62  		b = b[64:]
    63  	}
    64  	if len(b) >= 32 {
    65  		ac += uint64(binary.BigEndian.Uint32(b[:4]))
    66  		ac += uint64(binary.BigEndian.Uint32(b[4:8]))
    67  		ac += uint64(binary.BigEndian.Uint32(b[8:12]))
    68  		ac += uint64(binary.BigEndian.Uint32(b[12:16]))
    69  		ac += uint64(binary.BigEndian.Uint32(b[16:20]))
    70  		ac += uint64(binary.BigEndian.Uint32(b[20:24]))
    71  		ac += uint64(binary.BigEndian.Uint32(b[24:28]))
    72  		ac += uint64(binary.BigEndian.Uint32(b[28:32]))
    73  		b = b[32:]
    74  	}
    75  	if len(b) >= 16 {
    76  		ac += uint64(binary.BigEndian.Uint32(b[:4]))
    77  		ac += uint64(binary.BigEndian.Uint32(b[4:8]))
    78  		ac += uint64(binary.BigEndian.Uint32(b[8:12]))
    79  		ac += uint64(binary.BigEndian.Uint32(b[12:16]))
    80  		b = b[16:]
    81  	}
    82  	if len(b) >= 8 {
    83  		ac += uint64(binary.BigEndian.Uint32(b[:4]))
    84  		ac += uint64(binary.BigEndian.Uint32(b[4:8]))
    85  		b = b[8:]
    86  	}
    87  	if len(b) >= 4 {
    88  		ac += uint64(binary.BigEndian.Uint32(b))
    89  		b = b[4:]
    90  	}
    91  	if len(b) >= 2 {
    92  		ac += uint64(binary.BigEndian.Uint16(b))
    93  		b = b[2:]
    94  	}
    95  	if len(b) == 1 {
    96  		ac += uint64(b[0]) << 8
    97  	}
    98  
    99  	return ac
   100  }
   101  
   102  func checksum(b []byte, initial uint64) uint16 {
   103  	ac := checksumNoFold(b, initial)
   104  	ac = (ac >> 16) + (ac & 0xffff)
   105  	ac = (ac >> 16) + (ac & 0xffff)
   106  	ac = (ac >> 16) + (ac & 0xffff)
   107  	ac = (ac >> 16) + (ac & 0xffff)
   108  	return uint16(ac)
   109  }
   110  
   111  func pseudoHeaderChecksumNoFold(protocol uint8, srcAddr, dstAddr []byte, totalLen uint16) uint64 {
   112  	sum := checksumNoFold(srcAddr, 0)
   113  	sum = checksumNoFold(dstAddr, sum)
   114  	sum = checksumNoFold([]byte{0, protocol}, sum)
   115  	tmp := make([]byte, 2)
   116  	binary.BigEndian.PutUint16(tmp, totalLen)
   117  	return checksumNoFold(tmp, sum)
   118  }