github.com/apache/arrow/go/v14@v14.0.1/parquet/internal/thrift/helpers.go (about)

     1  // Licensed to the Apache Software Foundation (ASF) under one
     2  // or more contributor license agreements.  See the NOTICE file
     3  // distributed with this work for additional information
     4  // regarding copyright ownership.  The ASF licenses this file
     5  // to you under the Apache License, Version 2.0 (the
     6  // "License"); you may not use this file except in compliance
     7  // with the License.  You may obtain a copy of the License at
     8  //
     9  // http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  
    17  // Package thrift is just some useful helpers for interacting with thrift to
    18  // make other code easier to read/write and centralize interactions.
    19  package thrift
    20  
    21  import (
    22  	"bytes"
    23  	"context"
    24  	"io"
    25  
    26  	"github.com/apache/arrow/go/v14/parquet/internal/encryption"
    27  	"github.com/apache/thrift/lib/go/thrift"
    28  )
    29  
    30  // default factory for creating thrift protocols for serialization/deserialization
    31  var protocolFactory = thrift.NewTCompactProtocolFactoryConf(&thrift.TConfiguration{})
    32  
    33  // DeserializeThrift deserializes the bytes in buf into the given thrift msg type
    34  // returns the number of remaining bytes in the buffer that weren't needed for deserialization
    35  // and any error if there was one, or nil.
    36  func DeserializeThrift(msg thrift.TStruct, buf []byte) (remain uint64, err error) {
    37  	tbuf := &thrift.TMemoryBuffer{Buffer: bytes.NewBuffer(buf)}
    38  	err = msg.Read(context.TODO(), protocolFactory.GetProtocol(tbuf))
    39  	remain = tbuf.RemainingBytes()
    40  	return
    41  }
    42  
    43  // SerializeThriftStream writes out the serialized bytes of the passed in type
    44  // to the given writer stream.
    45  func SerializeThriftStream(msg thrift.TStruct, w io.Writer) error {
    46  	return msg.Write(context.TODO(), protocolFactory.GetProtocol(thrift.NewStreamTransportW(w)))
    47  }
    48  
    49  // DeserializeThriftStream populates the given msg by reading from the provided
    50  // stream until it completes the deserialization.
    51  func DeserializeThriftStream(msg thrift.TStruct, r io.Reader) error {
    52  	return msg.Read(context.TODO(), protocolFactory.GetProtocol(thrift.NewStreamTransportR(r)))
    53  }
    54  
    55  // Serializer is an object that can stick around to provide convenience
    56  // functions and allow object reuse
    57  type Serializer struct {
    58  	thrift.TSerializer
    59  }
    60  
    61  // NewThriftSerializer constructs a serializer with a default buffer of 1024
    62  func NewThriftSerializer() *Serializer {
    63  	tbuf := thrift.NewTMemoryBufferLen(1024)
    64  	return &Serializer{thrift.TSerializer{
    65  		Transport: tbuf,
    66  		Protocol:  protocolFactory.GetProtocol(tbuf),
    67  	}}
    68  }
    69  
    70  // Serialize will serialize the given msg to the writer stream w, optionally encrypting it on the way
    71  // if enc is not nil, returning the total number of bytes written and any error received, or nil
    72  func (t *Serializer) Serialize(msg thrift.TStruct, w io.Writer, enc encryption.Encryptor) (int, error) {
    73  	b, err := t.Write(context.Background(), msg)
    74  	if err != nil {
    75  		return 0, err
    76  	}
    77  
    78  	if enc == nil {
    79  		return w.Write(b)
    80  	}
    81  
    82  	var cipherBuf bytes.Buffer
    83  	cipherBuf.Grow(enc.CiphertextSizeDelta() + len(b))
    84  	enc.Encrypt(&cipherBuf, b)
    85  	n, err := cipherBuf.WriteTo(w)
    86  	return int(n), err
    87  }