github.com/cockroachdb/cockroach@v20.2.0-alpha.1+incompatible/pkg/col/colserde/record_batch.go (about) 1 // Copyright 2019 The Cockroach Authors. 2 // 3 // Use of this software is governed by the Business Source License 4 // included in the file licenses/BSL.txt. 5 // 6 // As of the Change Date specified in that file, in accordance with 7 // the Business Source License, use of this software will be governed 8 // by the Apache License, Version 2.0, included in the file 9 // licenses/APL.txt. 10 11 package colserde 12 13 import ( 14 "encoding/binary" 15 "io" 16 17 "github.com/apache/arrow/go/arrow/array" 18 "github.com/apache/arrow/go/arrow/memory" 19 "github.com/cockroachdb/cockroach/pkg/col/colserde/arrowserde" 20 "github.com/cockroachdb/cockroach/pkg/col/typeconv" 21 "github.com/cockroachdb/cockroach/pkg/sql/types" 22 "github.com/cockroachdb/errors" 23 flatbuffers "github.com/google/flatbuffers/go" 24 ) 25 26 const ( 27 // metadataLengthNumBytes is the number of bytes used to encode the length of 28 // the metadata in bytes. These are the first bytes of any arrow IPC message. 29 metadataLengthNumBytes = 4 30 flatbufferBuilderInitialCapacity = 1024 31 ) 32 33 // numBuffersForType returns how many buffers are used to represent an array of 34 // the given type. 35 func numBuffersForType(t *types.T) int { 36 // Most types are represented by 3 memory.Buffers (because most types are 37 // serialized into flat bytes representation). One buffer for the null 38 // bitmap, one for the values, and one for the offsets. 39 numBuffers := 3 40 switch typeconv.TypeFamilyToCanonicalTypeFamily(t.Family()) { 41 case types.BoolFamily, types.FloatFamily, types.IntFamily: 42 // This type doesn't have an offsets buffer. 43 numBuffers = 2 44 } 45 return numBuffers 46 } 47 48 // RecordBatchSerializer serializes RecordBatches in the standard Apache Arrow 49 // IPC format using flatbuffers. Note that only RecordBatch messages are 50 // supported. This is because the full spec would be too much to support 51 // (support for DictionaryBatches, Tensors, SparseTensors, and Schema 52 // messages would be needed) and we only need the part of the spec that allows 53 // us to send data. 54 // The IPC format is described here: 55 // https://arrow.apache.org/docs/format/IPC.html 56 type RecordBatchSerializer struct { 57 // numBuffers holds the number of buffers needed to represent an arrow array 58 // of the type at the corresponding index of typs passed in in 59 // NewRecordBatchSerializer. 60 numBuffers []int 61 62 builder *flatbuffers.Builder 63 scratch struct { 64 bufferLens []int 65 metadataLength [metadataLengthNumBytes]byte 66 padding []byte 67 } 68 } 69 70 // NewRecordBatchSerializer creates a new RecordBatchSerializer according to 71 // typs. Note that Serializing or Deserializing data that does not follow the 72 // passed in schema results in undefined behavior. 73 func NewRecordBatchSerializer(typs []*types.T) (*RecordBatchSerializer, error) { 74 if len(typs) == 0 { 75 return nil, errors.Errorf("zero length schema unsupported") 76 } 77 s := &RecordBatchSerializer{ 78 numBuffers: make([]int, len(typs)), 79 builder: flatbuffers.NewBuilder(flatbufferBuilderInitialCapacity), 80 } 81 for i, t := range typs { 82 s.numBuffers[i] = numBuffersForType(t) 83 } 84 // s.scratch.padding is used to align metadata to an 8 byte boundary, so 85 // doesn't need to be larger than 7 bytes. 86 s.scratch.padding = make([]byte, 7) 87 return s, nil 88 } 89 90 // calculatePadding calculates how many bytes must be added to numBytes to round 91 // it up to the nearest multiple of 8. 92 func calculatePadding(numBytes int) int { 93 return (8 - (numBytes & 7)) & 7 94 } 95 96 // Serialize serializes data as an arrow RecordBatch message and writes it to w. 97 // Serializing a schema that does not match the schema given in 98 // NewRecordBatchSerializer results in undefined behavior. 99 func (s *RecordBatchSerializer) Serialize( 100 w io.Writer, data []*array.Data, 101 ) (metadataLen uint32, dataLen uint64, _ error) { 102 if len(data) != len(s.numBuffers) { 103 return 0, 0, errors.Errorf("mismatched schema length and number of columns: %d != %d", len(s.numBuffers), len(data)) 104 } 105 // Ensure equal data length and expected number of buffers. We don't support 106 // zero-length schemas, so data[0] is in bounds at this point. 107 headerLength := data[0].Len() 108 for i := range data { 109 if data[i].Len() != headerLength { 110 return 0, 0, errors.Errorf("mismatched data lengths at column %d: %d != %d", i, headerLength, data[i].Len()) 111 } 112 if len(data[i].Buffers()) != s.numBuffers[i] { 113 return 0, 0, errors.Errorf( 114 "mismatched number of buffers at column %d: %d != %d", i, len(data[i].Buffers()), s.numBuffers[i], 115 ) 116 } 117 } 118 119 // The following is a good tutorial to understand flatbuffers (i.e. what is 120 // going on here) better: 121 // https://google.github.io/flatbuffers/flatbuffers_guide_tutorial.html 122 123 s.builder.Reset() 124 s.scratch.bufferLens = s.scratch.bufferLens[:0] 125 totalBufferLen := 0 126 127 // Encode the nodes. These are structs that represent each element in data, 128 // including the length and null count. 129 // When constructing flatbuffers, we start at the leaves of whatever data 130 // structure we are serializing so that we can deserialize from the beginning 131 // of that buffer while taking advantage of cache prefetching. Vectors are 132 // serialized backwards (hence the backwards iteration) because flatbuffer 133 // builders can only be prepended to and it simplifies the spec if vectors 134 // follow the same back to front approach as other data. 135 arrowserde.RecordBatchStartNodesVector(s.builder, len(data)) 136 for i := len(data) - 1; i >= 0; i-- { 137 col := data[i] 138 arrowserde.CreateFieldNode(s.builder, int64(col.Len()), int64(col.NullN())) 139 buffers := col.Buffers() 140 for j := len(buffers) - 1; j >= 0; j-- { 141 bufferLen := 0 142 // Some value buffers can be nil if the data are all zero values. 143 if buffers[j] != nil { 144 bufferLen = buffers[j].Len() 145 } 146 s.scratch.bufferLens = append(s.scratch.bufferLens, bufferLen) 147 totalBufferLen += bufferLen 148 } 149 } 150 nodes := s.builder.EndVector(len(data)) 151 152 // Encode the buffers vector. There are many buffers for each element in data 153 // and the actual bytes will be added to the message body later. Here we 154 // encode structs that hold the offset (relative to the start of the body) 155 // and the length of each buffer so that the deserializer can seek to the 156 // actual bytes in the body. Note that we iterate over s.scratch.bufferLens 157 // forwards due to adding lengths in the order that we want to prepend when 158 // creating the nodes vector. 159 arrowserde.RecordBatchStartBuffersVector(s.builder, len(s.scratch.bufferLens)) 160 for i, offset := 0, totalBufferLen; i < len(s.scratch.bufferLens); i++ { 161 bufferLen := s.scratch.bufferLens[i] 162 offset -= bufferLen 163 arrowserde.CreateBuffer(s.builder, int64(offset), int64(bufferLen)) 164 } 165 buffers := s.builder.EndVector(len(s.scratch.bufferLens)) 166 167 // Encode the RecordBatch. This is a table that holds both the nodes and 168 // buffer information. 169 arrowserde.RecordBatchStart(s.builder) 170 arrowserde.RecordBatchAddLength(s.builder, int64(headerLength)) 171 arrowserde.RecordBatchAddNodes(s.builder, nodes) 172 arrowserde.RecordBatchAddBuffers(s.builder, buffers) 173 header := arrowserde.RecordBatchEnd(s.builder) 174 175 // Finally, encode the Message table. This will include the RecordBatch above 176 // as well as some metadata. 177 arrowserde.MessageStart(s.builder) 178 arrowserde.MessageAddVersion(s.builder, arrowserde.MetadataVersionV1) 179 arrowserde.MessageAddHeaderType(s.builder, arrowserde.MessageHeaderRecordBatch) 180 arrowserde.MessageAddHeader(s.builder, header) 181 arrowserde.MessageAddBodyLength(s.builder, int64(totalBufferLen)) 182 s.builder.Finish(arrowserde.MessageEnd(s.builder)) 183 184 metadataBytes := s.builder.FinishedBytes() 185 186 // Use s.scratch.padding to align metadata to 8-byte boundary. 187 s.scratch.padding = s.scratch.padding[:calculatePadding(metadataLengthNumBytes+len(metadataBytes))] 188 189 // Write metadata + padding length as the first metadataLengthNumBytes. 190 metadataLength := uint32(len(metadataBytes) + len(s.scratch.padding)) 191 binary.LittleEndian.PutUint32(s.scratch.metadataLength[:], metadataLength) 192 if _, err := w.Write(s.scratch.metadataLength[:]); err != nil { 193 return 0, 0, err 194 } 195 196 // Write metadata. 197 if _, err := w.Write(metadataBytes); err != nil { 198 return 0, 0, err 199 } 200 201 // Add metadata padding. 202 if _, err := w.Write(s.scratch.padding); err != nil { 203 return 0, 0, err 204 } 205 206 // Add message body. The metadata holds the offsets and lengths of these 207 // buffers. 208 bodyLength := 0 209 for i := 0; i < len(data); i++ { 210 buffers := data[i].Buffers() 211 for j := 0; j < len(buffers); j++ { 212 var bufferBytes []byte 213 if buffers[j] != nil { 214 // Some value buffers can be nil if the data are all zero values. 215 bufferBytes = buffers[j].Bytes() 216 } 217 bodyLength += len(bufferBytes) 218 if _, err := w.Write(bufferBytes); err != nil { 219 return 0, 0, err 220 } 221 } 222 } 223 224 // Add body padding. The body also needs to be a multiple of 8 bytes. 225 s.scratch.padding = s.scratch.padding[:calculatePadding(bodyLength)] 226 _, err := w.Write(s.scratch.padding) 227 bodyLength += len(s.scratch.padding) 228 return metadataLength, uint64(bodyLength), err 229 } 230 231 // Deserialize deserializes an arrow IPC RecordBatch message contained in bytes 232 // into data. Deserializing a schema that does not match the schema given in 233 // NewRecordBatchSerializer results in undefined behavior. 234 func (s *RecordBatchSerializer) Deserialize(data *[]*array.Data, bytes []byte) error { 235 // Read the metadata by first reading its length. 236 metadataLen := int(binary.LittleEndian.Uint32(bytes[:metadataLengthNumBytes])) 237 metadata := arrowserde.GetRootAsMessage( 238 bytes[metadataLengthNumBytes:metadataLengthNumBytes+metadataLen], 0, 239 ) 240 241 bodyBytes := bytes[metadataLengthNumBytes+metadataLen : metadataLengthNumBytes+metadataLen+int(metadata.BodyLength())] 242 243 // We don't check the version because we don't fully support arrow 244 // serialization/deserialization so it's not useful. Refer to the 245 // RecordBatchSerializer struct comment for more information. 246 _ = metadata.Version() 247 248 if metadata.HeaderType() != arrowserde.MessageHeaderRecordBatch { 249 return errors.Errorf( 250 `cannot decode RecordBatch from %s message`, 251 arrowserde.EnumNamesMessageHeader[metadata.HeaderType()], 252 ) 253 } 254 255 var ( 256 headerTab flatbuffers.Table 257 header arrowserde.RecordBatch 258 ) 259 260 if !metadata.Header(&headerTab) { 261 return errors.New(`unable to decode metadata table`) 262 } 263 264 header.Init(headerTab.Bytes, headerTab.Pos) 265 if len(s.numBuffers) != header.NodesLength() { 266 return errors.Errorf( 267 `mismatched schema and header lengths: %d != %d`, len(s.numBuffers), header.NodesLength(), 268 ) 269 } 270 271 var ( 272 node arrowserde.FieldNode 273 buf arrowserde.Buffer 274 ) 275 for fieldIdx, bufferIdx := 0, 0; fieldIdx < len(s.numBuffers); fieldIdx++ { 276 header.Nodes(&node, fieldIdx) 277 278 // Make sure that this node (i.e. column buffer) is the same length as the 279 // length in the header, which specifies how many rows there are in the 280 // message body. 281 if node.Length() != header.Length() { 282 return errors.Errorf( 283 `mismatched field and header lengths: %d != %d`, node.Length(), header.Length(), 284 ) 285 } 286 287 // Decode the message body by using the offset and length information in the 288 // message header. 289 buffers := make([]*memory.Buffer, s.numBuffers[fieldIdx]) 290 for i := 0; i < s.numBuffers[fieldIdx]; i++ { 291 header.Buffers(&buf, bufferIdx) 292 bufData := bodyBytes[int(buf.Offset()):int(buf.Offset()+buf.Length())] 293 buffers[i] = memory.NewBufferBytes(bufData) 294 bufferIdx++ 295 } 296 297 *data = append( 298 *data, 299 array.NewData( 300 nil, /* dType */ 301 int(header.Length()), 302 buffers, 303 nil, /* childData. Note that we do not support types with childData */ 304 int(node.NullCount()), 305 0, /* offset */ 306 ), 307 ) 308 } 309 310 return nil 311 }