github.com/lestrrat-go/jwx/v2@v2.0.21/jwe/compress.go (about)

     1  package jwe
     2  
     3  import (
     4  	"bytes"
     5  	"compress/flate"
     6  	"fmt"
     7  	"io"
     8  
     9  	"github.com/lestrrat-go/jwx/v2/internal/pool"
    10  )
    11  
    12  func uncompress(src []byte, maxBufferSize int64) ([]byte, error) {
    13  	var dst bytes.Buffer
    14  	r := flate.NewReader(bytes.NewReader(src))
    15  	defer r.Close()
    16  	var buf [16384]byte
    17  	var sofar int64
    18  	for {
    19  		n, readErr := r.Read(buf[:])
    20  		sofar += int64(n)
    21  		if sofar > maxBufferSize {
    22  			return nil, fmt.Errorf(`compressed payload exceeds maximum allowed size`)
    23  		}
    24  		if readErr != nil {
    25  			// if we have a read error, and it's not EOF, then we need to stop
    26  			if readErr != io.EOF {
    27  				return nil, fmt.Errorf(`failed to read inflated data: %w`, readErr)
    28  			}
    29  		}
    30  
    31  		if _, err := dst.Write(buf[:n]); err != nil {
    32  			return nil, fmt.Errorf(`failed to write inflated data: %w`, err)
    33  		}
    34  
    35  		if readErr != nil {
    36  			// if it got here, then readErr == io.EOF, we're done
    37  			//nolint:nilerr
    38  			return dst.Bytes(), nil
    39  		}
    40  	}
    41  }
    42  
    43  func compress(plaintext []byte) ([]byte, error) {
    44  	buf := pool.GetBytesBuffer()
    45  	defer pool.ReleaseBytesBuffer(buf)
    46  
    47  	w, _ := flate.NewWriter(buf, 1)
    48  	in := plaintext
    49  	for len(in) > 0 {
    50  		n, err := w.Write(in)
    51  		if err != nil {
    52  			return nil, fmt.Errorf(`failed to write to compression writer: %w`, err)
    53  		}
    54  		in = in[n:]
    55  	}
    56  	if err := w.Close(); err != nil {
    57  		return nil, fmt.Errorf(`failed to close compression writer: %w`, err)
    58  	}
    59  
    60  	ret := make([]byte, buf.Len())
    61  	copy(ret, buf.Bytes())
    62  	return ret, nil
    63  }