github.com/wzzhu/tensor@v0.9.24/dense_apply_test.go (about)

     1  package tensor
     2  
     3  import (
     4  	"math/rand"
     5  	"testing"
     6  	"testing/quick"
     7  	"time"
     8  	"unsafe"
     9  )
    10  
    11  func getMutateVal(dt Dtype) interface{} {
    12  	switch dt {
    13  	case Int:
    14  		return int(1)
    15  	case Int8:
    16  		return int8(1)
    17  	case Int16:
    18  		return int16(1)
    19  	case Int32:
    20  		return int32(1)
    21  	case Int64:
    22  		return int64(1)
    23  	case Uint:
    24  		return uint(1)
    25  	case Uint8:
    26  		return uint8(1)
    27  	case Uint16:
    28  		return uint16(1)
    29  	case Uint32:
    30  		return uint32(1)
    31  	case Uint64:
    32  		return uint64(1)
    33  	case Float32:
    34  		return float32(1)
    35  	case Float64:
    36  		return float64(1)
    37  	case Complex64:
    38  		var c complex64 = 1
    39  		return c
    40  	case Complex128:
    41  		var c complex128 = 1
    42  		return c
    43  	case Bool:
    44  		return true
    45  	case String:
    46  		return "Hello World"
    47  	case Uintptr:
    48  		return uintptr(0xdeadbeef)
    49  	case UnsafePointer:
    50  		return unsafe.Pointer(uintptr(0xdeadbeef))
    51  	}
    52  	return nil
    53  }
    54  
    55  func getMutateFn(dt Dtype) interface{} {
    56  	switch dt {
    57  	case Int:
    58  		return mutateI
    59  	case Int8:
    60  		return mutateI8
    61  	case Int16:
    62  		return mutateI16
    63  	case Int32:
    64  		return mutateI32
    65  	case Int64:
    66  		return mutateI64
    67  	case Uint:
    68  		return mutateU
    69  	case Uint8:
    70  		return mutateU8
    71  	case Uint16:
    72  		return mutateU16
    73  	case Uint32:
    74  		return mutateU32
    75  	case Uint64:
    76  		return mutateU64
    77  	case Float32:
    78  		return mutateF32
    79  	case Float64:
    80  		return mutateF64
    81  	case Complex64:
    82  		return mutateC64
    83  	case Complex128:
    84  		return mutateC128
    85  	case Bool:
    86  		return mutateB
    87  	case String:
    88  		return mutateStr
    89  	case Uintptr:
    90  		return mutateUintptr
    91  	case UnsafePointer:
    92  		return mutateUnsafePointer
    93  	}
    94  	return nil
    95  }
    96  
    97  func TestDense_Apply(t *testing.T) {
    98  	var r *rand.Rand
    99  	mut := func(q *Dense) bool {
   100  		var mutVal interface{}
   101  		if mutVal = getMutateVal(q.Dtype()); mutVal == nil {
   102  			return true // we'll temporarily skip those we cannot mutate/get a mutation value
   103  		}
   104  		var fn interface{}
   105  		if fn = getMutateFn(q.Dtype()); fn == nil {
   106  			return true // we'll skip those that we cannot mutate
   107  		}
   108  
   109  		we, eqFail := willerr(q, nil, nil)
   110  		_, ok := q.Engine().(Mapper)
   111  		we = we || !ok
   112  
   113  		a := q.Clone().(*Dense)
   114  		correct := q.Clone().(*Dense)
   115  		correct.Memset(mutVal)
   116  		ret, err := a.Apply(fn)
   117  		if err, retEarly := qcErrCheck(t, "Apply", a, nil, we, err); retEarly {
   118  			if err != nil {
   119  				return false
   120  			}
   121  			return true
   122  		}
   123  		if !qcEqCheck(t, a.Dtype(), eqFail, correct.Data(), ret.Data()) {
   124  			return false
   125  		}
   126  
   127  		// wrong fn type/illogical values
   128  		if _, err = a.Apply(getMutateFn); err == nil {
   129  			t.Error("Expected an error")
   130  			return false
   131  		}
   132  		return true
   133  	}
   134  	r = rand.New(rand.NewSource(time.Now().UnixNano()))
   135  	if err := quick.Check(mut, &quick.Config{Rand: r}); err != nil {
   136  		t.Errorf("Applying mutation function failed %v", err)
   137  	}
   138  }
   139  
   140  func TestDense_Apply_unsafe(t *testing.T) {
   141  	var r *rand.Rand
   142  	mut := func(q *Dense) bool {
   143  		var mutVal interface{}
   144  		if mutVal = getMutateVal(q.Dtype()); mutVal == nil {
   145  			return true // we'll temporarily skip those we cannot mutate/get a mutation value
   146  		}
   147  		var fn interface{}
   148  		if fn = getMutateFn(q.Dtype()); fn == nil {
   149  			return true // we'll skip those that we cannot mutate
   150  		}
   151  
   152  		we, eqFail := willerr(q, nil, nil)
   153  		_, ok := q.Engine().(Mapper)
   154  		we = we || !ok
   155  
   156  		a := q.Clone().(*Dense)
   157  		correct := q.Clone().(*Dense)
   158  		correct.Memset(mutVal)
   159  		ret, err := a.Apply(fn, UseUnsafe())
   160  		if err, retEarly := qcErrCheck(t, "Apply", a, nil, we, err); retEarly {
   161  			if err != nil {
   162  				return false
   163  			}
   164  			return true
   165  		}
   166  		if !qcEqCheck(t, a.Dtype(), eqFail, correct.Data(), ret.Data()) {
   167  			return false
   168  		}
   169  		if ret != a {
   170  			t.Error("Expected ret == correct (Unsafe option was used)")
   171  			return false
   172  		}
   173  		return true
   174  	}
   175  	r = rand.New(rand.NewSource(time.Now().UnixNano()))
   176  	if err := quick.Check(mut, &quick.Config{Rand: r}); err != nil {
   177  		t.Errorf("Applying mutation function failed %v", err)
   178  	}
   179  }
   180  
   181  func TestDense_Apply_reuse(t *testing.T) {
   182  	var r *rand.Rand
   183  	mut := func(q *Dense) bool {
   184  		var mutVal interface{}
   185  		if mutVal = getMutateVal(q.Dtype()); mutVal == nil {
   186  			return true // we'll temporarily skip those we cannot mutate/get a mutation value
   187  		}
   188  		var fn interface{}
   189  		if fn = getMutateFn(q.Dtype()); fn == nil {
   190  			return true // we'll skip those that we cannot mutate
   191  		}
   192  
   193  		we, eqFail := willerr(q, nil, nil)
   194  		_, ok := q.Engine().(Mapper)
   195  		we = we || !ok
   196  
   197  		a := q.Clone().(*Dense)
   198  		reuse := q.Clone().(*Dense)
   199  		reuse.Zero()
   200  		correct := q.Clone().(*Dense)
   201  		correct.Memset(mutVal)
   202  		ret, err := a.Apply(fn, WithReuse(reuse))
   203  		if err, retEarly := qcErrCheck(t, "Apply", a, nil, we, err); retEarly {
   204  			if err != nil {
   205  				return false
   206  			}
   207  			return true
   208  		}
   209  		if !qcEqCheck(t, a.Dtype(), eqFail, correct.Data(), ret.Data()) {
   210  			return false
   211  		}
   212  		if ret != reuse {
   213  			t.Error("Expected ret == correct (Unsafe option was used)")
   214  			return false
   215  		}
   216  		return true
   217  	}
   218  	r = rand.New(rand.NewSource(time.Now().UnixNano()))
   219  	if err := quick.Check(mut, &quick.Config{Rand: r}); err != nil {
   220  		t.Errorf("Applying mutation function failed %v", err)
   221  	}
   222  }