github.com/grailbio/base@v0.0.11/recordio/header.go (about) 1 // Copyright 2018 GRAIL, Inc. All rights reserved. 2 // Use of this source code is governed by the Apache-2.0 3 // license that can be found in the LICENSE file. 4 5 package recordio 6 7 // Utility functions for encoding and parsing keys/values in a header block. 8 9 import ( 10 "encoding/binary" 11 "fmt" 12 13 "github.com/grailbio/base/errors" 14 ) 15 16 const ( 17 // Reserved header keywords. 18 19 // KeyTrailer must be set to true when the recordio file contains a trailer. 20 // value type: bool 21 KeyTrailer = "trailer" 22 23 // KeyTransformer defines transformer functions used to encode blocks. 24 KeyTransformer = "transformer" 25 ) 26 27 // KeyValue defines one entry stored in a recordio header block 28 type KeyValue struct { 29 // Key is the header key 30 Key string 31 // Value is the value corresponding to Key. The value must be one of int*, 32 // uint*, float*, bool, or string type. 33 Value interface{} 34 } 35 36 // ParsedHeader is the result of parsing the recordio header block contents. 37 type ParsedHeader []KeyValue 38 39 const ( 40 headerTypeBool uint8 = 1 41 headerTypeInt uint8 = 2 42 headerTypeUint uint8 = 3 43 headerTypeString uint8 = 4 44 // TODO(saito) Add more types 45 ) 46 47 // Helper for encoding key/value pairs into bytes to be stored in a header 48 // block. Thread compatible. 49 type headerEncoder struct { 50 data []byte 51 } 52 53 func (e *headerEncoder) grow(delta int) { 54 cur := len(e.data) 55 if cap(e.data) >= cur+delta { 56 e.data = e.data[:cur+delta] 57 } else { 58 tmp := make([]byte, cur+delta, (cur+delta)*2) 59 copy(tmp, e.data) 60 e.data = tmp 61 } 62 } 63 64 func (e *headerEncoder) putUint(v uint64) { 65 e.putRawByte(headerTypeUint) 66 cur := len(e.data) 67 e.grow(binary.MaxVarintLen64) 68 n := binary.PutUvarint(e.data[cur:], v) 69 e.data = e.data[:cur+n] 70 } 71 72 func (e *headerEncoder) putInt(v int64) { 73 e.putRawByte(headerTypeInt) 74 cur := len(e.data) 75 e.grow(binary.MaxVarintLen64) 76 n := binary.PutVarint(e.data[cur:], v) 77 e.data = e.data[:cur+n] 78 } 79 80 func (e *headerEncoder) putRawByte(b uint8) { 81 cur := len(e.data) 82 e.grow(1) 83 e.data[cur] = b 84 e.data = e.data[:cur+1] 85 } 86 87 func (e *headerEncoder) putBool(v bool) { 88 e.putRawByte(headerTypeBool) 89 if v { 90 e.putRawByte(1) 91 } else { 92 e.putRawByte(0) 93 } 94 } 95 96 func (e *headerEncoder) putString(s string) { 97 e.putRawByte(headerTypeString) 98 e.putUint(uint64(len(s))) 99 cur := len(e.data) 100 e.grow(len(s)) 101 copy(e.data[cur:], s) 102 e.data = e.data[:cur+len(s)] 103 } 104 105 func (e *headerEncoder) putKeyValue(key string, v interface{}) error { 106 e.putString(key) 107 switch v := v.(type) { 108 case bool: 109 e.putBool(v) 110 case uint: 111 e.putUint(uint64(v)) 112 case uint8: 113 e.putUint(uint64(v)) 114 case uint16: 115 e.putUint(uint64(v)) 116 case uint32: 117 e.putUint(uint64(v)) 118 case uint64: 119 e.putUint(v) 120 case int: 121 e.putInt(int64(v)) 122 case int8: 123 e.putInt(int64(v)) 124 case int16: 125 e.putInt(int64(v)) 126 case int32: 127 e.putInt(int64(v)) 128 case int64: 129 e.putInt(v) 130 case string: 131 e.putString(v) 132 default: 133 return fmt.Errorf("illegal header type %T", v) 134 } 135 return nil 136 } 137 138 // Helper for decoding header data produced by headerEncoder. Thread 139 // compatible. 140 type headerDecoder struct { 141 err errors.Once 142 data []byte 143 } 144 145 func (d *headerDecoder) getRawByte() uint8 { 146 if len(d.data) <= 0 { 147 d.err.Set(fmt.Errorf("Failed to read byte in header")) 148 return 0 149 } 150 b := d.data[0] 151 d.data = d.data[1:] 152 return b 153 } 154 155 func (d *headerDecoder) getRawValue() interface{} { 156 vType := d.getRawByte() 157 switch vType { 158 case headerTypeBool: 159 b := d.getRawByte() 160 return b != 0 161 case headerTypeUint: 162 v, n := binary.Uvarint(d.data) 163 if n <= 0 { 164 d.err.Set(fmt.Errorf("Failed to parse uint")) 165 return 0 166 } 167 d.data = d.data[n:] 168 return v 169 case headerTypeInt: 170 v, n := binary.Varint(d.data) 171 if n <= 0 { 172 d.err.Set(fmt.Errorf("Failed to parse uint")) 173 return 0 174 } 175 d.data = d.data[n:] 176 return v 177 case headerTypeString: 178 rn := d.getRawValue() 179 if err := d.err.Err(); err != nil { 180 return "" 181 } 182 n, ok := rn.(uint64) 183 if !ok { 184 d.err.Set(fmt.Errorf("failed to read string key")) 185 return "" 186 } 187 if uint64(len(d.data)) < n { 188 d.err.Set(fmt.Errorf("header invalid string (%v)", n)) 189 return "" 190 } 191 s := string(d.data[:n]) 192 d.data = d.data[n:] 193 return s 194 default: 195 d.err.Set(fmt.Errorf("illegal header type %T", vType)) 196 return nil 197 } 198 } 199 200 func (h *ParsedHeader) marshal() ([]byte, error) { 201 e := headerEncoder{} 202 e.putUint(uint64(len(*h))) 203 for _, kv := range *h { 204 if err := e.putKeyValue(kv.Key, kv.Value); err != nil { 205 return nil, err 206 } 207 } 208 return e.data, nil 209 } 210 211 func (h *ParsedHeader) unmarshal(data []byte) error { 212 d := headerDecoder{data: data} 213 vn := d.getRawValue() 214 if err := d.err.Err(); err != nil { 215 return err 216 } 217 n, ok := vn.(uint64) 218 if !ok { 219 d.err.Set(fmt.Errorf("Failed to read # header entries")) 220 return d.err.Err() 221 } 222 for i := uint64(0); i < n; i++ { 223 vkey := d.getRawValue() 224 if d.err.Err() != nil { 225 break 226 } 227 key, ok := vkey.(string) 228 if !ok { 229 d.err.Set(fmt.Errorf("failed to read string key")) 230 break 231 } 232 value := d.getRawValue() 233 if d.err.Err() != nil { 234 break 235 } 236 *h = append(*h, KeyValue{key, value}) 237 } 238 return d.err.Err() 239 } 240 241 // HasTrailer checks if the header has a "trailer" entry. 242 func (h *ParsedHeader) HasTrailer() bool { 243 for _, kv := range *h { 244 if kv.Key != KeyTrailer { 245 continue 246 } 247 b, ok := kv.Value.(bool) 248 if !ok || !b { 249 return false 250 } 251 return true 252 } 253 return false 254 }