github.com/apache/arrow/go/v14@v14.0.2/parquet/internal/encoding/encoder.go (about)

     1  // Licensed to the Apache Software Foundation (ASF) under one
     2  // or more contributor license agreements.  See the NOTICE file
     3  // distributed with this work for additional information
     4  // regarding copyright ownership.  The ASF licenses this file
     5  // to you under the Apache License, Version 2.0 (the
     6  // "License"); you may not use this file except in compliance
     7  // with the License.  You may obtain a copy of the License at
     8  //
     9  // http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  
    17  package encoding
    18  
    19  import (
    20  	"fmt"
    21  	"math/bits"
    22  	"reflect"
    23  
    24  	"github.com/apache/arrow/go/v14/arrow"
    25  	"github.com/apache/arrow/go/v14/arrow/bitutil"
    26  	"github.com/apache/arrow/go/v14/arrow/memory"
    27  	"github.com/apache/arrow/go/v14/internal/bitutils"
    28  	"github.com/apache/arrow/go/v14/parquet"
    29  	format "github.com/apache/arrow/go/v14/parquet/internal/gen-go/parquet"
    30  	"github.com/apache/arrow/go/v14/parquet/internal/utils"
    31  	"github.com/apache/arrow/go/v14/parquet/schema"
    32  )
    33  
    34  //go:generate go run ../../../arrow/_tools/tmpl/main.go -i -data=physical_types.tmpldata plain_encoder_types.gen.go.tmpl typed_encoder.gen.go.tmpl
    35  
    36  // EncoderTraits is an interface for the different types to make it more
    37  // convenient to construct encoders for specific types.
    38  type EncoderTraits interface {
    39  	Encoder(format.Encoding, bool, *schema.Column, memory.Allocator) TypedEncoder
    40  }
    41  
    42  // NewEncoder will return the appropriately typed encoder for the requested physical type
    43  // and encoding.
    44  //
    45  // If mem is nil, memory.DefaultAllocator will be used.
    46  func NewEncoder(t parquet.Type, e parquet.Encoding, useDict bool, descr *schema.Column, mem memory.Allocator) TypedEncoder {
    47  	traits := getEncodingTraits(t)
    48  	if traits == nil {
    49  		return nil
    50  	}
    51  
    52  	if mem == nil {
    53  		mem = memory.DefaultAllocator
    54  	}
    55  	return traits.Encoder(format.Encoding(e), useDict, descr, mem)
    56  }
    57  
    58  type encoder struct {
    59  	descr    *schema.Column
    60  	encoding format.Encoding
    61  	typeLen  int
    62  	mem      memory.Allocator
    63  
    64  	sink *PooledBufferWriter
    65  }
    66  
    67  // newEncoderBase constructs a new base encoder for embedding on the typed encoders
    68  // encapsulating the common functionality.
    69  func newEncoderBase(e format.Encoding, descr *schema.Column, mem memory.Allocator) encoder {
    70  	typelen := -1
    71  	if descr != nil && descr.PhysicalType() == parquet.Types.FixedLenByteArray {
    72  		typelen = int(descr.TypeLength())
    73  	}
    74  	return encoder{
    75  		descr:    descr,
    76  		encoding: e,
    77  		mem:      mem,
    78  		typeLen:  typelen,
    79  		sink:     NewPooledBufferWriter(1024),
    80  	}
    81  }
    82  
    83  func (e *encoder) Release() {
    84  	poolbuf := e.sink.buf
    85  	memory.Set(poolbuf.Buf(), 0)
    86  	poolbuf.ResizeNoShrink(0)
    87  	bufferPool.Put(poolbuf)
    88  	e.sink = nil
    89  }
    90  
    91  // ReserveForWrite allocates n bytes so that the next n bytes written do not require new allocations.
    92  func (e *encoder) ReserveForWrite(n int)           { e.sink.Reserve(n) }
    93  func (e *encoder) EstimatedDataEncodedSize() int64 { return int64(e.sink.Len()) }
    94  func (e *encoder) Encoding() parquet.Encoding      { return parquet.Encoding(e.encoding) }
    95  func (e *encoder) Allocator() memory.Allocator     { return e.mem }
    96  func (e *encoder) append(data []byte)              { e.sink.Write(data) }
    97  
    98  // FlushValues flushes any unwritten data to the buffer and returns the finished encoded buffer of data.
    99  // This also clears the encoder, ownership of the data belongs to whomever called FlushValues, Release
   100  // should be called on the resulting Buffer when done.
   101  func (e *encoder) FlushValues() (Buffer, error) { return e.sink.Finish(), nil }
   102  
   103  // Bytes returns the current bytes that have been written to the encoder's buffer but doesn't transfer ownership.
   104  func (e *encoder) Bytes() []byte { return e.sink.Bytes() }
   105  
   106  // Reset drops the data currently in the encoder and resets for new use.
   107  func (e *encoder) Reset() { e.sink.Reset(0) }
   108  
   109  type dictEncoder struct {
   110  	encoder
   111  
   112  	dictEncodedSize int
   113  	idxBuffer       *memory.Buffer
   114  	idxValues       []int32
   115  	memo            MemoTable
   116  
   117  	preservedDict arrow.Array
   118  }
   119  
   120  // newDictEncoderBase constructs and returns a dictionary encoder for the appropriate type using the passed
   121  // in memo table for constructing the index.
   122  func newDictEncoderBase(descr *schema.Column, memo MemoTable, mem memory.Allocator) dictEncoder {
   123  	return dictEncoder{
   124  		encoder:   newEncoderBase(format.Encoding_PLAIN_DICTIONARY, descr, mem),
   125  		idxBuffer: memory.NewResizableBuffer(mem),
   126  		memo:      memo,
   127  	}
   128  }
   129  
   130  // Reset drops all the currently encoded values from the index and indexes from the data to allow
   131  // restarting the encoding process.
   132  func (d *dictEncoder) Reset() {
   133  	d.encoder.Reset()
   134  	d.dictEncodedSize = 0
   135  	d.idxValues = d.idxValues[:0]
   136  	d.idxBuffer.ResizeNoShrink(0)
   137  	d.memo.Reset()
   138  	if d.preservedDict != nil {
   139  		d.preservedDict.Release()
   140  		d.preservedDict = nil
   141  	}
   142  }
   143  
   144  func (d *dictEncoder) Release() {
   145  	d.encoder.Release()
   146  	d.idxBuffer.Release()
   147  	if m, ok := d.memo.(BinaryMemoTable); ok {
   148  		m.Release()
   149  	} else {
   150  		d.memo.Reset()
   151  	}
   152  	if d.preservedDict != nil {
   153  		d.preservedDict.Release()
   154  		d.preservedDict = nil
   155  	}
   156  }
   157  
   158  func (d *dictEncoder) expandBuffer(newCap int) {
   159  	if cap(d.idxValues) >= newCap {
   160  		return
   161  	}
   162  
   163  	curLen := len(d.idxValues)
   164  	d.idxBuffer.ResizeNoShrink(arrow.Int32Traits.BytesRequired(bitutil.NextPowerOf2(newCap)))
   165  	d.idxValues = arrow.Int32Traits.CastFromBytes(d.idxBuffer.Buf())[: curLen : d.idxBuffer.Len()/arrow.Int32SizeBytes]
   166  }
   167  
   168  func (d *dictEncoder) PutIndices(data arrow.Array) error {
   169  	newValues := data.Len() - data.NullN()
   170  	curPos := len(d.idxValues)
   171  	newLen := newValues + curPos
   172  	d.expandBuffer(newLen)
   173  	d.idxValues = d.idxValues[:newLen:cap(d.idxValues)]
   174  
   175  	switch data.DataType().ID() {
   176  	case arrow.UINT8, arrow.INT8:
   177  		values := arrow.Uint8Traits.CastFromBytes(data.Data().Buffers()[1].Bytes())[data.Data().Offset():]
   178  		bitutils.VisitSetBitRunsNoErr(data.NullBitmapBytes(),
   179  			int64(data.Data().Offset()), int64(data.Len()),
   180  			func(pos, length int64) {
   181  				for i := int64(0); i < length; i++ {
   182  					d.idxValues[curPos] = int32(values[i+pos])
   183  					curPos++
   184  				}
   185  			})
   186  	case arrow.UINT16, arrow.INT16:
   187  		values := arrow.Uint16Traits.CastFromBytes(data.Data().Buffers()[1].Bytes())[data.Data().Offset():]
   188  		bitutils.VisitSetBitRunsNoErr(data.NullBitmapBytes(),
   189  			int64(data.Data().Offset()), int64(data.Len()),
   190  			func(pos, length int64) {
   191  				for i := int64(0); i < length; i++ {
   192  					d.idxValues[curPos] = int32(values[i+pos])
   193  					curPos++
   194  				}
   195  			})
   196  	case arrow.UINT32, arrow.INT32:
   197  		values := arrow.Uint32Traits.CastFromBytes(data.Data().Buffers()[1].Bytes())[data.Data().Offset():]
   198  		bitutils.VisitSetBitRunsNoErr(data.NullBitmapBytes(),
   199  			int64(data.Data().Offset()), int64(data.Len()),
   200  			func(pos, length int64) {
   201  				for i := int64(0); i < length; i++ {
   202  					d.idxValues[curPos] = int32(values[i+pos])
   203  					curPos++
   204  				}
   205  			})
   206  	case arrow.UINT64, arrow.INT64:
   207  		values := arrow.Uint64Traits.CastFromBytes(data.Data().Buffers()[1].Bytes())[data.Data().Offset():]
   208  		bitutils.VisitSetBitRunsNoErr(data.NullBitmapBytes(),
   209  			int64(data.Data().Offset()), int64(data.Len()),
   210  			func(pos, length int64) {
   211  				for i := int64(0); i < length; i++ {
   212  					d.idxValues[curPos] = int32(values[i+pos])
   213  					curPos++
   214  				}
   215  			})
   216  	default:
   217  		return fmt.Errorf("%w: passed non-integer array to PutIndices", arrow.ErrInvalid)
   218  	}
   219  
   220  	return nil
   221  }
   222  
   223  // append the passed index to the indexbuffer
   224  func (d *dictEncoder) addIndex(idx int) {
   225  	curLen := len(d.idxValues)
   226  	d.expandBuffer(curLen + 1)
   227  	d.idxValues = append(d.idxValues, int32(idx))
   228  }
   229  
   230  // FlushValues dumps all the currently buffered indexes that would become the data page to a buffer and
   231  // returns it or returns nil and any error encountered.
   232  func (d *dictEncoder) FlushValues() (Buffer, error) {
   233  	buf := bufferPool.Get().(*memory.Buffer)
   234  	buf.Reserve(int(d.EstimatedDataEncodedSize()))
   235  	size, err := d.WriteIndices(buf.Buf())
   236  	if err != nil {
   237  		poolBuffer{buf}.Release()
   238  		return nil, err
   239  	}
   240  	buf.ResizeNoShrink(size)
   241  	return poolBuffer{buf}, nil
   242  }
   243  
   244  // EstimatedDataEncodedSize returns the maximum number of bytes needed to store the RLE encoded indexes, not including the
   245  // dictionary index in the computation.
   246  func (d *dictEncoder) EstimatedDataEncodedSize() int64 {
   247  	return 1 + int64(utils.MaxBufferSize(d.BitWidth(), len(d.idxValues))+utils.MinBufferSize(d.BitWidth()))
   248  }
   249  
   250  // NumEntries returns the number of entires in the dictionary index for this encoder.
   251  func (d *dictEncoder) NumEntries() int {
   252  	return d.memo.Size()
   253  }
   254  
   255  // BitWidth returns the max bitwidth that would be necessary for encoding the index values currently
   256  // in the dictionary based on the size of the dictionary index.
   257  func (d *dictEncoder) BitWidth() int {
   258  	switch d.NumEntries() {
   259  	case 0:
   260  		return 0
   261  	case 1:
   262  		return 1
   263  	default:
   264  		return bits.Len32(uint32(d.NumEntries() - 1))
   265  	}
   266  }
   267  
   268  // WriteDict writes the dictionary index to the given byte slice.
   269  func (d *dictEncoder) WriteDict(out []byte) {
   270  	d.memo.WriteOut(out)
   271  }
   272  
   273  // WriteIndices performs Run Length encoding on the indexes and the writes the encoded
   274  // index value data to the provided byte slice, returning the number of bytes actually written.
   275  // If any error is encountered, it will return -1 and the error.
   276  func (d *dictEncoder) WriteIndices(out []byte) (int, error) {
   277  	out[0] = byte(d.BitWidth())
   278  
   279  	enc := utils.NewRleEncoder(utils.NewWriterAtBuffer(out[1:]), d.BitWidth())
   280  	for _, idx := range d.idxValues {
   281  		if err := enc.Put(uint64(idx)); err != nil {
   282  			return -1, err
   283  		}
   284  	}
   285  	nbytes := enc.Flush()
   286  
   287  	d.idxValues = d.idxValues[:0]
   288  	return nbytes + 1, nil
   289  }
   290  
   291  // Put adds a value to the dictionary data column, inserting the value if it
   292  // didn't already exist in the dictionary.
   293  func (d *dictEncoder) Put(v interface{}) {
   294  	memoIdx, found, err := d.memo.GetOrInsert(v)
   295  	if err != nil {
   296  		panic(err)
   297  	}
   298  	if !found {
   299  		d.dictEncodedSize += int(reflect.TypeOf(v).Size())
   300  	}
   301  	d.addIndex(memoIdx)
   302  }
   303  
   304  // DictEncodedSize returns the current size of the encoded dictionary
   305  func (d *dictEncoder) DictEncodedSize() int {
   306  	return d.dictEncodedSize
   307  }
   308  
   309  func (d *dictEncoder) canPutDictionary(values arrow.Array) error {
   310  	switch {
   311  	case values.NullN() > 0:
   312  		return fmt.Errorf("%w: inserted dictionary cannot contain nulls",
   313  			arrow.ErrInvalid)
   314  	case d.NumEntries() > 0:
   315  		return fmt.Errorf("%w: can only call PutDictionary on an empty DictEncoder",
   316  			arrow.ErrInvalid)
   317  	}
   318  
   319  	return nil
   320  }
   321  
   322  func (d *dictEncoder) PreservedDictionary() arrow.Array { return d.preservedDict }
   323  
   324  // spacedCompress is a helper function for encoders to remove the slots in the slices passed in according
   325  // to the bitmap which are null into an output slice that is no longer spaced out with slots for nulls.
   326  func spacedCompress(src, out interface{}, validBits []byte, validBitsOffset int64) int {
   327  	nvalid := 0
   328  
   329  	// for efficiency we use a type switch because the copy runs significantly faster when typed
   330  	// than calling reflect.Copy
   331  	switch s := src.(type) {
   332  	case []int32:
   333  		o := out.([]int32)
   334  		reader := bitutils.NewSetBitRunReader(validBits, validBitsOffset, int64(len(s)))
   335  		for {
   336  			run := reader.NextRun()
   337  			if run.Length == 0 {
   338  				break
   339  			}
   340  			copy(o[nvalid:], s[int(run.Pos):int(run.Pos+run.Length)])
   341  			nvalid += int(run.Length)
   342  		}
   343  	case []int64:
   344  		o := out.([]int64)
   345  		reader := bitutils.NewSetBitRunReader(validBits, validBitsOffset, int64(len(s)))
   346  		for {
   347  			run := reader.NextRun()
   348  			if run.Length == 0 {
   349  				break
   350  			}
   351  			copy(o[nvalid:], s[int(run.Pos):int(run.Pos+run.Length)])
   352  			nvalid += int(run.Length)
   353  		}
   354  	case []float32:
   355  		o := out.([]float32)
   356  		reader := bitutils.NewSetBitRunReader(validBits, validBitsOffset, int64(len(s)))
   357  		for {
   358  			run := reader.NextRun()
   359  			if run.Length == 0 {
   360  				break
   361  			}
   362  			copy(o[nvalid:], s[int(run.Pos):int(run.Pos+run.Length)])
   363  			nvalid += int(run.Length)
   364  		}
   365  	case []float64:
   366  		o := out.([]float64)
   367  		reader := bitutils.NewSetBitRunReader(validBits, validBitsOffset, int64(len(s)))
   368  		for {
   369  			run := reader.NextRun()
   370  			if run.Length == 0 {
   371  				break
   372  			}
   373  			copy(o[nvalid:], s[int(run.Pos):int(run.Pos+run.Length)])
   374  			nvalid += int(run.Length)
   375  		}
   376  	case []parquet.ByteArray:
   377  		o := out.([]parquet.ByteArray)
   378  		reader := bitutils.NewSetBitRunReader(validBits, validBitsOffset, int64(len(s)))
   379  		for {
   380  			run := reader.NextRun()
   381  			if run.Length == 0 {
   382  				break
   383  			}
   384  			copy(o[nvalid:], s[int(run.Pos):int(run.Pos+run.Length)])
   385  			nvalid += int(run.Length)
   386  		}
   387  	case []parquet.FixedLenByteArray:
   388  		o := out.([]parquet.FixedLenByteArray)
   389  		reader := bitutils.NewSetBitRunReader(validBits, validBitsOffset, int64(len(s)))
   390  		for {
   391  			run := reader.NextRun()
   392  			if run.Length == 0 {
   393  				break
   394  			}
   395  			copy(o[nvalid:], s[int(run.Pos):int(run.Pos+run.Length)])
   396  			nvalid += int(run.Length)
   397  		}
   398  	case []bool:
   399  		o := out.([]bool)
   400  		reader := bitutils.NewSetBitRunReader(validBits, validBitsOffset, int64(len(s)))
   401  		for {
   402  			run := reader.NextRun()
   403  			if run.Length == 0 {
   404  				break
   405  			}
   406  			copy(o[nvalid:], s[int(run.Pos):int(run.Pos+run.Length)])
   407  			nvalid += int(run.Length)
   408  		}
   409  	}
   410  
   411  	return nvalid
   412  }