gorgonia.org/gorgonia@v0.9.17/ops/nn/utils.go (about) 1 package nnops 2 3 import ( 4 "hash/fnv" 5 "unsafe" 6 7 "github.com/chewxy/hm" 8 "github.com/pkg/errors" 9 "gorgonia.org/gorgonia" 10 "gorgonia.org/tensor" 11 ) 12 13 func simpleHash(op gorgonia.Op) uint32 { 14 h := fnv.New32a() 15 op.WriteHash(h) 16 return h.Sum32() 17 } 18 19 func checkArity(op gorgonia.Op, inputs int) error { 20 if inputs != op.Arity() && op.Arity() >= 0 { 21 return errors.Errorf("%v has an arity of %d. Got %d instead", op, op.Arity(), inputs) 22 } 23 return nil 24 } 25 26 // nomem is a dummy type that implements cudnn.Memory, but returns 0 for all the pointer. 27 // 28 // It's essentially "nil" for CUDA memory 29 type nomem struct{} 30 31 func (nomem) Uintptr() uintptr { return 0 } 32 func (nomem) Pointer() unsafe.Pointer { return nil } 33 func (nomem) IsNativelyAccessible() bool { return false } 34 35 func calcMemSize(dt tensor.Dtype, s tensor.Shape) uintptr { 36 var elemSize uintptr 37 if s.IsScalar() { 38 elemSize = 1 39 } else { 40 elemSize = uintptr(s.TotalSize()) 41 } 42 dtSize := dt.Size() 43 return elemSize * dtSize 44 } 45 46 func dtypeOf(t hm.Type) (retVal tensor.Dtype, err error) { 47 switch p := t.(type) { 48 case tensor.Dtype: 49 retVal = p 50 case gorgonia.TensorType: 51 return dtypeOf(p.Of) 52 case hm.TypeVariable: 53 err = errors.Errorf("instance %v does not have a dtype", p) 54 default: 55 err = errors.Errorf("Not yet implemented: %v %v", "dtypeOf", p) 56 return 57 } 58 59 return 60 } 61 62 func CheckConvolutionParams(pad, stride, dilation []int) error { 63 // checks 64 for _, s := range stride { 65 if s <= 0 { 66 return errors.Errorf("Cannot use strides of less than or equal 0: %v", stride) 67 } 68 } 69 70 for _, p := range pad { 71 if p < 0 { 72 return errors.Errorf("Cannot use padding of less than 0: %v", pad) 73 } 74 } 75 76 for _, d := range dilation { 77 if d <= 0 { 78 return errors.Errorf("Cannot use dilation less than or eq 0 %v", dilation) 79 } 80 } 81 return nil 82 }