github.com/gorgonia/agogo@v0.1.1/dualnet/ermahagerdmonards.go (about)

     1  package dual
     2  
     3  import (
     4  	"fmt"
     5  
     6  	"github.com/pkg/errors"
     7  	G "gorgonia.org/gorgonia"
     8  	"gorgonia.org/gorgonia/ops/nn"
     9  	"gorgonia.org/tensor"
    10  )
    11  
    12  type maebe struct {
    13  	err error
    14  }
    15  
    16  type batchNormOp interface {
    17  	SetTraining()
    18  	SetTesting()
    19  	Reset() error
    20  }
    21  
    22  // generic monad... may be useful
    23  func (m *maebe) do(f func() (*G.Node, error)) (retVal *G.Node) {
    24  	if m.err != nil {
    25  		return nil
    26  	}
    27  	if retVal, m.err = f(); m.err != nil {
    28  		m.err = errors.WithStack(m.err)
    29  	}
    30  	return
    31  }
    32  
    33  func (m *maebe) conv(input *G.Node, filterCount, size int, name string) (retVal *G.Node) {
    34  	if m.err != nil {
    35  		return nil
    36  	}
    37  	featureCount := input.Shape()[1]
    38  	padding := findPadding(input.Shape()[2], input.Shape()[3], size, size)
    39  	filter := G.NewTensor(input.Graph(), Float, 4, G.WithShape(filterCount, featureCount, size, size), G.WithName("Filter"+name), G.WithInit(G.GlorotU(1.0)))
    40  
    41  	// assume well behaved images
    42  	if retVal, m.err = nnops.Conv2d(input, filter, []int{size, size}, padding, []int{1, 1}, []int{1, 1}); m.err != nil {
    43  		m.err = errors.WithStack(m.err)
    44  	}
    45  	return
    46  }
    47  
    48  func (m *maebe) batchnorm(input *G.Node) (retVal *G.Node, retOp batchNormOp) {
    49  	if m.err != nil {
    50  		return nil, nil
    51  	}
    52  	// note: the scale and biases will still be created
    53  	// and they will still be backpropagated
    54  	if retVal, _, _, retOp, m.err = nnops.BatchNorm(input, nil, nil, 0.997, 1e-5); m.err != nil {
    55  		m.err = errors.WithStack(m.err)
    56  	}
    57  	return
    58  }
    59  
    60  func (m *maebe) res(input *G.Node, filterCount int, name string) (*G.Node, batchNormOp) {
    61  	convolved := m.conv(input, filterCount, 3, name)
    62  	normalized, op := m.batchnorm(convolved)
    63  	retVal := m.rectify(normalized)
    64  	return retVal, op
    65  }
    66  
    67  func (m *maebe) share(input *G.Node, filterCount, layer int) (*G.Node, batchNormOp, batchNormOp) {
    68  	layer1, l1Op := m.res(input, filterCount, fmt.Sprintf("Layer1 of Shared Layer %d", layer))
    69  	layer2, l2Op := m.res(input, filterCount, fmt.Sprintf("Layer2 of Shared Layer %d", layer))
    70  	added := m.do(func() (*G.Node, error) { return G.Add(layer1, layer2) })
    71  	retVal := m.rectify(added)
    72  	return retVal, l1Op, l2Op
    73  }
    74  
    75  func (m *maebe) linear(input *G.Node, units int, name string) *G.Node {
    76  	if m.err != nil {
    77  		return nil
    78  	}
    79  	// figure out size
    80  	w := G.NewTensor(input.Graph(), Float, 2, G.WithShape(input.Shape()[1], units), G.WithInit(G.GlorotN(1.0)), G.WithName(name+"_w"))
    81  	xw := m.do(func() (*G.Node, error) { return G.Mul(input, w) })
    82  	b := G.NewTensor(xw.Graph(), Float, xw.Shape().Dims(), G.WithShape(xw.Shape().Clone()...), G.WithName(name+"_b"), G.WithInit(G.Zeroes()))
    83  	return m.do(func() (*G.Node, error) { return G.Add(xw, b) })
    84  }
    85  
    86  func (m *maebe) rectify(input *G.Node) (retVal *G.Node) {
    87  	if m.err != nil {
    88  		return nil
    89  	}
    90  	if retVal, m.err = nnops.Rectify(input); m.err != nil {
    91  		m.err = errors.WithStack(m.err)
    92  	}
    93  	return
    94  }
    95  
    96  func (m *maebe) reshape(input *G.Node, to tensor.Shape) (retVal *G.Node) {
    97  	if m.err != nil {
    98  		return nil
    99  	}
   100  	if retVal, m.err = G.Reshape(input, to); m.err != nil {
   101  		m.err = errors.WithStack(m.err)
   102  	}
   103  	return
   104  }
   105  
   106  func (m *maebe) xent(output, target *G.Node) (retVal *G.Node) {
   107  	var one *G.Node
   108  	switch Float {
   109  	case G.Float32:
   110  		one = G.NewConstant(float32(1))
   111  	case G.Float64:
   112  		one = G.NewConstant(float64(1))
   113  	}
   114  	var omy, omout *G.Node
   115  	if omy, m.err = G.Sub(one, target); m.err != nil {
   116  		m.err = errors.WithStack(m.err)
   117  		return nil
   118  	}
   119  
   120  	if omout, m.err = G.Sub(one, output); m.err != nil {
   121  		m.err = errors.WithStack(m.err)
   122  		return nil
   123  	}
   124  
   125  	var fst, snd *G.Node
   126  	if fst, m.err = G.HadamardProd(target, output); m.err != nil {
   127  		m.err = errors.WithStack(m.err)
   128  		return nil
   129  	}
   130  	if snd, m.err = G.HadamardProd(omy, omout); m.err != nil {
   131  		m.err = errors.WithStack(m.err)
   132  		return nil
   133  	}
   134  
   135  	if retVal, m.err = G.Add(fst, snd); m.err != nil {
   136  		m.err = errors.WithStack(m.err)
   137  		return nil
   138  	}
   139  	if retVal, m.err = G.Neg(retVal); m.err != nil {
   140  		m.err = errors.WithStack(m.err)
   141  		return nil
   142  	}
   143  	if retVal, m.err = G.Mean(retVal); m.err != nil {
   144  		m.err = errors.WithStack(m.err)
   145  	}
   146  	return
   147  }
   148  
   149  func findPadding(inputX, inputY, kernelX, kernelY int) []int {
   150  	return []int{
   151  		(inputX - 1 - inputX + kernelX) / 2,
   152  		(inputY - 1 - inputY + kernelY) / 2,
   153  	}
   154  }