github.com/nicocha30/gvisor-ligolo@v0.0.0-20230726075806-989fa2c0a413/pkg/tcpip/checksum/checksum.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package checksum provides the implementation of the encoding and decoding of
    16  // network protocol headers.
    17  package checksum
    18  
    19  import (
    20  	"encoding/binary"
    21  )
    22  
    23  // Size is the size of a checksum.
    24  //
    25  // The checksum is held in a uint16 which is 2 bytes.
    26  const Size = 2
    27  
    28  // Put puts the checksum in the provided byte slice.
    29  func Put(b []byte, xsum uint16) {
    30  	binary.BigEndian.PutUint16(b, xsum)
    31  }
    32  
    33  func unrolledCalculateChecksum(buf []byte, odd bool, initial uint16) (uint16, bool) {
    34  	v := uint32(initial)
    35  
    36  	if odd {
    37  		v += uint32(buf[0])
    38  		buf = buf[1:]
    39  	}
    40  
    41  	l := len(buf)
    42  	odd = l&1 != 0
    43  	if odd {
    44  		l--
    45  		v += uint32(buf[l]) << 8
    46  	}
    47  	for (l - 64) >= 0 {
    48  		i := 0
    49  		v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
    50  		v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3])
    51  		v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5])
    52  		v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7])
    53  		v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9])
    54  		v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11])
    55  		v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13])
    56  		v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15])
    57  		i += 16
    58  		v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
    59  		v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3])
    60  		v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5])
    61  		v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7])
    62  		v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9])
    63  		v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11])
    64  		v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13])
    65  		v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15])
    66  		i += 16
    67  		v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
    68  		v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3])
    69  		v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5])
    70  		v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7])
    71  		v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9])
    72  		v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11])
    73  		v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13])
    74  		v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15])
    75  		i += 16
    76  		v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
    77  		v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3])
    78  		v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5])
    79  		v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7])
    80  		v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9])
    81  		v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11])
    82  		v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13])
    83  		v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15])
    84  		buf = buf[64:]
    85  		l = l - 64
    86  	}
    87  	if (l - 32) >= 0 {
    88  		i := 0
    89  		v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
    90  		v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3])
    91  		v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5])
    92  		v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7])
    93  		v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9])
    94  		v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11])
    95  		v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13])
    96  		v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15])
    97  		i += 16
    98  		v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
    99  		v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3])
   100  		v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5])
   101  		v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7])
   102  		v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9])
   103  		v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11])
   104  		v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13])
   105  		v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15])
   106  		buf = buf[32:]
   107  		l = l - 32
   108  	}
   109  	if (l - 16) >= 0 {
   110  		i := 0
   111  		v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
   112  		v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3])
   113  		v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5])
   114  		v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7])
   115  		v += (uint32(buf[i+8]) << 8) + uint32(buf[i+9])
   116  		v += (uint32(buf[i+10]) << 8) + uint32(buf[i+11])
   117  		v += (uint32(buf[i+12]) << 8) + uint32(buf[i+13])
   118  		v += (uint32(buf[i+14]) << 8) + uint32(buf[i+15])
   119  		buf = buf[16:]
   120  		l = l - 16
   121  	}
   122  	if (l - 8) >= 0 {
   123  		i := 0
   124  		v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
   125  		v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3])
   126  		v += (uint32(buf[i+4]) << 8) + uint32(buf[i+5])
   127  		v += (uint32(buf[i+6]) << 8) + uint32(buf[i+7])
   128  		buf = buf[8:]
   129  		l = l - 8
   130  	}
   131  	if (l - 4) >= 0 {
   132  		i := 0
   133  		v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
   134  		v += (uint32(buf[i+2]) << 8) + uint32(buf[i+3])
   135  		buf = buf[4:]
   136  		l = l - 4
   137  	}
   138  
   139  	// At this point since l was even before we started unrolling
   140  	// there can be only two bytes left to add.
   141  	if l != 0 {
   142  		v += (uint32(buf[0]) << 8) + uint32(buf[1])
   143  	}
   144  
   145  	return Combine(uint16(v), uint16(v>>16)), odd
   146  }
   147  
   148  // Checksum calculates the checksum (as defined in RFC 1071) of the bytes in the
   149  // given byte array. This function uses an optimized version of the checksum
   150  // algorithm.
   151  //
   152  // The initial checksum must have been computed on an even number of bytes.
   153  func Checksum(buf []byte, initial uint16) uint16 {
   154  	s, _ := calculateChecksum(buf, false, initial)
   155  	return s
   156  }
   157  
   158  // Checksumer calculates checksum defined in RFC 1071.
   159  type Checksumer struct {
   160  	sum uint16
   161  	odd bool
   162  }
   163  
   164  // Add adds b to checksum.
   165  func (c *Checksumer) Add(b []byte) {
   166  	if len(b) > 0 {
   167  		c.sum, c.odd = calculateChecksum(b, c.odd, c.sum)
   168  	}
   169  }
   170  
   171  // Checksum returns the latest checksum value.
   172  func (c *Checksumer) Checksum() uint16 {
   173  	return c.sum
   174  }
   175  
   176  // Combine combines the two uint16 to form their checksum. This is done
   177  // by adding them and the carry.
   178  //
   179  // Note that checksum a must have been computed on an even number of bytes.
   180  func Combine(a, b uint16) uint16 {
   181  	v := uint32(a) + uint32(b)
   182  	return uint16(v + v>>16)
   183  }