golang.org/x/net@v0.25.1-0.20240516223405-c87a5b62e243/http2/hpack/huffman.go (about)

     1  // Copyright 2014 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package hpack
     6  
     7  import (
     8  	"bytes"
     9  	"errors"
    10  	"io"
    11  	"sync"
    12  )
    13  
    14  var bufPool = sync.Pool{
    15  	New: func() interface{} { return new(bytes.Buffer) },
    16  }
    17  
    18  // HuffmanDecode decodes the string in v and writes the expanded
    19  // result to w, returning the number of bytes written to w and the
    20  // Write call's return value. At most one Write call is made.
    21  func HuffmanDecode(w io.Writer, v []byte) (int, error) {
    22  	buf := bufPool.Get().(*bytes.Buffer)
    23  	buf.Reset()
    24  	defer bufPool.Put(buf)
    25  	if err := huffmanDecode(buf, 0, v); err != nil {
    26  		return 0, err
    27  	}
    28  	return w.Write(buf.Bytes())
    29  }
    30  
    31  // HuffmanDecodeToString decodes the string in v.
    32  func HuffmanDecodeToString(v []byte) (string, error) {
    33  	buf := bufPool.Get().(*bytes.Buffer)
    34  	buf.Reset()
    35  	defer bufPool.Put(buf)
    36  	if err := huffmanDecode(buf, 0, v); err != nil {
    37  		return "", err
    38  	}
    39  	return buf.String(), nil
    40  }
    41  
    42  // ErrInvalidHuffman is returned for errors found decoding
    43  // Huffman-encoded strings.
    44  var ErrInvalidHuffman = errors.New("hpack: invalid Huffman-encoded data")
    45  
    46  // huffmanDecode decodes v to buf.
    47  // If maxLen is greater than 0, attempts to write more to buf than
    48  // maxLen bytes will return ErrStringLength.
    49  func huffmanDecode(buf *bytes.Buffer, maxLen int, v []byte) error {
    50  	rootHuffmanNode := getRootHuffmanNode()
    51  	n := rootHuffmanNode
    52  	// cur is the bit buffer that has not been fed into n.
    53  	// cbits is the number of low order bits in cur that are valid.
    54  	// sbits is the number of bits of the symbol prefix being decoded.
    55  	cur, cbits, sbits := uint(0), uint8(0), uint8(0)
    56  	for _, b := range v {
    57  		cur = cur<<8 | uint(b)
    58  		cbits += 8
    59  		sbits += 8
    60  		for cbits >= 8 {
    61  			idx := byte(cur >> (cbits - 8))
    62  			n = n.children[idx]
    63  			if n == nil {
    64  				return ErrInvalidHuffman
    65  			}
    66  			if n.children == nil {
    67  				if maxLen != 0 && buf.Len() == maxLen {
    68  					return ErrStringLength
    69  				}
    70  				buf.WriteByte(n.sym)
    71  				cbits -= n.codeLen
    72  				n = rootHuffmanNode
    73  				sbits = cbits
    74  			} else {
    75  				cbits -= 8
    76  			}
    77  		}
    78  	}
    79  	for cbits > 0 {
    80  		n = n.children[byte(cur<<(8-cbits))]
    81  		if n == nil {
    82  			return ErrInvalidHuffman
    83  		}
    84  		if n.children != nil || n.codeLen > cbits {
    85  			break
    86  		}
    87  		if maxLen != 0 && buf.Len() == maxLen {
    88  			return ErrStringLength
    89  		}
    90  		buf.WriteByte(n.sym)
    91  		cbits -= n.codeLen
    92  		n = rootHuffmanNode
    93  		sbits = cbits
    94  	}
    95  	if sbits > 7 {
    96  		// Either there was an incomplete symbol, or overlong padding.
    97  		// Both are decoding errors per RFC 7541 section 5.2.
    98  		return ErrInvalidHuffman
    99  	}
   100  	if mask := uint(1<<cbits - 1); cur&mask != mask {
   101  		// Trailing bits must be a prefix of EOS per RFC 7541 section 5.2.
   102  		return ErrInvalidHuffman
   103  	}
   104  
   105  	return nil
   106  }
   107  
   108  // incomparable is a zero-width, non-comparable type. Adding it to a struct
   109  // makes that struct also non-comparable, and generally doesn't add
   110  // any size (as long as it's first).
   111  type incomparable [0]func()
   112  
   113  type node struct {
   114  	_ incomparable
   115  
   116  	// children is non-nil for internal nodes
   117  	children *[256]*node
   118  
   119  	// The following are only valid if children is nil:
   120  	codeLen uint8 // number of bits that led to the output of sym
   121  	sym     byte  // output symbol
   122  }
   123  
   124  func newInternalNode() *node {
   125  	return &node{children: new([256]*node)}
   126  }
   127  
   128  var (
   129  	buildRootOnce       sync.Once
   130  	lazyRootHuffmanNode *node
   131  )
   132  
   133  func getRootHuffmanNode() *node {
   134  	buildRootOnce.Do(buildRootHuffmanNode)
   135  	return lazyRootHuffmanNode
   136  }
   137  
   138  func buildRootHuffmanNode() {
   139  	if len(huffmanCodes) != 256 {
   140  		panic("unexpected size")
   141  	}
   142  	lazyRootHuffmanNode = newInternalNode()
   143  	// allocate a leaf node for each of the 256 symbols
   144  	leaves := new([256]node)
   145  
   146  	for sym, code := range huffmanCodes {
   147  		codeLen := huffmanCodeLen[sym]
   148  
   149  		cur := lazyRootHuffmanNode
   150  		for codeLen > 8 {
   151  			codeLen -= 8
   152  			i := uint8(code >> codeLen)
   153  			if cur.children[i] == nil {
   154  				cur.children[i] = newInternalNode()
   155  			}
   156  			cur = cur.children[i]
   157  		}
   158  		shift := 8 - codeLen
   159  		start, end := int(uint8(code<<shift)), int(1<<shift)
   160  
   161  		leaves[sym].sym = byte(sym)
   162  		leaves[sym].codeLen = codeLen
   163  		for i := start; i < start+end; i++ {
   164  			cur.children[i] = &leaves[sym]
   165  		}
   166  	}
   167  }
   168  
   169  // AppendHuffmanString appends s, as encoded in Huffman codes, to dst
   170  // and returns the extended buffer.
   171  func AppendHuffmanString(dst []byte, s string) []byte {
   172  	// This relies on the maximum huffman code length being 30 (See tables.go huffmanCodeLen array)
   173  	// So if a uint64 buffer has less than 32 valid bits can always accommodate another huffmanCode.
   174  	var (
   175  		x uint64 // buffer
   176  		n uint   // number valid of bits present in x
   177  	)
   178  	for i := 0; i < len(s); i++ {
   179  		c := s[i]
   180  		n += uint(huffmanCodeLen[c])
   181  		x <<= huffmanCodeLen[c] % 64
   182  		x |= uint64(huffmanCodes[c])
   183  		if n >= 32 {
   184  			n %= 32             // Normally would be -= 32 but %= 32 informs compiler 0 <= n <= 31 for upcoming shift
   185  			y := uint32(x >> n) // Compiler doesn't combine memory writes if y isn't uint32
   186  			dst = append(dst, byte(y>>24), byte(y>>16), byte(y>>8), byte(y))
   187  		}
   188  	}
   189  	// Add padding bits if necessary
   190  	if over := n % 8; over > 0 {
   191  		const (
   192  			eosCode    = 0x3fffffff
   193  			eosNBits   = 30
   194  			eosPadByte = eosCode >> (eosNBits - 8)
   195  		)
   196  		pad := 8 - over
   197  		x = (x << pad) | (eosPadByte >> over)
   198  		n += pad // 8 now divides into n exactly
   199  	}
   200  	// n in (0, 8, 16, 24, 32)
   201  	switch n / 8 {
   202  	case 0:
   203  		return dst
   204  	case 1:
   205  		return append(dst, byte(x))
   206  	case 2:
   207  		y := uint16(x)
   208  		return append(dst, byte(y>>8), byte(y))
   209  	case 3:
   210  		y := uint16(x >> 8)
   211  		return append(dst, byte(y>>8), byte(y), byte(x))
   212  	}
   213  	//	case 4:
   214  	y := uint32(x)
   215  	return append(dst, byte(y>>24), byte(y>>16), byte(y>>8), byte(y))
   216  }
   217  
   218  // HuffmanEncodeLength returns the number of bytes required to encode
   219  // s in Huffman codes. The result is round up to byte boundary.
   220  func HuffmanEncodeLength(s string) uint64 {
   221  	n := uint64(0)
   222  	for i := 0; i < len(s); i++ {
   223  		n += uint64(huffmanCodeLen[s[i]])
   224  	}
   225  	return (n + 7) / 8
   226  }