github.com/grailbio/base@v0.0.11/recordio/registry.go (about) 1 // Copyright 2018 GRAIL, Inc. All rights reserved. 2 // Use of this source code is governed by the Apache-2.0 3 // license that can be found in the LICENSE file. 4 5 package recordio 6 7 import ( 8 "fmt" 9 "strings" 10 "sync" 11 12 "github.com/grailbio/base/recordio/recordioiov" 13 "github.com/pkg/errors" 14 ) 15 16 // TransformerFactory is a function that creates a new TransformerFunc given an 17 // optional config string. 18 type TransformerFactory func(config string) (TransformFunc, error) 19 20 type registerer struct { 21 sync.RWMutex 22 transformers map[string]TransformerFactory 23 untransformers map[string]TransformerFactory 24 } 25 26 var registry = ®isterer{ 27 transformers: make(map[string]TransformerFactory), 28 untransformers: make(map[string]TransformerFactory), 29 } 30 31 func idTransform(scratch []byte, in [][]byte) ([]byte, error) { 32 out := recordioiov.Slice(scratch, recordioiov.TotalBytes(in)) 33 n := 0 34 for _, b := range in { 35 copy(out[n:], b) 36 n += len(b) 37 } 38 return out, nil 39 } 40 41 func (r *registerer) registerTransformer(name string, t TransformerFactory, u TransformerFactory) { 42 r.Lock() 43 defer r.Unlock() 44 if _, ok := r.transformers[name]; ok { 45 panic(fmt.Sprintf("Transformer %s already registered", name)) 46 } 47 r.transformers[name] = t 48 r.untransformers[name] = u 49 } 50 51 func getTransformers(h []string, factory map[string]TransformerFactory) ([]TransformFunc, error) { 52 var transformers []TransformFunc 53 for _, str := range h { 54 toks := strings.SplitN(str, " ", 2) 55 name := toks[0] 56 f, ok := factory[name] 57 if !ok { 58 return nil, errors.Errorf("Transformer %s not found", str) 59 } 60 61 var config string 62 if len(toks) > 1 { 63 config = toks[1] 64 } 65 var tr TransformFunc 66 var err error 67 if tr, err = f(config); err != nil { 68 return nil, err 69 } 70 transformers = append(transformers, tr) 71 } 72 return transformers, nil 73 } 74 75 func (r *registerer) getTransformer(h []string) (TransformFunc, error) { 76 r.RLock() 77 transformers, err := getTransformers(h, r.transformers) 78 r.RUnlock() 79 if err != nil { 80 return nil, err 81 } 82 83 switch { 84 case len(transformers) == 0: 85 return idTransform, nil 86 case len(transformers) == 1: 87 return transformers[0], nil 88 default: 89 combined := func(scratch []byte, data [][]byte) ([]byte, error) { 90 for _, tr := range transformers { 91 out, err := tr(scratch, data) 92 if err != nil { 93 return nil, err 94 } 95 if len(data) == 0 { 96 data = [][]byte{out} 97 } else { 98 data = data[:1] 99 data[0] = out 100 } 101 scratch = nil 102 // TODO(saito) Maybe reuse one of data[] as scratch? 103 } 104 if len(data) != 1 { // At least one transformer should have run. 105 panic(data) 106 } 107 return data[0], nil 108 } 109 return combined, nil 110 } 111 } 112 113 func (r *registerer) GetUntransformer(h []string) (TransformFunc, error) { 114 r.RLock() 115 transformers, err := getTransformers(h, r.untransformers) 116 r.RUnlock() 117 if err != nil { 118 return nil, err 119 } 120 121 switch { 122 case len(transformers) == 0: 123 return idTransform, nil 124 case len(transformers) == 1: 125 return transformers[0], nil 126 default: 127 combined := func(scratch []byte, data [][]byte) ([]byte, error) { 128 for i := len(transformers) - 1; i >= 0; i-- { 129 tr := transformers[i] 130 out, err := tr(scratch, data) 131 if err != nil { 132 return nil, err 133 } 134 if len(data) == 0 { 135 data = [][]byte{out} 136 } else { 137 data = data[:1] 138 data[0] = out 139 } 140 } 141 if len(data) != 1 { // At least one transformer should have run. 142 panic(data) 143 } 144 return data[0], nil 145 } 146 return combined, nil 147 } 148 } 149 150 // RegisterTransformer registers a block transformer. Factory transformer should 151 // produce a transformer function. The factory is run by NewWriterV2. The 152 // transformer function is called by the writer to transform a block just before 153 // storing it in storage. 154 // 155 // The untransformer factory is the reverse of the transformer factory. It is 156 // run by NewScannerV2. The untransformer function is called by the scanner to 157 // transform data read from storage into a block. 158 // 159 // This function is usually called when the process starts. 160 // 161 // The transformer and untransformer factories, as well as the functions 162 // generated by these factories must be all thread safe. 163 // 164 // REQUIRES: A (un)transformer with the same "name" has not been registered 165 // already. 166 func RegisterTransformer(name string, transformer TransformerFactory, untransformer TransformerFactory) { 167 registry.registerTransformer(name, transformer, untransformer) 168 }