gorgonia.org/gorgonia@v0.9.17/nn.go (about) 1 package gorgonia 2 3 import ( 4 "fmt" 5 6 "github.com/pkg/errors" 7 "gorgonia.org/gorgonia/internal/encoding" 8 "gorgonia.org/tensor" 9 ) 10 11 // BinaryXent is a convenience function for doing binary crossentropy stuff. 12 // The formula is as below: 13 // -(y * logprob) + (1-y)(1-logprob) 14 func BinaryXent(output, target *Node) (retVal *Node, err error) { 15 var one *Node 16 var logO, omt, omo, tLogO *Node 17 18 // which constant one to use? 19 var dt tensor.Dtype 20 if dt, err = dtypeOf(output.t); err != nil { 21 return nil, errors.Wrapf(err, dtypeExtractionFail, output.t) 22 } 23 24 switch dt { 25 case Float64: 26 one = onef64 27 case Float32: 28 one = onef32 29 default: 30 return nil, errors.Errorf(nyiFail, "BinaryXEnt", dt) 31 } 32 33 if logO, err = Log(output); err != nil { 34 return nil, errors.Wrap(err, operationError) 35 } 36 37 if omt, err = Sub(one, target); err != nil { 38 return nil, errors.Wrap(err, operationError) 39 } 40 41 if omo, err = Sub(one, output); err != nil { 42 return nil, errors.Wrap(err, operationError) 43 } 44 45 if tLogO, err = HadamardProd(target, logO); err != nil { 46 return nil, errors.Wrap(err, operationError) 47 } 48 49 if retVal, err = Log(omo); err != nil { 50 return nil, errors.Wrap(err, operationError) 51 } 52 53 if retVal, err = HadamardProd(omt, retVal); err != nil { 54 return nil, errors.Wrap(err, operationError) 55 } 56 57 if retVal, err = Add(tLogO, retVal); err != nil { 58 return nil, errors.Wrap(err, operationError) 59 } 60 61 return Neg(retVal) 62 } 63 64 // Dropout is a convenience function to implement dropout. 65 // It uses randomly zeroes out a *Tensor with a probability drawn from 66 // a uniform distribution 67 func Dropout(x *Node, dropProb float64) (retVal *Node, err error) { 68 return dropout(x, dropProb, UniformRandomNode) 69 } 70 71 type dropoutRandFn func(g *ExprGraph, dt tensor.Dtype, low, high float64, shape ...int) *Node 72 73 func dropout(x *Node, dropProb float64, randFn dropoutRandFn) (retVal *Node, err error) { 74 if dropProb == 0.0 { 75 return x, nil 76 } 77 keepProb := 1.0 - dropProb 78 79 var dt tensor.Dtype 80 if dt, err = dtypeOf(x.t); err != nil { 81 return nil, errors.Wrap(err, dtypeOfFail) 82 } 83 84 var pr Value 85 switch dt { 86 case Float64: 87 pr, _ = anyToScalar(keepProb) 88 case Float32: 89 pr, _ = anyToScalar(float32(keepProb)) 90 default: 91 return nil, errors.Errorf(nyiTypeFail, "Dropout()", dt) 92 } 93 94 p := NewConstant(pr) 95 96 m := randFn(x.g, dt, 0, 1, x.shape...) 97 if retVal, err = Lt(m, p, true); err != nil { 98 return nil, errors.Wrap(err, "Greater Than failed") 99 } 100 101 if retVal, err = HadamardProd(x, retVal); err != nil { 102 return nil, errors.Wrap(err, mulFail) 103 } 104 105 return HadamardDiv(retVal, p) 106 } 107 108 // LeakyRelu returns a node whose underlying value is: 109 // f(x) = alpha * x if x < 0 110 // f(x) = x for x ⩾ 0 111 // applied elementwise. 112 func LeakyRelu(x *Node, alpha float64) (*Node, error) { 113 var zero *Node 114 var dt tensor.Dtype 115 var err error 116 var alphaN *Node 117 118 // which zero to use? 119 if dt, err = dtypeOf(x.t); err != nil { 120 return nil, errors.Wrap(err, dtypeOfFail) 121 } 122 switch dt { 123 case Float64: 124 zero = zerof64 125 alphaN = NewConstant(alpha) 126 case Float32: 127 zero = zerof32 128 alphaN = NewConstant(float32(alpha)) 129 default: 130 return nil, errors.Errorf(nyiFail, "ReLu", dt) 131 } 132 133 gteZeroOp := newElemBinOp(gteOpType, x, zero) 134 gteZeroOp.retSame = true 135 136 xGteZeroCmp, err := ApplyOp(gteZeroOp, x, zero) 137 if err != nil { 138 return nil, errors.Wrap(err, applyOpFail) 139 } 140 ltZeroOp := newElemBinOp(ltOpType, x, zero) 141 ltZeroOp.retSame = true 142 143 xLtZeroCmp, err := ApplyOp(ltZeroOp, x, zero) 144 if err != nil { 145 return nil, errors.Wrap(err, applyOpFail) 146 } 147 xGteZero, err := HadamardProd(x, xGteZeroCmp) 148 if err != nil { 149 return nil, errors.Wrap(err, applyOpFail) 150 } 151 xLtZero, err := HadamardProd(x, xLtZeroCmp) 152 if err != nil { 153 return nil, errors.Wrap(err, applyOpFail) 154 } 155 xLtZeroAlpha, err := HadamardProd(xLtZero, alphaN) 156 if err != nil { 157 return nil, errors.Wrap(err, applyOpFail) 158 } 159 return Add(xGteZero, xLtZeroAlpha) 160 } 161 162 // Rectify is a convenience function for creating rectified linear units activation functions. 163 // This function uses ⩾, which is the canonical version. If you want to use >, you can create 164 // your own by just following this. 165 func Rectify(x *Node) (retVal *Node, err error) { 166 var zero *Node 167 var dt tensor.Dtype 168 group := encoding.NewGroup("Rectify") 169 170 // which zero to use? 171 if dt, err = dtypeOf(x.t); err != nil { 172 return nil, errors.Wrap(err, dtypeOfFail) 173 } 174 switch dt { 175 case Float64: 176 zero = zerof64 177 case Float32: 178 zero = zerof32 179 default: 180 return nil, errors.Errorf(nyiFail, "ReLu", dt) 181 } 182 183 cmp := newElemBinOp(gteOpType, x, zero) 184 cmp.retSame = true 185 186 if retVal, err = ApplyOp(cmp, x, zero); err != nil { 187 return nil, errors.Wrap(err, applyOpFail) 188 } 189 retVal.groups = retVal.groups.Upsert(group) 190 191 return HadamardProd(x, retVal) 192 } 193 194 // Im2Col converts a BCHW image block to columns. The kernel, pad and stride parameter must be shape of size 2, no more no less 195 // This poor naming scheme clearly comes from matlab 196 func Im2Col(n *Node, kernel, pad, stride, dilation tensor.Shape) (retVal *Node, err error) { 197 if kernel.Dims() != 2 { 198 return nil, errors.Errorf("kernel shape is supposed to have a dim of 2") 199 } 200 if pad.Dims() != 2 { 201 return nil, errors.Errorf("pad is supposed to have a dim of 2") 202 } 203 if stride.Dims() != 2 { 204 return nil, errors.Errorf("strides is supposed to have a dim of 2") 205 } 206 if dilation.Dims() != 2 { 207 return nil, errors.Errorf("dilation is supposed to have a dim of 2") 208 } 209 210 if kernel[0] <= 0 || kernel[1] <= 0 { 211 return nil, errors.Errorf("cannot have negative or 0 in kernel shape") 212 } 213 214 if stride[0] <= 0 || stride[1] <= 0 { 215 return nil, errors.Errorf("cannot have negative or 0 in stride: %v", stride) 216 } 217 218 if pad[0] < 0 || pad[1] < 0 { 219 return nil, errors.Errorf("cannot have negative padding") 220 } 221 222 if dilation[0] <= 0 || dilation[1] <= 0 { 223 return nil, errors.Errorf("cannot have negative or 0 in dilation. %v", dilation) 224 } 225 226 op := makeIm2ColOp(kernel[0], kernel[1], pad[0], pad[1], stride[0], stride[1], dilation[0], dilation[1]) 227 return ApplyOp(op, n) 228 } 229 230 // Conv2d is a simple 2D convolution, to be used for CPU computation only. 231 // If CuDNN is used, use the CUDAConv2D function. 232 // These are the properties the inputs must fulfil: 233 // 234 // - im: must have 4D shape. Expected format is BCHW (batch, channels, height, width) 235 // - filter: must have 4D shape: (batch, kernel, height, width) 236 // - kernelShape: shape of the filter kernel 237 // - pad: len(pad) == 2, defaults to []int{0, 0} if nil is passed 238 // - stride: len(stride) == 2, example: []int{1, 1} 239 // - dilation: len(dilation) == 2, defaults to []int{1, 1} if nil is passed 240 func Conv2d(im, filter *Node, kernelShape tensor.Shape, pad, stride, dilation []int) (retVal *Node, err error) { 241 group := encoding.NewGroup("Convolution") 242 // niceness for defaults 243 if pad == nil { 244 pad = []int{0, 0} 245 } 246 if dilation == nil { 247 dilation = []int{1, 1} 248 } 249 250 if im.Shape().Dims() != 4 { 251 return nil, fmt.Errorf("im should have 4 dims, got %v dims", im.Shape().Dims()) 252 } 253 254 if filter.Shape().Dims() != 4 { 255 return nil, fmt.Errorf("filter should have 4 dims, got %v dims", filter.Shape().Dims()) 256 } 257 258 // checks 259 for _, s := range stride { 260 if s <= 0 { 261 return nil, errors.Errorf("Cannot use strides of less than or equal 0: %v", stride) 262 } 263 } 264 265 for _, p := range pad { 266 if p < 0 { 267 return nil, errors.Errorf("Cannot use padding of less than 0: %v", pad) 268 } 269 } 270 271 for _, d := range dilation { 272 if d <= 0 { 273 return nil, errors.Errorf("Cannot use dilation less than or eq 0 %v", dilation) 274 } 275 } 276 277 var colIm *Node 278 if colIm, err = Im2Col(im, kernelShape, pad, stride, dilation); err != nil { 279 return nil, fmt.Errorf("Im2Col to failed: %w", err) 280 } 281 colIm.groups = colIm.groups.Upsert(group) 282 283 layer := filter.Shape()[0] 284 kernel := filter.Shape()[1] 285 row := filter.Shape()[2] 286 col := filter.Shape()[3] 287 288 if colIm.Shape()[3] != kernel*row*col { 289 return nil, fmt.Errorf("%d (kernel) * %d (width) * %d (height) must be %d, got %d", kernel, row, col, colIm.Shape()[3], kernel*row*col) 290 } 291 292 var flattened *Node 293 if flattened, err = Reshape(filter, tensor.Shape{layer, kernel * row * col}); err != nil { 294 return nil, fmt.Errorf("reshaping filter from %v to (%v, %v * %v * %v) failed: %w", filter.Shape(), layer, kernel, row, col, err) 295 } 296 flattened.groups = flattened.groups.Upsert(group) 297 298 // extract patch 299 batch := colIm.Shape()[0] 300 m := colIm.Shape()[1] 301 n := colIm.Shape()[2] 302 z := colIm.Shape()[3] 303 304 var patch, colImLayer *Node 305 if patch, err = Reshape(colIm, tensor.Shape{batch * m * n, z}); err != nil { 306 return nil, fmt.Errorf("reshaping colIm from %v to (%v * %v * %v * %v) failed: %w", colIm.Shape(), batch, m, n, z, err) 307 } 308 patch.groups = patch.groups.Upsert(group) 309 310 op := linAlgBinOp{ 311 āBinaryOperator: matMulOperator, 312 transA: false, 313 transB: true, 314 } 315 316 if colImLayer, err = ApplyOp(op, patch, flattened); err != nil { 317 return nil, fmt.Errorf("failed to apply op: %w", err) 318 } 319 colImLayer.groups = colImLayer.groups.Upsert(group) 320 321 // now reshape and transpose the values back into the original order 322 var res *Node 323 if res, err = Reshape(colImLayer, tensor.Shape{batch, m, n, layer}); err != nil { 324 return nil, fmt.Errorf("failed to reshape %v to (%v, %v, %v, %v): %w", colImLayer.Shape(), batch, m, n, layer, err) 325 } 326 res.groups = res.groups.Upsert(group) 327 ret, err := Transpose(res, 0, 3, 1, 2) 328 if err != nil { 329 return nil, fmt.Errorf("transpose %v failed: %w", res.Shape(), err) 330 } 331 332 ret.groups = ret.groups.Upsert(group) 333 return ret, nil 334 } 335 336 // Conv1d is a 1D convlution. It relies on Conv2D 337 func Conv1d(in, filter *Node, kernel, pad, stride, dilation int) (*Node, error) { 338 return Conv2d(in, filter, tensor.Shape{1, kernel}, []int{0, pad}, []int{1, stride}, []int{1, dilation}) 339 } 340 341 // MaxPool2D applies the kernel filter to the input node. 342 // The pad slice can have two different lengths. 343 // 344 // - if len(pad) == 2, padding is assume to be symetric, and a padding is adding up *and* down to each dimension 345 // paddedOutputH = pad[0] + inputH + pad[0] 346 // paddedOutputW = pad[1] + inputW + pad[1] 347 // 348 // - if len(pad) == 4, padding is explicit and can be asymmetric. 349 // paddedOutputH = pad[0] + inputH + pad[1] 350 // paddedOutputW = pad[2] + inputW + pad[3] 351 func MaxPool2D(x *Node, kernel tensor.Shape, pad, stride []int) (*Node, error) { 352 group := encoding.NewGroup("Maxpool") 353 xShape := x.Shape() 354 h, w := xShape[2], xShape[3] 355 kh, kw := kernel[0], kernel[1] 356 357 // check shape 358 if xShape.Dims() != 4 { 359 return nil, errors.Errorf("Expected input to have a shape with dimension 4") 360 } 361 if kernel.Dims() != 2 { 362 return nil, errors.Errorf("Expected kernel to have a shape of dimension 2") 363 } 364 365 // checks 366 for _, s := range stride { 367 if s <= 0 { 368 return nil, errors.Errorf("Cannot use strides of less than or equal 0: %v", stride) 369 } 370 } 371 372 for _, p := range pad { 373 if p < 0 { 374 return nil, errors.Errorf("Cannot use padding of less than 0: %v", pad) 375 } 376 } 377 378 padNorth := pad[0] 379 padWest := pad[1] 380 padSouth := pad[0] 381 padEast := pad[1] 382 if len(pad) == 4 { 383 padNorth = pad[0] 384 padSouth = pad[1] 385 padWest = pad[2] 386 padEast = pad[3] 387 } 388 389 if h-kh+padNorth+padSouth < 0 { 390 // error 391 return nil, errors.New("Impossible height/kernel/pad combination") 392 } 393 394 if w-kw+padWest+padEast < 0 { 395 // error 396 return nil, errors.New("Impossible width/kernel/pad combination") 397 } 398 399 op := newMaxPoolOp(xShape, kernel, pad, stride) 400 retVal, err := ApplyOp(op, x) 401 retVal.groups = retVal.groups.Upsert(group) 402 return retVal, err 403 } 404 405 // MaxPool1D applies a maxpool on the node x. 406 func MaxPool1D(x *Node, kernel, pad, stride int) (*Node, error) { 407 return MaxPool2D(x, tensor.Shape{1, kernel}, []int{0, pad}, []int{1, stride}) 408 } 409 410 // BatchNorm applies a batchnormalization. This operator can be used in forward pass or for training. 411 // In an evaluation only, the "op" output can be discared. 412 // In training phase, γ, β can be discarded and the op should be used. 413 // Input must be a matrix with shape (B, N) or a 4d tensor with shape (B, C, W, H) 414 func BatchNorm(x, scale, bias *Node, momentum, epsilon float64) (retVal, γ, β *Node, op *BatchNormOp, err error) { 415 dt, err := dtypeOf(x.Type()) 416 if err != nil { 417 return nil, nil, nil, nil, err 418 } 419 batches := x.Shape()[0] 420 channels := x.Shape()[1] 421 spatialDim := x.Shape().TotalSize() / (channels * batches) 422 423 mean := tensor.New(tensor.Of(dt), tensor.WithShape(channels)) 424 variance := tensor.New(tensor.Of(dt), tensor.WithShape(channels)) 425 ma := tensor.New(tensor.Of(dt), tensor.WithShape(1)) 426 427 meanTmp := tensor.New(tensor.Of(dt), tensor.WithShape(channels)) 428 varianceTmp := tensor.New(tensor.Of(dt), tensor.WithShape(channels)) 429 tmp := tensor.New(tensor.Of(dt), tensor.WithShape(x.Shape().Clone()...)) 430 xNorm := tensor.New(tensor.Of(dt), tensor.WithShape(x.Shape().Clone()...)) 431 batchSumMultiplier := tensor.New(tensor.Of(dt), tensor.WithShape(batches)) 432 433 var uno interface{} 434 switch dt { 435 case Float64: 436 uno = float64(1) 437 case Float32: 438 uno = float32(1) 439 } 440 spatialSumMultiplier := tensor.New(tensor.Of(dt), tensor.WithShape(spatialDim)) 441 if err = spatialSumMultiplier.Memset(uno); err != nil { 442 return nil, nil, nil, nil, err 443 } 444 445 numByChans := tensor.New(tensor.Of(dt), tensor.WithShape(channels*batches)) 446 if err = batchSumMultiplier.Memset(uno); err != nil { 447 return nil, nil, nil, nil, err 448 } 449 450 op = &BatchNormOp{ 451 momentum: momentum, 452 epsilon: epsilon, 453 454 mean: mean, 455 variance: variance, 456 ma: ma, 457 458 meanTmp: meanTmp, 459 varianceTmp: varianceTmp, 460 tmpSpace: tmp, 461 xNorm: xNorm, 462 batchSumMultiplier: batchSumMultiplier, 463 numByChans: numByChans, 464 spatialSumMultiplier: spatialSumMultiplier, 465 466 training: true, 467 dims: x.Dims(), 468 } 469 g := x.Graph() 470 dims := x.Shape().Dims() 471 472 if scale == nil { 473 scale = NewTensor(g, dt, dims, WithShape(x.Shape().Clone()...), WithName(x.Name()+"_γ"), WithInit(GlorotN(1.0))) 474 } 475 if bias == nil { 476 bias = NewTensor(g, dt, dims, WithShape(x.Shape().Clone()...), WithName(x.Name()+"_β"), WithInit(GlorotN(1.0))) 477 } 478 479 if retVal, err = ApplyOp(op, x); err != nil { 480 return nil, nil, nil, nil, err 481 } 482 if retVal, err = Auto(BroadcastHadamardProd, scale, retVal); err != nil { 483 return nil, nil, nil, nil, err 484 } 485 retVal, err = Auto(BroadcastAdd, retVal, bias) 486 487 return retVal, scale, bias, op, err 488 } 489 490 // GlobalAveragePool2D consumes an input tensor X and applies average pooling across the values in the same channel. 491 // The expected input shape is BCHW where B is the batch size, C is the number of channels, and H and W are the height and the width of the data. 492 func GlobalAveragePool2D(x *Node) (*Node, error) { 493 return ApplyOp(&globalAveragePoolOp{}, x) 494 }