gorgonia.org/gorgonia@v0.9.17/cuda/utils.go (about) 1 package cuda 2 3 import ( 4 "fmt" 5 "log" 6 7 "github.com/pkg/errors" 8 "gorgonia.org/tensor" 9 ) 10 11 func getDenseTensor(t tensor.Tensor) (tensor.DenseTensor, error) { 12 switch tt := t.(type) { 13 case tensor.DenseTensor: 14 return tt, nil 15 case tensor.Densor: 16 return tt.Dense(), nil 17 default: 18 return nil, errors.Errorf("Tensor %T is not a DenseTensor", t) 19 } 20 } 21 22 func handleFuncOpts(expShape tensor.Shape, expType tensor.Dtype, o tensor.DataOrder, strict bool, opts ...tensor.FuncOpt) (reuse tensor.DenseTensor, safe, toReuse, incr, same bool, err error) { 23 fo := tensor.ParseFuncOpts(opts...) 24 25 reuseT, incr := fo.IncrReuse() 26 safe = fo.Safe() 27 same = fo.Same() 28 toReuse = reuseT != nil 29 30 if toReuse { 31 if reuse, err = getDenseTensor(reuseT); err != nil { 32 err = errors.Wrapf(err, "Expected a tensor.DenseTensor") 33 return 34 } 35 36 if (strict || same) && reuse.Dtype() != expType { 37 err = errors.Errorf(typeMismatch, expType, reuse.Dtype()) 38 err = errors.Wrapf(err, "Cannot use reuse") 39 return 40 } 41 42 if reuse.DataSize() != expShape.TotalSize() && !expShape.IsScalar() { 43 log.Printf("REUSE CHECK reuse shape %v, expected Shape %v", reuse.Shape(), expShape) 44 err = errors.Errorf(shapeMismatch, reuse.Shape(), expShape) 45 err = errors.Wrapf(err, "Cannot use reuse: shape mismatch - reuse.len() %v, expShape.TotalSize() %v", reuse.DataSize(), expShape.TotalSize()) 46 return 47 } 48 49 if !incr && reuse != nil { 50 // reuse.setDataOrder(o) 51 // err = reuse.reshape(expShape...) 52 } 53 54 } 55 return 56 } 57 58 func binaryCheck(a, b tensor.Tensor) (err error) { 59 at := a.Dtype() 60 bt := b.Dtype() 61 62 switch at { 63 case tensor.Float32, tensor.Float64: 64 default: 65 return errors.Errorf("Unsupported Dtype for a: %v", at) 66 } 67 68 switch bt { 69 case tensor.Float32, tensor.Float64: 70 default: 71 return errors.Errorf("Unsupported Dtype for b: %v", bt) 72 } 73 74 if at.Kind() != bt.Kind() { 75 return errors.Errorf(typeMismatch, at, bt) 76 } 77 78 if !a.Shape().Eq(b.Shape()) { 79 log.Printf("BINARY CHECK %v %v", a.Shape(), b.Shape()) 80 return errors.Errorf(shapeMismatch, b.Shape(), a.Shape()) 81 } 82 83 if a.RequiresIterator() { 84 return errors.New("unsupported operation: a requires an iterator") 85 } 86 87 if b.RequiresIterator() { 88 return errors.New("unsupported operation: b requires an iterator") 89 } 90 return nil 91 } 92 93 func unaryCheck(a tensor.Tensor) error { 94 at := a.Dtype() 95 switch at { 96 case tensor.Float32, tensor.Float64: 97 default: 98 return errors.Errorf("Unsupported Dtype for a: %v", at) 99 } 100 101 if a.RequiresIterator() { 102 return errors.New("unsupported operation: a requires an iterator") 103 } 104 return nil 105 } 106 107 func logicalSize(s tensor.Shape) int { 108 if s.IsScalar() { 109 return 1 110 } 111 return s.TotalSize() 112 } 113 114 func constructName2(a, b tensor.Tensor, fn string) (name string) { 115 dt := a.Dtype() 116 as := a.Shape() 117 bs := b.Shape() 118 switch { 119 case as.IsScalar() && bs.IsScalar(): 120 name = fmt.Sprintf("%v.%s_ss_f%d", elemBinOpMod, fn, int(dt.Size()*8)) 121 case as.IsScalar() && !bs.IsScalar(): 122 name = fmt.Sprintf("%v.%s_sv_f%d", elemBinOpMod, fn, int(dt.Size()*8)) 123 case !as.IsScalar() && bs.IsScalar(): 124 name = fmt.Sprintf("%v.%s_vs_f%d", elemBinOpMod, fn, int(dt.Size()*8)) 125 default: 126 name = fmt.Sprintf("%v.%s_vv_f%d", elemBinOpMod, fn, int(dt.Size()*8)) 127 } 128 return 129 } 130 131 func constructName1(a tensor.Tensor, leftTensor bool, fn string) (name string) { 132 dt := a.Dtype() 133 if leftTensor { 134 name = fmt.Sprintf("%v.%s_vs_f%d", elemBinOpMod, fn, int(dt.Size()*8)) 135 } else { 136 name = fmt.Sprintf("%v.%s_sv_f%d", elemBinOpMod, fn, int(dt.Size()*8)) 137 } 138 return 139 }