github.com/bir3/gocompiler@v0.9.2202/extra/compress/zstd/seqdec_amd64.go (about)

     1  //go:build amd64 && !appengine && !noasm && gc
     2  // +build amd64,!appengine,!noasm,gc
     3  
     4  package zstd
     5  
     6  import (
     7  	"fmt"
     8  	"io"
     9  
    10  	"github.com/bir3/gocompiler/extra/compress/internal/cpuinfo"
    11  )
    12  
    13  type decodeSyncAsmContext struct {
    14  	llTable     []decSymbol
    15  	mlTable     []decSymbol
    16  	ofTable     []decSymbol
    17  	llState     uint64
    18  	mlState     uint64
    19  	ofState     uint64
    20  	iteration   int
    21  	litRemain   int
    22  	out         []byte
    23  	outPosition int
    24  	literals    []byte
    25  	litPosition int
    26  	history     []byte
    27  	windowSize  int
    28  	ll          int // set on error (not for all errors, please refer to _generate/gen.go)
    29  	ml          int // set on error (not for all errors, please refer to _generate/gen.go)
    30  	mo          int // set on error (not for all errors, please refer to _generate/gen.go)
    31  }
    32  
    33  // sequenceDecs_decodeSync_amd64 implements the main loop of sequenceDecs.decodeSync in x86 asm.
    34  //
    35  // Please refer to seqdec_generic.go for the reference implementation.
    36  //
    37  //go:noescape
    38  func sequenceDecs_decodeSync_amd64(s *sequenceDecs, br *bitReader, ctx *decodeSyncAsmContext) int
    39  
    40  // sequenceDecs_decodeSync_bmi2 implements the main loop of sequenceDecs.decodeSync in x86 asm with BMI2 extensions.
    41  //
    42  //go:noescape
    43  func sequenceDecs_decodeSync_bmi2(s *sequenceDecs, br *bitReader, ctx *decodeSyncAsmContext) int
    44  
    45  // sequenceDecs_decodeSync_safe_amd64 does the same as above, but does not write more than output buffer.
    46  //
    47  //go:noescape
    48  func sequenceDecs_decodeSync_safe_amd64(s *sequenceDecs, br *bitReader, ctx *decodeSyncAsmContext) int
    49  
    50  // sequenceDecs_decodeSync_safe_bmi2 does the same as above, but does not write more than output buffer.
    51  //
    52  //go:noescape
    53  func sequenceDecs_decodeSync_safe_bmi2(s *sequenceDecs, br *bitReader, ctx *decodeSyncAsmContext) int
    54  
    55  // decode sequences from the stream with the provided history but without a dictionary.
    56  func (s *sequenceDecs) decodeSyncSimple(hist []byte) (bool, error) {
    57  	if len(s.dict) > 0 {
    58  		return false, nil
    59  	}
    60  	if s.maxSyncLen == 0 && cap(s.out)-len(s.out) < maxCompressedBlockSize {
    61  		return false, nil
    62  	}
    63  
    64  	// FIXME: Using unsafe memory copies leads to rare, random crashes
    65  	// with fuzz testing. It is therefore disabled for now.
    66  	const useSafe = true
    67  	/*
    68  		useSafe := false
    69  		if s.maxSyncLen == 0 && cap(s.out)-len(s.out) < maxCompressedBlockSizeAlloc {
    70  			useSafe = true
    71  		}
    72  		if s.maxSyncLen > 0 && cap(s.out)-len(s.out)-compressedBlockOverAlloc < int(s.maxSyncLen) {
    73  			useSafe = true
    74  		}
    75  		if cap(s.literals) < len(s.literals)+compressedBlockOverAlloc {
    76  			useSafe = true
    77  		}
    78  	*/
    79  
    80  	br := s.br
    81  
    82  	maxBlockSize := maxCompressedBlockSize
    83  	if s.windowSize < maxBlockSize {
    84  		maxBlockSize = s.windowSize
    85  	}
    86  
    87  	ctx := decodeSyncAsmContext{
    88  		llTable:     s.litLengths.fse.dt[:maxTablesize],
    89  		mlTable:     s.matchLengths.fse.dt[:maxTablesize],
    90  		ofTable:     s.offsets.fse.dt[:maxTablesize],
    91  		llState:     uint64(s.litLengths.state.state),
    92  		mlState:     uint64(s.matchLengths.state.state),
    93  		ofState:     uint64(s.offsets.state.state),
    94  		iteration:   s.nSeqs - 1,
    95  		litRemain:   len(s.literals),
    96  		out:         s.out,
    97  		outPosition: len(s.out),
    98  		literals:    s.literals,
    99  		windowSize:  s.windowSize,
   100  		history:     hist,
   101  	}
   102  
   103  	s.seqSize = 0
   104  	startSize := len(s.out)
   105  
   106  	var errCode int
   107  	if cpuinfo.HasBMI2() {
   108  		if useSafe {
   109  			errCode = sequenceDecs_decodeSync_safe_bmi2(s, br, &ctx)
   110  		} else {
   111  			errCode = sequenceDecs_decodeSync_bmi2(s, br, &ctx)
   112  		}
   113  	} else {
   114  		if useSafe {
   115  			errCode = sequenceDecs_decodeSync_safe_amd64(s, br, &ctx)
   116  		} else {
   117  			errCode = sequenceDecs_decodeSync_amd64(s, br, &ctx)
   118  		}
   119  	}
   120  	switch errCode {
   121  	case noError:
   122  		break
   123  
   124  	case errorMatchLenOfsMismatch:
   125  		return true, fmt.Errorf("zero matchoff and matchlen (%d) > 0", ctx.ml)
   126  
   127  	case errorMatchLenTooBig:
   128  		return true, fmt.Errorf("match len (%d) bigger than max allowed length", ctx.ml)
   129  
   130  	case errorMatchOffTooBig:
   131  		return true, fmt.Errorf("match offset (%d) bigger than current history (%d)",
   132  			ctx.mo, ctx.outPosition+len(hist)-startSize)
   133  
   134  	case errorNotEnoughLiterals:
   135  		return true, fmt.Errorf("unexpected literal count, want %d bytes, but only %d is available",
   136  			ctx.ll, ctx.litRemain+ctx.ll)
   137  
   138  	case errorOverread:
   139  		return true, io.ErrUnexpectedEOF
   140  
   141  	case errorNotEnoughSpace:
   142  		size := ctx.outPosition + ctx.ll + ctx.ml
   143  		if debugDecoder {
   144  			println("msl:", s.maxSyncLen, "cap", cap(s.out), "bef:", startSize, "sz:", size-startSize, "mbs:", maxBlockSize, "outsz:", cap(s.out)-startSize)
   145  		}
   146  		return true, fmt.Errorf("output bigger than max block size (%d)", maxBlockSize)
   147  
   148  	default:
   149  		return true, fmt.Errorf("sequenceDecs_decode returned erronous code %d", errCode)
   150  	}
   151  
   152  	s.seqSize += ctx.litRemain
   153  	if s.seqSize > maxBlockSize {
   154  		return true, fmt.Errorf("output bigger than max block size (%d)", maxBlockSize)
   155  	}
   156  	err := br.close()
   157  	if err != nil {
   158  		printf("Closing sequences: %v, %+v\n", err, *br)
   159  		return true, err
   160  	}
   161  
   162  	s.literals = s.literals[ctx.litPosition:]
   163  	t := ctx.outPosition
   164  	s.out = s.out[:t]
   165  
   166  	// Add final literals
   167  	s.out = append(s.out, s.literals...)
   168  	if debugDecoder {
   169  		t += len(s.literals)
   170  		if t != len(s.out) {
   171  			panic(fmt.Errorf("length mismatch, want %d, got %d", len(s.out), t))
   172  		}
   173  	}
   174  
   175  	return true, nil
   176  }
   177  
   178  // --------------------------------------------------------------------------------
   179  
   180  type decodeAsmContext struct {
   181  	llTable   []decSymbol
   182  	mlTable   []decSymbol
   183  	ofTable   []decSymbol
   184  	llState   uint64
   185  	mlState   uint64
   186  	ofState   uint64
   187  	iteration int
   188  	seqs      []seqVals
   189  	litRemain int
   190  }
   191  
   192  const noError = 0
   193  
   194  // error reported when mo == 0 && ml > 0
   195  const errorMatchLenOfsMismatch = 1
   196  
   197  // error reported when ml > maxMatchLen
   198  const errorMatchLenTooBig = 2
   199  
   200  // error reported when mo > available history or mo > s.windowSize
   201  const errorMatchOffTooBig = 3
   202  
   203  // error reported when the sum of literal lengths exeeceds the literal buffer size
   204  const errorNotEnoughLiterals = 4
   205  
   206  // error reported when capacity of `out` is too small
   207  const errorNotEnoughSpace = 5
   208  
   209  // error reported when bits are overread.
   210  const errorOverread = 6
   211  
   212  // sequenceDecs_decode implements the main loop of sequenceDecs in x86 asm.
   213  //
   214  // Please refer to seqdec_generic.go for the reference implementation.
   215  //
   216  //go:noescape
   217  func sequenceDecs_decode_amd64(s *sequenceDecs, br *bitReader, ctx *decodeAsmContext) int
   218  
   219  // sequenceDecs_decode implements the main loop of sequenceDecs in x86 asm.
   220  //
   221  // Please refer to seqdec_generic.go for the reference implementation.
   222  //
   223  //go:noescape
   224  func sequenceDecs_decode_56_amd64(s *sequenceDecs, br *bitReader, ctx *decodeAsmContext) int
   225  
   226  // sequenceDecs_decode implements the main loop of sequenceDecs in x86 asm with BMI2 extensions.
   227  //
   228  //go:noescape
   229  func sequenceDecs_decode_bmi2(s *sequenceDecs, br *bitReader, ctx *decodeAsmContext) int
   230  
   231  // sequenceDecs_decode implements the main loop of sequenceDecs in x86 asm with BMI2 extensions.
   232  //
   233  //go:noescape
   234  func sequenceDecs_decode_56_bmi2(s *sequenceDecs, br *bitReader, ctx *decodeAsmContext) int
   235  
   236  // decode sequences from the stream without the provided history.
   237  func (s *sequenceDecs) decode(seqs []seqVals) error {
   238  	br := s.br
   239  
   240  	maxBlockSize := maxCompressedBlockSize
   241  	if s.windowSize < maxBlockSize {
   242  		maxBlockSize = s.windowSize
   243  	}
   244  
   245  	ctx := decodeAsmContext{
   246  		llTable:   s.litLengths.fse.dt[:maxTablesize],
   247  		mlTable:   s.matchLengths.fse.dt[:maxTablesize],
   248  		ofTable:   s.offsets.fse.dt[:maxTablesize],
   249  		llState:   uint64(s.litLengths.state.state),
   250  		mlState:   uint64(s.matchLengths.state.state),
   251  		ofState:   uint64(s.offsets.state.state),
   252  		seqs:      seqs,
   253  		iteration: len(seqs) - 1,
   254  		litRemain: len(s.literals),
   255  	}
   256  
   257  	if debugDecoder {
   258  		println("decode: decoding", len(seqs), "sequences", br.remain(), "bits remain on stream")
   259  	}
   260  
   261  	s.seqSize = 0
   262  	lte56bits := s.maxBits+s.offsets.fse.actualTableLog+s.matchLengths.fse.actualTableLog+s.litLengths.fse.actualTableLog <= 56
   263  	var errCode int
   264  	if cpuinfo.HasBMI2() {
   265  		if lte56bits {
   266  			errCode = sequenceDecs_decode_56_bmi2(s, br, &ctx)
   267  		} else {
   268  			errCode = sequenceDecs_decode_bmi2(s, br, &ctx)
   269  		}
   270  	} else {
   271  		if lte56bits {
   272  			errCode = sequenceDecs_decode_56_amd64(s, br, &ctx)
   273  		} else {
   274  			errCode = sequenceDecs_decode_amd64(s, br, &ctx)
   275  		}
   276  	}
   277  	if errCode != 0 {
   278  		i := len(seqs) - ctx.iteration - 1
   279  		switch errCode {
   280  		case errorMatchLenOfsMismatch:
   281  			ml := ctx.seqs[i].ml
   282  			return fmt.Errorf("zero matchoff and matchlen (%d) > 0", ml)
   283  
   284  		case errorMatchLenTooBig:
   285  			ml := ctx.seqs[i].ml
   286  			return fmt.Errorf("match len (%d) bigger than max allowed length", ml)
   287  
   288  		case errorNotEnoughLiterals:
   289  			ll := ctx.seqs[i].ll
   290  			return fmt.Errorf("unexpected literal count, want %d bytes, but only %d is available", ll, ctx.litRemain+ll)
   291  		case errorOverread:
   292  			return io.ErrUnexpectedEOF
   293  		}
   294  
   295  		return fmt.Errorf("sequenceDecs_decode_amd64 returned erronous code %d", errCode)
   296  	}
   297  
   298  	if ctx.litRemain < 0 {
   299  		return fmt.Errorf("literal count is too big: total available %d, total requested %d",
   300  			len(s.literals), len(s.literals)-ctx.litRemain)
   301  	}
   302  
   303  	s.seqSize += ctx.litRemain
   304  	if s.seqSize > maxBlockSize {
   305  		return fmt.Errorf("output bigger than max block size (%d)", maxBlockSize)
   306  	}
   307  	if debugDecoder {
   308  		println("decode: ", br.remain(), "bits remain on stream. code:", errCode)
   309  	}
   310  	err := br.close()
   311  	if err != nil {
   312  		printf("Closing sequences: %v, %+v\n", err, *br)
   313  	}
   314  	return err
   315  }
   316  
   317  // --------------------------------------------------------------------------------
   318  
   319  type executeAsmContext struct {
   320  	seqs        []seqVals
   321  	seqIndex    int
   322  	out         []byte
   323  	history     []byte
   324  	literals    []byte
   325  	outPosition int
   326  	litPosition int
   327  	windowSize  int
   328  }
   329  
   330  // sequenceDecs_executeSimple_amd64 implements the main loop of sequenceDecs.executeSimple in x86 asm.
   331  //
   332  // Returns false if a match offset is too big.
   333  //
   334  // Please refer to seqdec_generic.go for the reference implementation.
   335  //
   336  //go:noescape
   337  func sequenceDecs_executeSimple_amd64(ctx *executeAsmContext) bool
   338  
   339  // Same as above, but with safe memcopies
   340  //
   341  //go:noescape
   342  func sequenceDecs_executeSimple_safe_amd64(ctx *executeAsmContext) bool
   343  
   344  // executeSimple handles cases when dictionary is not used.
   345  func (s *sequenceDecs) executeSimple(seqs []seqVals, hist []byte) error {
   346  	// Ensure we have enough output size...
   347  	if len(s.out)+s.seqSize+compressedBlockOverAlloc > cap(s.out) {
   348  		addBytes := s.seqSize + len(s.out) + compressedBlockOverAlloc
   349  		s.out = append(s.out, make([]byte, addBytes)...)
   350  		s.out = s.out[:len(s.out)-addBytes]
   351  	}
   352  
   353  	if debugDecoder {
   354  		printf("Execute %d seqs with literals: %d into %d bytes\n", len(seqs), len(s.literals), s.seqSize)
   355  	}
   356  
   357  	var t = len(s.out)
   358  	out := s.out[:t+s.seqSize]
   359  
   360  	ctx := executeAsmContext{
   361  		seqs:        seqs,
   362  		seqIndex:    0,
   363  		out:         out,
   364  		history:     hist,
   365  		outPosition: t,
   366  		litPosition: 0,
   367  		literals:    s.literals,
   368  		windowSize:  s.windowSize,
   369  	}
   370  	var ok bool
   371  	if cap(s.literals) < len(s.literals)+compressedBlockOverAlloc {
   372  		ok = sequenceDecs_executeSimple_safe_amd64(&ctx)
   373  	} else {
   374  		ok = sequenceDecs_executeSimple_amd64(&ctx)
   375  	}
   376  	if !ok {
   377  		return fmt.Errorf("match offset (%d) bigger than current history (%d)",
   378  			seqs[ctx.seqIndex].mo, ctx.outPosition+len(hist))
   379  	}
   380  	s.literals = s.literals[ctx.litPosition:]
   381  	t = ctx.outPosition
   382  
   383  	// Add final literals
   384  	copy(out[t:], s.literals)
   385  	if debugDecoder {
   386  		t += len(s.literals)
   387  		if t != len(out) {
   388  			panic(fmt.Errorf("length mismatch, want %d, got %d, ss: %d", len(out), t, s.seqSize))
   389  		}
   390  	}
   391  	s.out = out
   392  
   393  	return nil
   394  }