github.com/apernet/sing-tun@v0.2.6-0.20240323130332-b9f6511036ad/internal/clashtcpip/tcp.go (about)

     1  package clashtcpip
     2  
     3  import (
     4  	"encoding/binary"
     5  	"net"
     6  )
     7  
     8  const (
     9  	TCPFin uint16 = 1 << 0
    10  	TCPSyn uint16 = 1 << 1
    11  	TCPRst uint16 = 1 << 2
    12  	TCPPuh uint16 = 1 << 3
    13  	TCPAck uint16 = 1 << 4
    14  	TCPUrg uint16 = 1 << 5
    15  	TCPEce uint16 = 1 << 6
    16  	TCPEwr uint16 = 1 << 7
    17  	TCPNs  uint16 = 1 << 8
    18  )
    19  
    20  const TCPHeaderSize = 20
    21  
    22  type TCPPacket []byte
    23  
    24  func (p TCPPacket) SourcePort() uint16 {
    25  	return binary.BigEndian.Uint16(p)
    26  }
    27  
    28  func (p TCPPacket) SetSourcePort(port uint16) {
    29  	binary.BigEndian.PutUint16(p, port)
    30  }
    31  
    32  func (p TCPPacket) DestinationPort() uint16 {
    33  	return binary.BigEndian.Uint16(p[2:])
    34  }
    35  
    36  func (p TCPPacket) SetDestinationPort(port uint16) {
    37  	binary.BigEndian.PutUint16(p[2:], port)
    38  }
    39  
    40  func (p TCPPacket) Flags() uint16 {
    41  	return uint16(p[13] | (p[12] & 0x1))
    42  }
    43  
    44  func (p TCPPacket) Checksum() uint16 {
    45  	return binary.BigEndian.Uint16(p[16:])
    46  }
    47  
    48  func (p TCPPacket) SetChecksum(sum [2]byte) {
    49  	p[16] = sum[0]
    50  	p[17] = sum[1]
    51  }
    52  
    53  func (p TCPPacket) OffloadChecksum() {
    54  	p.SetChecksum(zeroChecksum)
    55  }
    56  
    57  func (p TCPPacket) ResetChecksum(psum uint32) {
    58  	p.SetChecksum(zeroChecksum)
    59  	p.SetChecksum(Checksum(psum, p))
    60  }
    61  
    62  func (p TCPPacket) Valid() bool {
    63  	return len(p) >= TCPHeaderSize
    64  }
    65  
    66  func (p TCPPacket) Verify(sourceAddress net.IP, targetAddress net.IP) error {
    67  	var checksum [2]byte
    68  	checksum[0] = p[16]
    69  	checksum[1] = p[17]
    70  
    71  	// reset checksum
    72  	p[16] = 0
    73  	p[17] = 0
    74  
    75  	// restore checksum
    76  	defer func() {
    77  		p[16] = checksum[0]
    78  		p[17] = checksum[1]
    79  	}()
    80  
    81  	// check checksum
    82  	s := uint32(0)
    83  	s += Sum(sourceAddress)
    84  	s += Sum(targetAddress)
    85  	s += uint32(TCP)
    86  	s += uint32(len(p))
    87  
    88  	check := Checksum(s, p)
    89  	if checksum[0] != check[0] || checksum[1] != check[1] {
    90  		return ErrInvalidChecksum
    91  	}
    92  
    93  	return nil
    94  }