github.com/leanovate/gopter@v0.2.9/bi_mapper.go (about)

     1  package gopter
     2  
     3  import (
     4  	"fmt"
     5  	"reflect"
     6  )
     7  
     8  // BiMapper is a bi-directional (or bijective) mapper of a tuple of values (up)
     9  // to another tuple of values (down).
    10  type BiMapper struct {
    11  	UpTypes    []reflect.Type
    12  	DownTypes  []reflect.Type
    13  	Downstream reflect.Value
    14  	Upstream   reflect.Value
    15  }
    16  
    17  // NewBiMapper creates a BiMapper of two functions `downstream` and its
    18  // inverse `upstream`.
    19  // That is: The return values of `downstream` must match the parameters of
    20  // `upstream` and vice versa.
    21  func NewBiMapper(downstream interface{}, upstream interface{}) *BiMapper {
    22  	downstreamVal := reflect.ValueOf(downstream)
    23  	if downstreamVal.Kind() != reflect.Func {
    24  		panic("downstream has to be a function")
    25  	}
    26  	upstreamVal := reflect.ValueOf(upstream)
    27  	if upstreamVal.Kind() != reflect.Func {
    28  		panic("upstream has to be a function")
    29  	}
    30  
    31  	downstreamType := downstreamVal.Type()
    32  	upTypes := make([]reflect.Type, downstreamType.NumIn())
    33  	for i := 0; i < len(upTypes); i++ {
    34  		upTypes[i] = downstreamType.In(i)
    35  	}
    36  	downTypes := make([]reflect.Type, downstreamType.NumOut())
    37  	for i := 0; i < len(downTypes); i++ {
    38  		downTypes[i] = downstreamType.Out(i)
    39  	}
    40  
    41  	upstreamType := upstreamVal.Type()
    42  	if len(upTypes) != upstreamType.NumOut() {
    43  		panic(fmt.Sprintf("upstream is expected to have %d return values", len(upTypes)))
    44  	}
    45  	for i, upType := range upTypes {
    46  		if upstreamType.Out(i) != upType {
    47  			panic(fmt.Sprintf("upstream has wrong return type %d: %v != %v", i, upstreamType.Out(i), upType))
    48  		}
    49  	}
    50  	if len(downTypes) != upstreamType.NumIn() {
    51  		panic(fmt.Sprintf("upstream is expected to have %d parameters", len(downTypes)))
    52  	}
    53  	for i, downType := range downTypes {
    54  		if upstreamType.In(i) != downType {
    55  			panic(fmt.Sprintf("upstream has wrong parameter type %d: %v != %v", i, upstreamType.In(i), downType))
    56  		}
    57  	}
    58  
    59  	return &BiMapper{
    60  		UpTypes:    upTypes,
    61  		DownTypes:  downTypes,
    62  		Downstream: downstreamVal,
    63  		Upstream:   upstreamVal,
    64  	}
    65  }
    66  
    67  // ConvertUp calls the Upstream function on the arguments in the down array
    68  // and returns the results.
    69  func (b *BiMapper) ConvertUp(down []interface{}) []interface{} {
    70  	if len(down) != len(b.DownTypes) {
    71  		panic(fmt.Sprintf("Expected %d values != %d", len(b.DownTypes), len(down)))
    72  	}
    73  	downVals := make([]reflect.Value, len(b.DownTypes))
    74  	for i, val := range down {
    75  		if val == nil {
    76  			downVals[i] = reflect.Zero(b.DownTypes[i])
    77  		} else {
    78  			downVals[i] = reflect.ValueOf(val)
    79  		}
    80  	}
    81  	upVals := b.Upstream.Call(downVals)
    82  	up := make([]interface{}, len(upVals))
    83  	for i, upVal := range upVals {
    84  		up[i] = upVal.Interface()
    85  	}
    86  
    87  	return up
    88  }
    89  
    90  // ConvertDown calls the Downstream function on the elements of the up array
    91  // and returns the results.
    92  func (b *BiMapper) ConvertDown(up []interface{}) []interface{} {
    93  	if len(up) != len(b.UpTypes) {
    94  		panic(fmt.Sprintf("Expected %d values != %d", len(b.UpTypes), len(up)))
    95  	}
    96  	upVals := make([]reflect.Value, len(b.UpTypes))
    97  	for i, val := range up {
    98  		if val == nil {
    99  			upVals[i] = reflect.Zero(b.UpTypes[i])
   100  		} else {
   101  			upVals[i] = reflect.ValueOf(val)
   102  		}
   103  	}
   104  	downVals := b.Downstream.Call(upVals)
   105  	down := make([]interface{}, len(downVals))
   106  	for i, downVal := range downVals {
   107  		down[i] = downVal.Interface()
   108  	}
   109  
   110  	return down
   111  }