github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/col/colserde/record_batch.go (about)

     1  // Copyright 2019 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package colserde
    12  
    13  import (
    14  	"encoding/binary"
    15  	"io"
    16  
    17  	"github.com/apache/arrow/go/arrow/array"
    18  	"github.com/apache/arrow/go/arrow/memory"
    19  	"github.com/cockroachdb/cockroach/pkg/col/colserde/arrowserde"
    20  	"github.com/cockroachdb/cockroach/pkg/col/typeconv"
    21  	"github.com/cockroachdb/cockroach/pkg/sql/types"
    22  	"github.com/cockroachdb/errors"
    23  	flatbuffers "github.com/google/flatbuffers/go"
    24  )
    25  
    26  const (
    27  	// metadataLengthNumBytes is the number of bytes used to encode the length of
    28  	// the metadata in bytes. These are the first bytes of any arrow IPC message.
    29  	metadataLengthNumBytes           = 4
    30  	flatbufferBuilderInitialCapacity = 1024
    31  )
    32  
    33  // numBuffersForType returns how many buffers are used to represent an array of
    34  // the given type.
    35  func numBuffersForType(t *types.T) int {
    36  	// Most types are represented by 3 memory.Buffers (because most types are
    37  	// serialized into flat bytes representation). One buffer for the null
    38  	// bitmap, one for the values, and one for the offsets.
    39  	numBuffers := 3
    40  	switch typeconv.TypeFamilyToCanonicalTypeFamily(t.Family()) {
    41  	case types.BoolFamily, types.FloatFamily, types.IntFamily:
    42  		// This type doesn't have an offsets buffer.
    43  		numBuffers = 2
    44  	}
    45  	return numBuffers
    46  }
    47  
    48  // RecordBatchSerializer serializes RecordBatches in the standard Apache Arrow
    49  // IPC format using flatbuffers. Note that only RecordBatch messages are
    50  // supported. This is because the full spec would be too much to support
    51  // (support for DictionaryBatches, Tensors, SparseTensors, and Schema
    52  // messages would be needed) and we only need the part of the spec that allows
    53  // us to send data.
    54  // The IPC format is described here:
    55  // https://arrow.apache.org/docs/format/IPC.html
    56  type RecordBatchSerializer struct {
    57  	// numBuffers holds the number of buffers needed to represent an arrow array
    58  	// of the type at the corresponding index of typs passed in in
    59  	// NewRecordBatchSerializer.
    60  	numBuffers []int
    61  
    62  	builder *flatbuffers.Builder
    63  	scratch struct {
    64  		bufferLens     []int
    65  		metadataLength [metadataLengthNumBytes]byte
    66  		padding        []byte
    67  	}
    68  }
    69  
    70  // NewRecordBatchSerializer creates a new RecordBatchSerializer according to
    71  // typs. Note that Serializing or Deserializing data that does not follow the
    72  // passed in schema results in undefined behavior.
    73  func NewRecordBatchSerializer(typs []*types.T) (*RecordBatchSerializer, error) {
    74  	if len(typs) == 0 {
    75  		return nil, errors.Errorf("zero length schema unsupported")
    76  	}
    77  	s := &RecordBatchSerializer{
    78  		numBuffers: make([]int, len(typs)),
    79  		builder:    flatbuffers.NewBuilder(flatbufferBuilderInitialCapacity),
    80  	}
    81  	for i, t := range typs {
    82  		s.numBuffers[i] = numBuffersForType(t)
    83  	}
    84  	// s.scratch.padding is used to align metadata to an 8 byte boundary, so
    85  	// doesn't need to be larger than 7 bytes.
    86  	s.scratch.padding = make([]byte, 7)
    87  	return s, nil
    88  }
    89  
    90  // calculatePadding calculates how many bytes must be added to numBytes to round
    91  // it up to the nearest multiple of 8.
    92  func calculatePadding(numBytes int) int {
    93  	return (8 - (numBytes & 7)) & 7
    94  }
    95  
    96  // Serialize serializes data as an arrow RecordBatch message and writes it to w.
    97  // Serializing a schema that does not match the schema given in
    98  // NewRecordBatchSerializer results in undefined behavior.
    99  func (s *RecordBatchSerializer) Serialize(
   100  	w io.Writer, data []*array.Data,
   101  ) (metadataLen uint32, dataLen uint64, _ error) {
   102  	if len(data) != len(s.numBuffers) {
   103  		return 0, 0, errors.Errorf("mismatched schema length and number of columns: %d != %d", len(s.numBuffers), len(data))
   104  	}
   105  	// Ensure equal data length and expected number of buffers. We don't support
   106  	// zero-length schemas, so data[0] is in bounds at this point.
   107  	headerLength := data[0].Len()
   108  	for i := range data {
   109  		if data[i].Len() != headerLength {
   110  			return 0, 0, errors.Errorf("mismatched data lengths at column %d: %d != %d", i, headerLength, data[i].Len())
   111  		}
   112  		if len(data[i].Buffers()) != s.numBuffers[i] {
   113  			return 0, 0, errors.Errorf(
   114  				"mismatched number of buffers at column %d: %d != %d", i, len(data[i].Buffers()), s.numBuffers[i],
   115  			)
   116  		}
   117  	}
   118  
   119  	// The following is a good tutorial to understand flatbuffers (i.e. what is
   120  	// going on here) better:
   121  	// https://google.github.io/flatbuffers/flatbuffers_guide_tutorial.html
   122  
   123  	s.builder.Reset()
   124  	s.scratch.bufferLens = s.scratch.bufferLens[:0]
   125  	totalBufferLen := 0
   126  
   127  	// Encode the nodes. These are structs that represent each element in data,
   128  	// including the length and null count.
   129  	// When constructing flatbuffers, we start at the leaves of whatever data
   130  	// structure we are serializing so that we can deserialize from the beginning
   131  	// of that buffer while taking advantage of cache prefetching. Vectors are
   132  	// serialized backwards (hence the backwards iteration) because flatbuffer
   133  	// builders can only be prepended to and it simplifies the spec if vectors
   134  	// follow the same back to front approach as other data.
   135  	arrowserde.RecordBatchStartNodesVector(s.builder, len(data))
   136  	for i := len(data) - 1; i >= 0; i-- {
   137  		col := data[i]
   138  		arrowserde.CreateFieldNode(s.builder, int64(col.Len()), int64(col.NullN()))
   139  		buffers := col.Buffers()
   140  		for j := len(buffers) - 1; j >= 0; j-- {
   141  			bufferLen := 0
   142  			// Some value buffers can be nil if the data are all zero values.
   143  			if buffers[j] != nil {
   144  				bufferLen = buffers[j].Len()
   145  			}
   146  			s.scratch.bufferLens = append(s.scratch.bufferLens, bufferLen)
   147  			totalBufferLen += bufferLen
   148  		}
   149  	}
   150  	nodes := s.builder.EndVector(len(data))
   151  
   152  	// Encode the buffers vector. There are many buffers for each element in data
   153  	// and the actual bytes will be added to the message body later. Here we
   154  	// encode structs that hold the offset (relative to the start of the body)
   155  	// and the length of each buffer so that the deserializer can seek to the
   156  	// actual bytes in the body. Note that we iterate over s.scratch.bufferLens
   157  	// forwards due to adding lengths in the order that we want to prepend when
   158  	// creating the nodes vector.
   159  	arrowserde.RecordBatchStartBuffersVector(s.builder, len(s.scratch.bufferLens))
   160  	for i, offset := 0, totalBufferLen; i < len(s.scratch.bufferLens); i++ {
   161  		bufferLen := s.scratch.bufferLens[i]
   162  		offset -= bufferLen
   163  		arrowserde.CreateBuffer(s.builder, int64(offset), int64(bufferLen))
   164  	}
   165  	buffers := s.builder.EndVector(len(s.scratch.bufferLens))
   166  
   167  	// Encode the RecordBatch. This is a table that holds both the nodes and
   168  	// buffer information.
   169  	arrowserde.RecordBatchStart(s.builder)
   170  	arrowserde.RecordBatchAddLength(s.builder, int64(headerLength))
   171  	arrowserde.RecordBatchAddNodes(s.builder, nodes)
   172  	arrowserde.RecordBatchAddBuffers(s.builder, buffers)
   173  	header := arrowserde.RecordBatchEnd(s.builder)
   174  
   175  	// Finally, encode the Message table. This will include the RecordBatch above
   176  	// as well as some metadata.
   177  	arrowserde.MessageStart(s.builder)
   178  	arrowserde.MessageAddVersion(s.builder, arrowserde.MetadataVersionV1)
   179  	arrowserde.MessageAddHeaderType(s.builder, arrowserde.MessageHeaderRecordBatch)
   180  	arrowserde.MessageAddHeader(s.builder, header)
   181  	arrowserde.MessageAddBodyLength(s.builder, int64(totalBufferLen))
   182  	s.builder.Finish(arrowserde.MessageEnd(s.builder))
   183  
   184  	metadataBytes := s.builder.FinishedBytes()
   185  
   186  	// Use s.scratch.padding to align metadata to 8-byte boundary.
   187  	s.scratch.padding = s.scratch.padding[:calculatePadding(metadataLengthNumBytes+len(metadataBytes))]
   188  
   189  	// Write metadata + padding length as the first metadataLengthNumBytes.
   190  	metadataLength := uint32(len(metadataBytes) + len(s.scratch.padding))
   191  	binary.LittleEndian.PutUint32(s.scratch.metadataLength[:], metadataLength)
   192  	if _, err := w.Write(s.scratch.metadataLength[:]); err != nil {
   193  		return 0, 0, err
   194  	}
   195  
   196  	// Write metadata.
   197  	if _, err := w.Write(metadataBytes); err != nil {
   198  		return 0, 0, err
   199  	}
   200  
   201  	// Add metadata padding.
   202  	if _, err := w.Write(s.scratch.padding); err != nil {
   203  		return 0, 0, err
   204  	}
   205  
   206  	// Add message body. The metadata holds the offsets and lengths of these
   207  	// buffers.
   208  	bodyLength := 0
   209  	for i := 0; i < len(data); i++ {
   210  		buffers := data[i].Buffers()
   211  		for j := 0; j < len(buffers); j++ {
   212  			var bufferBytes []byte
   213  			if buffers[j] != nil {
   214  				// Some value buffers can be nil if the data are all zero values.
   215  				bufferBytes = buffers[j].Bytes()
   216  			}
   217  			bodyLength += len(bufferBytes)
   218  			if _, err := w.Write(bufferBytes); err != nil {
   219  				return 0, 0, err
   220  			}
   221  		}
   222  	}
   223  
   224  	// Add body padding. The body also needs to be a multiple of 8 bytes.
   225  	s.scratch.padding = s.scratch.padding[:calculatePadding(bodyLength)]
   226  	_, err := w.Write(s.scratch.padding)
   227  	bodyLength += len(s.scratch.padding)
   228  	return metadataLength, uint64(bodyLength), err
   229  }
   230  
   231  // Deserialize deserializes an arrow IPC RecordBatch message contained in bytes
   232  // into data. Deserializing a schema that does not match the schema given in
   233  // NewRecordBatchSerializer results in undefined behavior.
   234  func (s *RecordBatchSerializer) Deserialize(data *[]*array.Data, bytes []byte) error {
   235  	// Read the metadata by first reading its length.
   236  	metadataLen := int(binary.LittleEndian.Uint32(bytes[:metadataLengthNumBytes]))
   237  	metadata := arrowserde.GetRootAsMessage(
   238  		bytes[metadataLengthNumBytes:metadataLengthNumBytes+metadataLen], 0,
   239  	)
   240  
   241  	bodyBytes := bytes[metadataLengthNumBytes+metadataLen : metadataLengthNumBytes+metadataLen+int(metadata.BodyLength())]
   242  
   243  	// We don't check the version because we don't fully support arrow
   244  	// serialization/deserialization so it's not useful. Refer to the
   245  	// RecordBatchSerializer struct comment for more information.
   246  	_ = metadata.Version()
   247  
   248  	if metadata.HeaderType() != arrowserde.MessageHeaderRecordBatch {
   249  		return errors.Errorf(
   250  			`cannot decode RecordBatch from %s message`,
   251  			arrowserde.EnumNamesMessageHeader[metadata.HeaderType()],
   252  		)
   253  	}
   254  
   255  	var (
   256  		headerTab flatbuffers.Table
   257  		header    arrowserde.RecordBatch
   258  	)
   259  
   260  	if !metadata.Header(&headerTab) {
   261  		return errors.New(`unable to decode metadata table`)
   262  	}
   263  
   264  	header.Init(headerTab.Bytes, headerTab.Pos)
   265  	if len(s.numBuffers) != header.NodesLength() {
   266  		return errors.Errorf(
   267  			`mismatched schema and header lengths: %d != %d`, len(s.numBuffers), header.NodesLength(),
   268  		)
   269  	}
   270  
   271  	var (
   272  		node arrowserde.FieldNode
   273  		buf  arrowserde.Buffer
   274  	)
   275  	for fieldIdx, bufferIdx := 0, 0; fieldIdx < len(s.numBuffers); fieldIdx++ {
   276  		header.Nodes(&node, fieldIdx)
   277  
   278  		// Make sure that this node (i.e. column buffer) is the same length as the
   279  		// length in the header, which specifies how many rows there are in the
   280  		// message body.
   281  		if node.Length() != header.Length() {
   282  			return errors.Errorf(
   283  				`mismatched field and header lengths: %d != %d`, node.Length(), header.Length(),
   284  			)
   285  		}
   286  
   287  		// Decode the message body by using the offset and length information in the
   288  		// message header.
   289  		buffers := make([]*memory.Buffer, s.numBuffers[fieldIdx])
   290  		for i := 0; i < s.numBuffers[fieldIdx]; i++ {
   291  			header.Buffers(&buf, bufferIdx)
   292  			bufData := bodyBytes[int(buf.Offset()):int(buf.Offset()+buf.Length())]
   293  			buffers[i] = memory.NewBufferBytes(bufData)
   294  			bufferIdx++
   295  		}
   296  
   297  		*data = append(
   298  			*data,
   299  			array.NewData(
   300  				nil, /* dType */
   301  				int(header.Length()),
   302  				buffers,
   303  				nil, /* childData. Note that we do not support types with childData */
   304  				int(node.NullCount()),
   305  				0, /* offset */
   306  			),
   307  		)
   308  	}
   309  
   310  	return nil
   311  }