gorgonia.org/gorgonia@v0.9.17/ops/nn/utils.go (about)

     1  package nnops
     2  
     3  import (
     4  	"hash/fnv"
     5  	"unsafe"
     6  
     7  	"github.com/chewxy/hm"
     8  	"github.com/pkg/errors"
     9  	"gorgonia.org/gorgonia"
    10  	"gorgonia.org/tensor"
    11  )
    12  
    13  func simpleHash(op gorgonia.Op) uint32 {
    14  	h := fnv.New32a()
    15  	op.WriteHash(h)
    16  	return h.Sum32()
    17  }
    18  
    19  func checkArity(op gorgonia.Op, inputs int) error {
    20  	if inputs != op.Arity() && op.Arity() >= 0 {
    21  		return errors.Errorf("%v has an arity of %d. Got %d instead", op, op.Arity(), inputs)
    22  	}
    23  	return nil
    24  }
    25  
    26  // nomem is a dummy type that implements cudnn.Memory, but returns 0 for all the pointer.
    27  //
    28  // It's essentially "nil" for CUDA memory
    29  type nomem struct{}
    30  
    31  func (nomem) Uintptr() uintptr           { return 0 }
    32  func (nomem) Pointer() unsafe.Pointer    { return nil }
    33  func (nomem) IsNativelyAccessible() bool { return false }
    34  
    35  func calcMemSize(dt tensor.Dtype, s tensor.Shape) uintptr {
    36  	var elemSize uintptr
    37  	if s.IsScalar() {
    38  		elemSize = 1
    39  	} else {
    40  		elemSize = uintptr(s.TotalSize())
    41  	}
    42  	dtSize := dt.Size()
    43  	return elemSize * dtSize
    44  }
    45  
    46  func dtypeOf(t hm.Type) (retVal tensor.Dtype, err error) {
    47  	switch p := t.(type) {
    48  	case tensor.Dtype:
    49  		retVal = p
    50  	case gorgonia.TensorType:
    51  		return dtypeOf(p.Of)
    52  	case hm.TypeVariable:
    53  		err = errors.Errorf("instance %v does not have a dtype", p)
    54  	default:
    55  		err = errors.Errorf("Not yet implemented: %v %v", "dtypeOf", p)
    56  		return
    57  	}
    58  
    59  	return
    60  }
    61  
    62  func CheckConvolutionParams(pad, stride, dilation []int) error {
    63  	// checks
    64  	for _, s := range stride {
    65  		if s <= 0 {
    66  			return errors.Errorf("Cannot use strides of less than or equal 0: %v", stride)
    67  		}
    68  	}
    69  
    70  	for _, p := range pad {
    71  		if p < 0 {
    72  			return errors.Errorf("Cannot use padding of less than 0: %v", pad)
    73  		}
    74  	}
    75  
    76  	for _, d := range dilation {
    77  		if d <= 0 {
    78  			return errors.Errorf("Cannot use dilation less than or eq 0 %v", dilation)
    79  		}
    80  	}
    81  	return nil
    82  }