gorgonia.org/gorgonia@v0.9.17/op_upsample.go (about) 1 package gorgonia 2 3 import ( 4 "fmt" 5 "hash" 6 "hash/fnv" 7 8 "gorgonia.org/tensor" 9 10 "github.com/chewxy/hm" 11 "github.com/pkg/errors" 12 ) 13 14 type upsampleOp struct { 15 stride int 16 } 17 18 func newUpsampleOp(inputShape tensor.Shape, stride int) *upsampleOp { 19 upsampleop := &upsampleOp{ 20 stride: stride, 21 } 22 return upsampleop 23 } 24 25 //Upsample2D - simply upscaling Tensor by scale factor. 26 /* 27 1, 2 28 3, 4 29 converts to 30 1,1,2,2 31 1,1,2,2 32 3,3,4,4, 33 3,3,4,4, 34 */ 35 func Upsample2D(x *Node, scale int) (*Node, error) { 36 if scale < 1 { 37 return nil, errors.Errorf("Upsample scale %v does not make sense", scale) 38 } 39 xShape := x.Shape() 40 op := newUpsampleOp(xShape, scale-1) 41 retVal, err := ApplyOp(op, x) 42 return retVal, err 43 } 44 45 func (op *upsampleOp) Arity() int { 46 47 return 1 48 } 49 func (op *upsampleOp) ReturnsPtr() bool { return false } 50 51 func (op *upsampleOp) CallsExtern() bool { return false } 52 53 func (op *upsampleOp) WriteHash(h hash.Hash) { 54 fmt.Fprintf(h, "Upsample{}(stride: (%d))", op.stride) 55 } 56 func (op *upsampleOp) Hashcode() uint32 { 57 h := fnv.New32a() 58 op.WriteHash(h) 59 return h.Sum32() 60 } 61 62 func (op *upsampleOp) String() string { 63 return fmt.Sprintf("Upsample{}(stride: (%d))", op.stride) 64 } 65 func (op *upsampleOp) InferShape(inputs ...DimSizer) (tensor.Shape, error) { 66 s := inputs[0].(tensor.Shape).Clone() 67 s[2] = s[2] * (1 + op.stride) 68 s[3] = s[3] * (1 + op.stride) 69 return s, nil 70 } 71 func (op *upsampleOp) Type() hm.Type { 72 a := hm.TypeVariable('a') 73 t := TensorType{Dims: 4, Of: a} 74 return hm.NewFnType(t, t) 75 } 76 func (op *upsampleOp) OverwritesInput() int { return -1 } 77 78 func (op *upsampleOp) checkInput(inputs ...Value) (tensor.Tensor, error) { 79 if err := checkArity(op, len(inputs)); err != nil { 80 return nil, err 81 } 82 var in tensor.Tensor 83 var ok bool 84 if in, ok = inputs[0].(tensor.Tensor); !ok { 85 return nil, errors.Errorf("Expected input to be a tensor") 86 } 87 88 if in.Shape().Dims() != 4 { 89 return nil, errors.Errorf("Expected input to have 4 dimensions") 90 } 91 return in, nil 92 } 93 94 func (op *upsampleOp) Do(inputs ...Value) (retVal Value, err error) { 95 var in tensor.Tensor 96 if in, err = op.checkInput(inputs...); err != nil { 97 return nil, err 98 } 99 inShp := in.Shape() 100 b, c, h, w := inShp[0], inShp[1], inShp[2], inShp[3] 101 102 out := tensor.New(tensor.Of(in.Dtype()), tensor.WithShape(b, c, h*(1+op.stride), w*(1+op.stride)), tensor.WithEngine(in.Engine())) 103 for bi := 0; bi < b; bi++ { 104 for ci := 0; ci < c; ci++ { 105 for hi := 0; hi < h; hi++ { 106 for wi := 0; wi < w; wi++ { 107 val, err := in.At(bi, ci, hi, wi) 108 if err != nil { 109 return nil, errors.Errorf("Error accessing input data at [%v, %v, %v, %v]", bi, ci, hi, wi) 110 } 111 hout := hi * (op.stride + 1) 112 wout := wi * (op.stride + 1) 113 for shi := 0; shi <= op.stride; shi++ { 114 for swi := 0; swi <= op.stride; swi++ { 115 out.SetAt(val, bi, ci, hout+shi, wout+swi) 116 } 117 } 118 } 119 } 120 } 121 } 122 123 return out, nil 124 } 125 126 func (op *upsampleOp) DiffWRT(inputs int) []bool { return []bool{true} } 127 128 func (op *upsampleOp) SymDiff(inputs Nodes, output, grad *Node) (retVal Nodes, err error) { 129 if err = checkArity(op, len(inputs)); err != nil { 130 return 131 } 132 input := inputs[0] 133 134 var op2 upsampleOp 135 op2 = *op 136 diff := &upsampleDiffOp{op2} 137 138 var ret *Node 139 if ret, err = ApplyOp(diff, input, output, grad); err != nil { 140 return nil, err 141 } 142 return Nodes{ret}, nil 143 } 144 145 type upsampleDiffOp struct { 146 upsampleOp 147 } 148 149 func (op *upsampleDiffOp) Arity() int { return 3 } 150 151 func (op *upsampleDiffOp) Type() hm.Type { 152 a := hm.TypeVariable('a') 153 t := TensorType{Dims: 4, Of: a} 154 return hm.NewFnType(t, t, t, t) 155 } 156 157 func (op *upsampleDiffOp) InferShape(inputs ...DimSizer) (tensor.Shape, error) { 158 return inputs[0].(tensor.Shape).Clone(), nil 159 } 160 161 func (op *upsampleDiffOp) checkInput(inputs ...Value) (in, pooled, pooledGrad tensor.Tensor, err error) { 162 if err = checkArity(op, len(inputs)); err != nil { 163 return 164 } 165 166 var ok bool 167 if in, ok = inputs[0].(tensor.Tensor); !ok { 168 err = errors.Errorf("Expected input to be a tensor") 169 return 170 } 171 if in.Shape().Dims() != 4 { 172 err = errors.Errorf("Expected input to have 4 dimensions") 173 return 174 } 175 176 if pooled, ok = inputs[1].(tensor.Tensor); !ok { 177 err = errors.Errorf("Expected pooled to be a tensor") 178 return 179 } 180 181 if pooledGrad, ok = inputs[2].(tensor.Tensor); !ok { 182 err = errors.Errorf("Expected pooledGrad to be a tensor") 183 return 184 } 185 return 186 } 187 188 func (op *upsampleDiffOp) Do(inputs ...Value) (retVal Value, err error) { 189 var gradIn tensor.Tensor 190 in, pooled, pooledGrad, err := op.checkInput(inputs...) 191 if err != nil { 192 return nil, err 193 } 194 insh := in.Shape() 195 gradIn = tensor.New(tensor.Of(in.Dtype()), tensor.WithShape(in.Shape().Clone()...), tensor.WithEngine(in.Engine())) 196 b, c, h, w := insh[0], insh[1], insh[2], insh[3] 197 for bi := 0; bi < b; bi++ { 198 for ci := 0; ci < c; ci++ { 199 for hi := 0; hi < h; hi++ { 200 for wi := 0; wi < w; wi++ { 201 summ := 0. 202 for sh := 0; sh <= op.stride; sh++ { 203 for sw := 0; sw <= op.stride; sw++ { 204 val, err := pooledGrad.At(bi, ci, hi*(op.stride+1)+sh, wi*(op.stride+1)+sw) 205 if err != nil { 206 return nil, errors.Errorf("Error accessing input data at [%v, %v, %v, %v]", bi, ci, hi, wi) 207 } 208 if pooled.Dtype() == tensor.Float32 { 209 summ += float64(val.(float32)) 210 } else if pooled.Dtype() == tensor.Float64 { 211 summ += val.(float64) 212 } 213 } 214 } 215 if pooled.Dtype() == tensor.Float32 { 216 gradIn.SetAt(float32(summ), bi, ci, hi, wi) 217 } 218 if pooled.Dtype() == tensor.Float64 { 219 gradIn.SetAt(summ, bi, ci, hi, wi) 220 } 221 } 222 } 223 } 224 } 225 return gradIn, nil 226 }