gorgonia.org/gorgonia@v0.9.17/op_upsample.go (about)

     1  package gorgonia
     2  
     3  import (
     4  	"fmt"
     5  	"hash"
     6  	"hash/fnv"
     7  
     8  	"gorgonia.org/tensor"
     9  
    10  	"github.com/chewxy/hm"
    11  	"github.com/pkg/errors"
    12  )
    13  
    14  type upsampleOp struct {
    15  	stride int
    16  }
    17  
    18  func newUpsampleOp(inputShape tensor.Shape, stride int) *upsampleOp {
    19  	upsampleop := &upsampleOp{
    20  		stride: stride,
    21  	}
    22  	return upsampleop
    23  }
    24  
    25  //Upsample2D -  simply upscaling Tensor by scale factor.
    26  /*
    27  	1, 2
    28  	3, 4
    29  	converts to
    30  	1,1,2,2
    31  	1,1,2,2
    32  	3,3,4,4,
    33  	3,3,4,4,
    34  */
    35  func Upsample2D(x *Node, scale int) (*Node, error) {
    36  	if scale < 1 {
    37  		return nil, errors.Errorf("Upsample scale %v does not make sense", scale)
    38  	}
    39  	xShape := x.Shape()
    40  	op := newUpsampleOp(xShape, scale-1)
    41  	retVal, err := ApplyOp(op, x)
    42  	return retVal, err
    43  }
    44  
    45  func (op *upsampleOp) Arity() int {
    46  
    47  	return 1
    48  }
    49  func (op *upsampleOp) ReturnsPtr() bool { return false }
    50  
    51  func (op *upsampleOp) CallsExtern() bool { return false }
    52  
    53  func (op *upsampleOp) WriteHash(h hash.Hash) {
    54  	fmt.Fprintf(h, "Upsample{}(stride: (%d))", op.stride)
    55  }
    56  func (op *upsampleOp) Hashcode() uint32 {
    57  	h := fnv.New32a()
    58  	op.WriteHash(h)
    59  	return h.Sum32()
    60  }
    61  
    62  func (op *upsampleOp) String() string {
    63  	return fmt.Sprintf("Upsample{}(stride: (%d))", op.stride)
    64  }
    65  func (op *upsampleOp) InferShape(inputs ...DimSizer) (tensor.Shape, error) {
    66  	s := inputs[0].(tensor.Shape).Clone()
    67  	s[2] = s[2] * (1 + op.stride)
    68  	s[3] = s[3] * (1 + op.stride)
    69  	return s, nil
    70  }
    71  func (op *upsampleOp) Type() hm.Type {
    72  	a := hm.TypeVariable('a')
    73  	t := TensorType{Dims: 4, Of: a}
    74  	return hm.NewFnType(t, t)
    75  }
    76  func (op *upsampleOp) OverwritesInput() int { return -1 }
    77  
    78  func (op *upsampleOp) checkInput(inputs ...Value) (tensor.Tensor, error) {
    79  	if err := checkArity(op, len(inputs)); err != nil {
    80  		return nil, err
    81  	}
    82  	var in tensor.Tensor
    83  	var ok bool
    84  	if in, ok = inputs[0].(tensor.Tensor); !ok {
    85  		return nil, errors.Errorf("Expected input to be a tensor")
    86  	}
    87  
    88  	if in.Shape().Dims() != 4 {
    89  		return nil, errors.Errorf("Expected input to have 4 dimensions")
    90  	}
    91  	return in, nil
    92  }
    93  
    94  func (op *upsampleOp) Do(inputs ...Value) (retVal Value, err error) {
    95  	var in tensor.Tensor
    96  	if in, err = op.checkInput(inputs...); err != nil {
    97  		return nil, err
    98  	}
    99  	inShp := in.Shape()
   100  	b, c, h, w := inShp[0], inShp[1], inShp[2], inShp[3]
   101  
   102  	out := tensor.New(tensor.Of(in.Dtype()), tensor.WithShape(b, c, h*(1+op.stride), w*(1+op.stride)), tensor.WithEngine(in.Engine()))
   103  	for bi := 0; bi < b; bi++ {
   104  		for ci := 0; ci < c; ci++ {
   105  			for hi := 0; hi < h; hi++ {
   106  				for wi := 0; wi < w; wi++ {
   107  					val, err := in.At(bi, ci, hi, wi)
   108  					if err != nil {
   109  						return nil, errors.Errorf("Error accessing input data at [%v, %v, %v, %v]", bi, ci, hi, wi)
   110  					}
   111  					hout := hi * (op.stride + 1)
   112  					wout := wi * (op.stride + 1)
   113  					for shi := 0; shi <= op.stride; shi++ {
   114  						for swi := 0; swi <= op.stride; swi++ {
   115  							out.SetAt(val, bi, ci, hout+shi, wout+swi)
   116  						}
   117  					}
   118  				}
   119  			}
   120  		}
   121  	}
   122  
   123  	return out, nil
   124  }
   125  
   126  func (op *upsampleOp) DiffWRT(inputs int) []bool { return []bool{true} }
   127  
   128  func (op *upsampleOp) SymDiff(inputs Nodes, output, grad *Node) (retVal Nodes, err error) {
   129  	if err = checkArity(op, len(inputs)); err != nil {
   130  		return
   131  	}
   132  	input := inputs[0]
   133  
   134  	var op2 upsampleOp
   135  	op2 = *op
   136  	diff := &upsampleDiffOp{op2}
   137  
   138  	var ret *Node
   139  	if ret, err = ApplyOp(diff, input, output, grad); err != nil {
   140  		return nil, err
   141  	}
   142  	return Nodes{ret}, nil
   143  }
   144  
   145  type upsampleDiffOp struct {
   146  	upsampleOp
   147  }
   148  
   149  func (op *upsampleDiffOp) Arity() int { return 3 }
   150  
   151  func (op *upsampleDiffOp) Type() hm.Type {
   152  	a := hm.TypeVariable('a')
   153  	t := TensorType{Dims: 4, Of: a}
   154  	return hm.NewFnType(t, t, t, t)
   155  }
   156  
   157  func (op *upsampleDiffOp) InferShape(inputs ...DimSizer) (tensor.Shape, error) {
   158  	return inputs[0].(tensor.Shape).Clone(), nil
   159  }
   160  
   161  func (op *upsampleDiffOp) checkInput(inputs ...Value) (in, pooled, pooledGrad tensor.Tensor, err error) {
   162  	if err = checkArity(op, len(inputs)); err != nil {
   163  		return
   164  	}
   165  
   166  	var ok bool
   167  	if in, ok = inputs[0].(tensor.Tensor); !ok {
   168  		err = errors.Errorf("Expected input to be a tensor")
   169  		return
   170  	}
   171  	if in.Shape().Dims() != 4 {
   172  		err = errors.Errorf("Expected input to have 4 dimensions")
   173  		return
   174  	}
   175  
   176  	if pooled, ok = inputs[1].(tensor.Tensor); !ok {
   177  		err = errors.Errorf("Expected pooled to be a tensor")
   178  		return
   179  	}
   180  
   181  	if pooledGrad, ok = inputs[2].(tensor.Tensor); !ok {
   182  		err = errors.Errorf("Expected pooledGrad to be a tensor")
   183  		return
   184  	}
   185  	return
   186  }
   187  
   188  func (op *upsampleDiffOp) Do(inputs ...Value) (retVal Value, err error) {
   189  	var gradIn tensor.Tensor
   190  	in, pooled, pooledGrad, err := op.checkInput(inputs...)
   191  	if err != nil {
   192  		return nil, err
   193  	}
   194  	insh := in.Shape()
   195  	gradIn = tensor.New(tensor.Of(in.Dtype()), tensor.WithShape(in.Shape().Clone()...), tensor.WithEngine(in.Engine()))
   196  	b, c, h, w := insh[0], insh[1], insh[2], insh[3]
   197  	for bi := 0; bi < b; bi++ {
   198  		for ci := 0; ci < c; ci++ {
   199  			for hi := 0; hi < h; hi++ {
   200  				for wi := 0; wi < w; wi++ {
   201  					summ := 0.
   202  					for sh := 0; sh <= op.stride; sh++ {
   203  						for sw := 0; sw <= op.stride; sw++ {
   204  							val, err := pooledGrad.At(bi, ci, hi*(op.stride+1)+sh, wi*(op.stride+1)+sw)
   205  							if err != nil {
   206  								return nil, errors.Errorf("Error accessing input data at [%v, %v, %v, %v]", bi, ci, hi, wi)
   207  							}
   208  							if pooled.Dtype() == tensor.Float32 {
   209  								summ += float64(val.(float32))
   210  							} else if pooled.Dtype() == tensor.Float64 {
   211  								summ += val.(float64)
   212  							}
   213  						}
   214  					}
   215  					if pooled.Dtype() == tensor.Float32 {
   216  						gradIn.SetAt(float32(summ), bi, ci, hi, wi)
   217  					}
   218  					if pooled.Dtype() == tensor.Float64 {
   219  						gradIn.SetAt(summ, bi, ci, hi, wi)
   220  					}
   221  				}
   222  			}
   223  		}
   224  	}
   225  	return gradIn, nil
   226  }