github.com/google/syzkaller@v0.0.0-20240517125934-c0f1611a36d6/pkg/image/compression_optimized.go (about)

     1  // Copyright 2022 syzkaller project authors. All rights reserved.
     2  // Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
     3  
     4  //go:build !windows && !386 && !arm
     5  
     6  package image
     7  
     8  import (
     9  	"bytes"
    10  	"compress/zlib"
    11  	"fmt"
    12  	"io"
    13  	"sync"
    14  	"syscall"
    15  	"unsafe"
    16  )
    17  
    18  // Temporary scratch data used by the decompression procedure.
    19  type decompressScratch struct {
    20  	r   bytes.Reader
    21  	zr  io.Reader
    22  	buf []byte
    23  }
    24  
    25  var decompressPool = sync.Pool{New: func() interface{} {
    26  	return &decompressScratch{
    27  		buf: make([]byte, 8<<10),
    28  	}
    29  }}
    30  
    31  func mustDecompress(compressed []byte) (data []byte, dtor func()) {
    32  	// Optimized decompression procedure that is ~2x faster than a naive version
    33  	// and consumes significantly less memory and generates less garbage.
    34  	// Images tend to contain lots of 0s, especially the larger images.
    35  	// The main idea is that we mmap a buffer and then don't write 0s into it
    36  	// (since it already contains all 0s). As the result if a page is all 0s
    37  	// then we don't page it in and don't consume memory for it.
    38  	// Executor uses the same optimization during decompression.
    39  	scratch := decompressPool.Get().(*decompressScratch)
    40  	defer decompressPool.Put(scratch)
    41  	scratch.r.Reset(compressed)
    42  	if scratch.zr == nil {
    43  		zr, err := zlib.NewReader(&scratch.r)
    44  		if err != nil {
    45  			panic(err)
    46  		}
    47  		scratch.zr = zr
    48  	} else {
    49  		if err := scratch.zr.(zlib.Resetter).Reset(&scratch.r, nil); err != nil {
    50  			panic(err)
    51  		}
    52  	}
    53  	// We don't know the size of the uncompressed image.
    54  	// We could uncompress it into ioutil.Discard first, then allocate memory and uncompress second time
    55  	// (and it's still faster than the naive uncompress into bytes.Buffer!).
    56  	// But we know maximum size of images, so just mmap the max size.
    57  	// It's fast and unused part does not consume memory.
    58  	// Note: executor/common_zlib.h also knows this const.
    59  	const maxImageSize = 132 << 20
    60  	var err error
    61  	data, err = syscall.Mmap(-1, 0, maxImageSize, syscall.PROT_READ|syscall.PROT_WRITE,
    62  		syscall.MAP_ANON|syscall.MAP_PRIVATE)
    63  	if err != nil {
    64  		panic(err)
    65  	}
    66  	dtor = func() {
    67  		if err := syscall.Munmap(data[:maxImageSize]); err != nil {
    68  			panic(err)
    69  		}
    70  	}
    71  	offset := 0
    72  	for {
    73  		n, err := scratch.zr.Read(scratch.buf)
    74  		if err != nil && err != io.EOF {
    75  			panic(err)
    76  		}
    77  		if n == 0 {
    78  			break
    79  		}
    80  		if offset+n > len(data) {
    81  			panic(fmt.Sprintf("bad image size: offset=%v n=%v data=%v", offset, n, len(data)))
    82  		}
    83  		// Copy word-at-a-time and avoid bounds checks in the loop,
    84  		// this is considerably faster than a naive byte loop.
    85  		// We already checked bounds above.
    86  		type word uint64
    87  		const wordSize = unsafe.Sizeof(word(0))
    88  		// Don't copy the last word b/c otherwise we calculate pointer outside of scratch.buf object
    89  		// on the last iteration. We don't use it, but unsafe rules prohibit even calculating
    90  		// such pointers. Alternatively we could add 8 unused bytes to scratch.buf, but it will
    91  		// play badly with memory allocator size classes (it will consume whole additional page,
    92  		// or whatever is the alignment for such large objects). We could also break from the middle
    93  		// of the loop before updating src/dst pointers, but it hurts codegen a lot (compilers like
    94  		// canonical loop forms).
    95  		words := uintptr(n-1) / wordSize
    96  		src := (*word)(unsafe.Pointer(&scratch.buf[0]))
    97  		dst := (*word)(unsafe.Pointer(&data[offset]))
    98  		for i := uintptr(0); i < words; i++ {
    99  			if *src != 0 {
   100  				*dst = *src
   101  			}
   102  			src = (*word)(unsafe.Pointer(uintptr(unsafe.Pointer(src)) + wordSize))
   103  			dst = (*word)(unsafe.Pointer(uintptr(unsafe.Pointer(dst)) + wordSize))
   104  		}
   105  		// Copy any remaining trailing bytes.
   106  		for i := words * wordSize; i < uintptr(n); i++ {
   107  			v := scratch.buf[i]
   108  			if v != 0 {
   109  				data[uintptr(offset)+i] = v
   110  			}
   111  		}
   112  		offset += n
   113  	}
   114  	data = data[:offset]
   115  	return
   116  }