gitee.com/aurawing/surguard-go@v0.3.1-0.20240409071558-96509a61ecf3/device/ip.go (about)

     1  /* SPDX-License-Identifier: MIT
     2   *
     3   * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
     4   */
     5  
     6  package device
     7  
     8  import (
     9  	"encoding/binary"
    10  	"net"
    11  )
    12  
    13  const (
    14  	IPv4offsetTotalLength = 2
    15  	IPv4offsetSrc         = 12
    16  	IPv4offsetDst         = IPv4offsetSrc + net.IPv4len
    17  	IPv4offsetProtoType   = 9
    18  )
    19  
    20  const (
    21  	IPv6offsetPayloadLength = 4
    22  	IPv6offsetSrc           = 8
    23  	IPv6offsetDst           = IPv6offsetSrc + net.IPv6len
    24  )
    25  
    26  const (
    27  	UDPv4offset            = 20
    28  	UDPv4offsetTotalLength = 24
    29  	UDPv4offsetChecksum    = 26
    30  )
    31  
    32  const (
    33  	TCPv4offset         = 20
    34  	TCPv4offsetChecksum = 36
    35  )
    36  
    37  func IPv4CheckSum(data []byte) uint16 {
    38  	var (
    39  		sum    uint32
    40  		length int = len(data)
    41  		index  int
    42  	)
    43  	//以每16位为单位进行求和,直到所有的字节全部求完或者只剩下一个8位字节(如果剩余一个8位字节说明字节数为奇数个)
    44  	for length > 1 {
    45  		sum += uint32(data[index])<<8 + uint32(data[index+1])
    46  		index += 2
    47  		length -= 2
    48  	}
    49  	//如果字节数为奇数个,要加上最后剩下的那个8位字节
    50  	if length > 0 {
    51  		sum += uint32(data[index])
    52  	}
    53  	//加上高16位进位的部分
    54  	for {
    55  		if sum>>16 == 0 {
    56  			break
    57  		}
    58  		sum = (sum & 0xffff) + (sum >> 16)
    59  	}
    60  	//别忘了返回的时候先求反
    61  	return uint16(^sum)
    62  }
    63  
    64  func checksumPartial(data []byte) uint32 {
    65  	var (
    66  		sum    uint32
    67  		length int = len(data)
    68  		index  int
    69  	)
    70  	//以每16位为单位进行求和,直到所有的字节全部求完或者只剩下一个8位字节(如果剩余一个8位字节说明字节数为奇数个)
    71  	for length > 1 {
    72  		sum += uint32(data[index])<<8 + uint32(data[index+1])
    73  		index += 2
    74  		length -= 2
    75  	}
    76  	//如果字节数为奇数个,要加上最后剩下的那个8位字节
    77  	if length > 0 {
    78  		sum += uint32(data[index]) << 8
    79  	}
    80  	return sum
    81  }
    82  
    83  func UDPv4CheckSum(packet []byte) {
    84  	packet[UDPv4offsetChecksum] = 0
    85  	packet[UDPv4offsetChecksum+1] = 0
    86  	csum := checksumPartial(packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len])
    87  	csum += checksumPartial(packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len])
    88  	csum += uint32(packet[IPv4offsetProtoType])
    89  	csum += checksumPartial(packet[UDPv4offsetTotalLength : UDPv4offsetTotalLength+2])
    90  
    91  	udplen := binary.BigEndian.Uint16(packet[IPv4offsetTotalLength:IPv4offsetTotalLength+2]) - 20
    92  	csum += checksumPartial(packet[UDPv4offset : UDPv4offset+udplen])
    93  	for {
    94  		if csum>>16 == 0 {
    95  			break
    96  		}
    97  		csum = (csum & 0xffff) + (csum >> 16)
    98  	}
    99  	checksum := uint16(^csum)
   100  	binary.BigEndian.PutUint16(packet[UDPv4offsetChecksum:UDPv4offsetChecksum+2], checksum)
   101  }
   102  
   103  func TCPv4CheckSum(packet []byte) {
   104  	packet[TCPv4offsetChecksum] = 0
   105  	packet[TCPv4offsetChecksum+1] = 0
   106  	csum := checksumPartial(packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len])
   107  	csum += checksumPartial(packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len])
   108  	csum += uint32(packet[IPv4offsetProtoType])
   109  	tcplen := binary.BigEndian.Uint16(packet[IPv4offsetTotalLength:IPv4offsetTotalLength+2]) - 20
   110  	csum += uint32(tcplen)
   111  	csum += checksumPartial(packet[TCPv4offset : TCPv4offset+tcplen])
   112  	for {
   113  		if csum>>16 == 0 {
   114  			break
   115  		}
   116  		csum = (csum & 0xffff) + (csum >> 16)
   117  	}
   118  	checksum := uint16(^csum)
   119  	binary.BigEndian.PutUint16(packet[TCPv4offsetChecksum:TCPv4offsetChecksum+2], checksum)
   120  }