github.com/bodgit/sevenzip@v1.5.1/internal/bcj2/reader.go (about)

     1  package bcj2
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/binary"
     6  	"errors"
     7  	"io"
     8  
     9  	"github.com/bodgit/sevenzip/internal/util"
    10  	"github.com/hashicorp/go-multierror"
    11  )
    12  
    13  const (
    14  	numMoveBits               = 5
    15  	numbitModelTotalBits      = 11
    16  	bitModelTotal        uint = 1 << numbitModelTotalBits
    17  	numTopBits                = 24
    18  	topValue             uint = 1 << numTopBits
    19  )
    20  
    21  func isJcc(b0, b1 byte) bool {
    22  	return b0 == 0x0f && (b1&0xf0) == 0x80
    23  }
    24  
    25  func isJ(b0, b1 byte) bool {
    26  	return (b1&0xfe) == 0xe8 || isJcc(b0, b1)
    27  }
    28  
    29  func index(b0, b1 byte) int {
    30  	switch b1 {
    31  	case 0xe8:
    32  		return int(b0)
    33  	case 0xe9:
    34  		return 256
    35  	default:
    36  		return 257
    37  	}
    38  }
    39  
    40  type readCloser struct {
    41  	main util.ReadCloser
    42  	call io.ReadCloser
    43  	jump io.ReadCloser
    44  
    45  	rd     util.ReadCloser
    46  	nrange uint
    47  	code   uint
    48  
    49  	sd [256 + 2]uint
    50  
    51  	previous byte
    52  	written  uint64
    53  
    54  	buf *bytes.Buffer
    55  }
    56  
    57  // NewReader returns a new BCJ2 io.ReadCloser.
    58  func NewReader(_ []byte, _ uint64, readers []io.ReadCloser) (io.ReadCloser, error) {
    59  	if len(readers) != 4 {
    60  		return nil, errors.New("bcj2: need exactly four readers")
    61  	}
    62  
    63  	rc := &readCloser{
    64  		main:   util.ByteReadCloser(readers[0]),
    65  		call:   readers[1],
    66  		jump:   readers[2],
    67  		rd:     util.ByteReadCloser(readers[3]),
    68  		nrange: 0xffffffff,
    69  		buf:    new(bytes.Buffer),
    70  	}
    71  	rc.buf.Grow(1 << 16)
    72  
    73  	b := make([]byte, 5)
    74  	if _, err := io.ReadFull(rc.rd, b); err != nil {
    75  		return nil, err
    76  	}
    77  
    78  	for _, x := range b {
    79  		rc.code = (rc.code << 8) | uint(x)
    80  	}
    81  
    82  	for i := range rc.sd {
    83  		rc.sd[i] = bitModelTotal >> 1
    84  	}
    85  
    86  	return rc, nil
    87  }
    88  
    89  func (rc *readCloser) Close() error {
    90  	var err *multierror.Error
    91  	if rc.main != nil {
    92  		err = multierror.Append(err, rc.main.Close(), rc.call.Close(), rc.jump.Close(), rc.rd.Close())
    93  	}
    94  
    95  	return err.ErrorOrNil()
    96  }
    97  
    98  func (rc *readCloser) Read(p []byte) (int, error) {
    99  	if rc.main == nil {
   100  		return 0, errors.New("bcj2: Read after Close")
   101  	}
   102  
   103  	if err := rc.read(); err != nil && !errors.Is(err, io.EOF) {
   104  		return 0, err
   105  	}
   106  
   107  	return rc.buf.Read(p)
   108  }
   109  
   110  func (rc *readCloser) update() error {
   111  	if rc.nrange < topValue {
   112  		b, err := rc.rd.ReadByte()
   113  		if err != nil {
   114  			return err
   115  		}
   116  
   117  		rc.code = (rc.code << 8) | uint(b)
   118  		rc.nrange <<= 8
   119  	}
   120  
   121  	return nil
   122  }
   123  
   124  func (rc *readCloser) decode(i int) (bool, error) {
   125  	newBound := (rc.nrange >> numbitModelTotalBits) * rc.sd[i]
   126  
   127  	if rc.code < newBound {
   128  		rc.nrange = newBound
   129  		rc.sd[i] += (bitModelTotal - rc.sd[i]) >> numMoveBits
   130  
   131  		if err := rc.update(); err != nil {
   132  			return false, err
   133  		}
   134  
   135  		return false, nil
   136  	}
   137  
   138  	rc.nrange -= newBound
   139  	rc.code -= newBound
   140  	rc.sd[i] -= rc.sd[i] >> numMoveBits
   141  
   142  	if err := rc.update(); err != nil {
   143  		return false, err
   144  	}
   145  
   146  	return true, nil
   147  }
   148  
   149  func (rc *readCloser) read() error {
   150  	var (
   151  		b   byte
   152  		err error
   153  	)
   154  
   155  	for {
   156  		if b, err = rc.main.ReadByte(); err != nil {
   157  			return err
   158  		}
   159  
   160  		rc.written++
   161  		_ = rc.buf.WriteByte(b)
   162  
   163  		if isJ(rc.previous, b) {
   164  			break
   165  		}
   166  
   167  		rc.previous = b
   168  
   169  		if rc.buf.Len() == rc.buf.Cap() {
   170  			return nil
   171  		}
   172  	}
   173  
   174  	bit, err := rc.decode(index(rc.previous, b))
   175  	if err != nil {
   176  		return err
   177  	}
   178  
   179  	if bit {
   180  		var r io.Reader
   181  		if b == 0xe8 {
   182  			r = rc.call
   183  		} else {
   184  			r = rc.jump
   185  		}
   186  
   187  		var dest uint32
   188  		if err = binary.Read(r, binary.BigEndian, &dest); err != nil {
   189  			return err
   190  		}
   191  
   192  		dest -= uint32(rc.written + 4)
   193  		_ = binary.Write(rc.buf, binary.LittleEndian, dest)
   194  
   195  		rc.previous = byte(dest >> 24)
   196  		rc.written += 4
   197  	} else {
   198  		rc.previous = b
   199  	}
   200  
   201  	return nil
   202  }