github.com/grailbio/base@v0.0.11/recordio/registry.go (about)

     1  // Copyright 2018 GRAIL, Inc. All rights reserved.
     2  // Use of this source code is governed by the Apache-2.0
     3  // license that can be found in the LICENSE file.
     4  
     5  package recordio
     6  
     7  import (
     8  	"fmt"
     9  	"strings"
    10  	"sync"
    11  
    12  	"github.com/grailbio/base/recordio/recordioiov"
    13  	"github.com/pkg/errors"
    14  )
    15  
    16  // TransformerFactory is a function that creates a new TransformerFunc given an
    17  // optional config string.
    18  type TransformerFactory func(config string) (TransformFunc, error)
    19  
    20  type registerer struct {
    21  	sync.RWMutex
    22  	transformers   map[string]TransformerFactory
    23  	untransformers map[string]TransformerFactory
    24  }
    25  
    26  var registry = &registerer{
    27  	transformers:   make(map[string]TransformerFactory),
    28  	untransformers: make(map[string]TransformerFactory),
    29  }
    30  
    31  func idTransform(scratch []byte, in [][]byte) ([]byte, error) {
    32  	out := recordioiov.Slice(scratch, recordioiov.TotalBytes(in))
    33  	n := 0
    34  	for _, b := range in {
    35  		copy(out[n:], b)
    36  		n += len(b)
    37  	}
    38  	return out, nil
    39  }
    40  
    41  func (r *registerer) registerTransformer(name string, t TransformerFactory, u TransformerFactory) {
    42  	r.Lock()
    43  	defer r.Unlock()
    44  	if _, ok := r.transformers[name]; ok {
    45  		panic(fmt.Sprintf("Transformer %s already registered", name))
    46  	}
    47  	r.transformers[name] = t
    48  	r.untransformers[name] = u
    49  }
    50  
    51  func getTransformers(h []string, factory map[string]TransformerFactory) ([]TransformFunc, error) {
    52  	var transformers []TransformFunc
    53  	for _, str := range h {
    54  		toks := strings.SplitN(str, " ", 2)
    55  		name := toks[0]
    56  		f, ok := factory[name]
    57  		if !ok {
    58  			return nil, errors.Errorf("Transformer %s not found", str)
    59  		}
    60  
    61  		var config string
    62  		if len(toks) > 1 {
    63  			config = toks[1]
    64  		}
    65  		var tr TransformFunc
    66  		var err error
    67  		if tr, err = f(config); err != nil {
    68  			return nil, err
    69  		}
    70  		transformers = append(transformers, tr)
    71  	}
    72  	return transformers, nil
    73  }
    74  
    75  func (r *registerer) getTransformer(h []string) (TransformFunc, error) {
    76  	r.RLock()
    77  	transformers, err := getTransformers(h, r.transformers)
    78  	r.RUnlock()
    79  	if err != nil {
    80  		return nil, err
    81  	}
    82  
    83  	switch {
    84  	case len(transformers) == 0:
    85  		return idTransform, nil
    86  	case len(transformers) == 1:
    87  		return transformers[0], nil
    88  	default:
    89  		combined := func(scratch []byte, data [][]byte) ([]byte, error) {
    90  			for _, tr := range transformers {
    91  				out, err := tr(scratch, data)
    92  				if err != nil {
    93  					return nil, err
    94  				}
    95  				if len(data) == 0 {
    96  					data = [][]byte{out}
    97  				} else {
    98  					data = data[:1]
    99  					data[0] = out
   100  				}
   101  				scratch = nil
   102  				// TODO(saito) Maybe reuse one of data[] as scratch?
   103  			}
   104  			if len(data) != 1 { // At least one transformer should have run.
   105  				panic(data)
   106  			}
   107  			return data[0], nil
   108  		}
   109  		return combined, nil
   110  	}
   111  }
   112  
   113  func (r *registerer) GetUntransformer(h []string) (TransformFunc, error) {
   114  	r.RLock()
   115  	transformers, err := getTransformers(h, r.untransformers)
   116  	r.RUnlock()
   117  	if err != nil {
   118  		return nil, err
   119  	}
   120  
   121  	switch {
   122  	case len(transformers) == 0:
   123  		return idTransform, nil
   124  	case len(transformers) == 1:
   125  		return transformers[0], nil
   126  	default:
   127  		combined := func(scratch []byte, data [][]byte) ([]byte, error) {
   128  			for i := len(transformers) - 1; i >= 0; i-- {
   129  				tr := transformers[i]
   130  				out, err := tr(scratch, data)
   131  				if err != nil {
   132  					return nil, err
   133  				}
   134  				if len(data) == 0 {
   135  					data = [][]byte{out}
   136  				} else {
   137  					data = data[:1]
   138  					data[0] = out
   139  				}
   140  			}
   141  			if len(data) != 1 { // At least one transformer should have run.
   142  				panic(data)
   143  			}
   144  			return data[0], nil
   145  		}
   146  		return combined, nil
   147  	}
   148  }
   149  
   150  // RegisterTransformer registers a block transformer. Factory transformer should
   151  // produce a transformer function. The factory is run by NewWriterV2.  The
   152  // transformer function is called by the writer to transform a block just before
   153  // storing it in storage.
   154  //
   155  // The untransformer factory is the reverse of the transformer factory. It is
   156  // run by NewScannerV2. The untransformer function is called by the scanner to
   157  // transform data read from storage into a block.
   158  //
   159  // This function is usually called when the process starts.
   160  //
   161  // The transformer and untransformer factories, as well as the functions
   162  // generated by these factories must be all thread safe.
   163  //
   164  // REQUIRES: A (un)transformer with the same "name" has not been registered
   165  // already.
   166  func RegisterTransformer(name string, transformer TransformerFactory, untransformer TransformerFactory) {
   167  	registry.registerTransformer(name, transformer, untransformer)
   168  }