github.com/Schaudge/grailbase@v0.0.0-20240223061707-44c758a471c0/recordio/header.go (about)

     1  // Copyright 2018 GRAIL, Inc. All rights reserved.
     2  // Use of this source code is governed by the Apache-2.0
     3  // license that can be found in the LICENSE file.
     4  
     5  package recordio
     6  
     7  // Utility functions for encoding and parsing keys/values in a header block.
     8  
     9  import (
    10  	"encoding/binary"
    11  	"fmt"
    12  
    13  	"github.com/Schaudge/grailbase/errors"
    14  )
    15  
    16  const (
    17  	// Reserved header keywords.
    18  
    19  	// KeyTrailer must be set to true when the recordio file contains a trailer.
    20  	// value type: bool
    21  	KeyTrailer = "trailer"
    22  
    23  	// KeyTransformer defines transformer functions used to encode blocks.
    24  	KeyTransformer = "transformer"
    25  )
    26  
    27  // KeyValue defines one entry stored in a recordio header block
    28  type KeyValue struct {
    29  	// Key is the header key
    30  	Key string
    31  	// Value is the value corresponding to Key. The value must be one of int*,
    32  	// uint*, float*, bool, or string type.
    33  	Value interface{}
    34  }
    35  
    36  // ParsedHeader is the result of parsing the recordio header block contents.
    37  type ParsedHeader []KeyValue
    38  
    39  const (
    40  	headerTypeBool   uint8 = 1
    41  	headerTypeInt    uint8 = 2
    42  	headerTypeUint   uint8 = 3
    43  	headerTypeString uint8 = 4
    44  	// TODO(saito) Add more types
    45  )
    46  
    47  // Helper for encoding key/value pairs into bytes to be stored in a header
    48  // block. Thread compatible.
    49  type headerEncoder struct {
    50  	data []byte
    51  }
    52  
    53  func (e *headerEncoder) grow(delta int) {
    54  	cur := len(e.data)
    55  	if cap(e.data) >= cur+delta {
    56  		e.data = e.data[:cur+delta]
    57  	} else {
    58  		tmp := make([]byte, cur+delta, (cur+delta)*2)
    59  		copy(tmp, e.data)
    60  		e.data = tmp
    61  	}
    62  }
    63  
    64  func (e *headerEncoder) putUint(v uint64) {
    65  	e.putRawByte(headerTypeUint)
    66  	cur := len(e.data)
    67  	e.grow(binary.MaxVarintLen64)
    68  	n := binary.PutUvarint(e.data[cur:], v)
    69  	e.data = e.data[:cur+n]
    70  }
    71  
    72  func (e *headerEncoder) putInt(v int64) {
    73  	e.putRawByte(headerTypeInt)
    74  	cur := len(e.data)
    75  	e.grow(binary.MaxVarintLen64)
    76  	n := binary.PutVarint(e.data[cur:], v)
    77  	e.data = e.data[:cur+n]
    78  }
    79  
    80  func (e *headerEncoder) putRawByte(b uint8) {
    81  	cur := len(e.data)
    82  	e.grow(1)
    83  	e.data[cur] = b
    84  	e.data = e.data[:cur+1]
    85  }
    86  
    87  func (e *headerEncoder) putBool(v bool) {
    88  	e.putRawByte(headerTypeBool)
    89  	if v {
    90  		e.putRawByte(1)
    91  	} else {
    92  		e.putRawByte(0)
    93  	}
    94  }
    95  
    96  func (e *headerEncoder) putString(s string) {
    97  	e.putRawByte(headerTypeString)
    98  	e.putUint(uint64(len(s)))
    99  	cur := len(e.data)
   100  	e.grow(len(s))
   101  	copy(e.data[cur:], s)
   102  	e.data = e.data[:cur+len(s)]
   103  }
   104  
   105  func (e *headerEncoder) putKeyValue(key string, v interface{}) error {
   106  	e.putString(key)
   107  	switch v := v.(type) {
   108  	case bool:
   109  		e.putBool(v)
   110  	case uint:
   111  		e.putUint(uint64(v))
   112  	case uint8:
   113  		e.putUint(uint64(v))
   114  	case uint16:
   115  		e.putUint(uint64(v))
   116  	case uint32:
   117  		e.putUint(uint64(v))
   118  	case uint64:
   119  		e.putUint(v)
   120  	case int:
   121  		e.putInt(int64(v))
   122  	case int8:
   123  		e.putInt(int64(v))
   124  	case int16:
   125  		e.putInt(int64(v))
   126  	case int32:
   127  		e.putInt(int64(v))
   128  	case int64:
   129  		e.putInt(v)
   130  	case string:
   131  		e.putString(v)
   132  	default:
   133  		return fmt.Errorf("illegal header type %T", v)
   134  	}
   135  	return nil
   136  }
   137  
   138  // Helper for decoding header data produced by headerEncoder.  Thread
   139  // compatible.
   140  type headerDecoder struct {
   141  	err  errors.Once
   142  	data []byte
   143  }
   144  
   145  func (d *headerDecoder) getRawByte() uint8 {
   146  	if len(d.data) <= 0 {
   147  		d.err.Set(fmt.Errorf("Failed to read byte in header"))
   148  		return 0
   149  	}
   150  	b := d.data[0]
   151  	d.data = d.data[1:]
   152  	return b
   153  }
   154  
   155  func (d *headerDecoder) getRawValue() interface{} {
   156  	vType := d.getRawByte()
   157  	switch vType {
   158  	case headerTypeBool:
   159  		b := d.getRawByte()
   160  		return b != 0
   161  	case headerTypeUint:
   162  		v, n := binary.Uvarint(d.data)
   163  		if n <= 0 {
   164  			d.err.Set(fmt.Errorf("Failed to parse uint"))
   165  			return 0
   166  		}
   167  		d.data = d.data[n:]
   168  		return v
   169  	case headerTypeInt:
   170  		v, n := binary.Varint(d.data)
   171  		if n <= 0 {
   172  			d.err.Set(fmt.Errorf("Failed to parse uint"))
   173  			return 0
   174  		}
   175  		d.data = d.data[n:]
   176  		return v
   177  	case headerTypeString:
   178  		rn := d.getRawValue()
   179  		if err := d.err.Err(); err != nil {
   180  			return ""
   181  		}
   182  		n, ok := rn.(uint64)
   183  		if !ok {
   184  			d.err.Set(fmt.Errorf("failed to read string key"))
   185  			return ""
   186  		}
   187  		if uint64(len(d.data)) < n {
   188  			d.err.Set(fmt.Errorf("header invalid string (%v)", n))
   189  			return ""
   190  		}
   191  		s := string(d.data[:n])
   192  		d.data = d.data[n:]
   193  		return s
   194  	default:
   195  		d.err.Set(fmt.Errorf("illegal header type %T", vType))
   196  		return nil
   197  	}
   198  }
   199  
   200  func (h *ParsedHeader) marshal() ([]byte, error) {
   201  	e := headerEncoder{}
   202  	e.putUint(uint64(len(*h)))
   203  	for _, kv := range *h {
   204  		if err := e.putKeyValue(kv.Key, kv.Value); err != nil {
   205  			return nil, err
   206  		}
   207  	}
   208  	return e.data, nil
   209  }
   210  
   211  func (h *ParsedHeader) unmarshal(data []byte) error {
   212  	d := headerDecoder{data: data}
   213  	vn := d.getRawValue()
   214  	if err := d.err.Err(); err != nil {
   215  		return err
   216  	}
   217  	n, ok := vn.(uint64)
   218  	if !ok {
   219  		d.err.Set(fmt.Errorf("Failed to read # header entries"))
   220  		return d.err.Err()
   221  	}
   222  	for i := uint64(0); i < n; i++ {
   223  		vkey := d.getRawValue()
   224  		if d.err.Err() != nil {
   225  			break
   226  		}
   227  		key, ok := vkey.(string)
   228  		if !ok {
   229  			d.err.Set(fmt.Errorf("failed to read string key"))
   230  			break
   231  		}
   232  		value := d.getRawValue()
   233  		if d.err.Err() != nil {
   234  			break
   235  		}
   236  		*h = append(*h, KeyValue{key, value})
   237  	}
   238  	return d.err.Err()
   239  }
   240  
   241  // HasTrailer checks if the header has a "trailer" entry.
   242  func (h *ParsedHeader) HasTrailer() bool {
   243  	for _, kv := range *h {
   244  		if kv.Key != KeyTrailer {
   245  			continue
   246  		}
   247  		b, ok := kv.Value.(bool)
   248  		if !ok || !b {
   249  			return false
   250  		}
   251  		return true
   252  	}
   253  	return false
   254  }