gorgonia.org/gorgonia@v0.9.17/op_softmax.go (about) 1 package gorgonia 2 3 import ( 4 "fmt" 5 "hash" 6 "math" 7 "runtime" 8 "sync" 9 10 "github.com/chewxy/hm" 11 "github.com/chewxy/math32" 12 "github.com/pkg/errors" 13 "gorgonia.org/tensor" 14 ) 15 16 type softmaxOp struct { 17 shape tensor.Shape 18 axis int 19 isLog bool 20 } 21 22 func newSoftmaxOp(inputShape tensor.Shape, axes ...int) *softmaxOp { 23 axis := -1 24 if len(axes) > 0 { 25 axis = axes[0] 26 } 27 softmaxop := &softmaxOp{ 28 shape: inputShape, 29 axis: axis, 30 } 31 32 return softmaxop 33 } 34 35 func (op *softmaxOp) Arity() int { return 1 } 36 37 func (op *softmaxOp) ReturnsPtr() bool { return false } 38 39 func (op *softmaxOp) CallsExtern() bool { return false } 40 41 func (op *softmaxOp) WriteHash(h hash.Hash) { fmt.Fprintf(h, "Softmax{%v}()", op.axis) } 42 43 func (op *softmaxOp) Hashcode() uint32 { return simpleHash(op) } 44 45 func (op *softmaxOp) String() string { return fmt.Sprintf("Softmax{%d, %v}()", op.axis, op.isLog) } 46 47 func (op *softmaxOp) InferShape(inputs ...DimSizer) (tensor.Shape, error) { 48 return inputs[0].(tensor.Shape), nil 49 } 50 51 func (op *softmaxOp) Type() hm.Type { 52 a := hm.TypeVariable('a') 53 return hm.NewFnType(a, a) // f(float64) float64 54 } 55 56 func (op *softmaxOp) OverwritesInput() int { return -1 } 57 58 func (op *softmaxOp) checkInput(inputs ...Value) (tensor.Tensor, error) { 59 if err := checkArity(op, len(inputs)); err != nil { 60 return nil, err 61 } 62 63 var ( 64 in tensor.Tensor 65 ok bool 66 ) 67 68 if in, ok = inputs[0].(tensor.Tensor); !ok { 69 return nil, errors.Errorf("Expected input to be a tensor") 70 } 71 72 return in, nil 73 } 74 75 func (op *softmaxOp) Do(inputs ...Value) (retVal Value, err error) { 76 inputTensor, err := op.checkInput(inputs...) 77 if err != nil { 78 return nil, fmt.Errorf("Can't check Softmax input: %w", err) 79 } 80 81 aShape := inputTensor.Shape() 82 axis := aShape.Dims() - 1 // default: last dim 83 84 if aShape.IsColVec() || (aShape.IsVector() && !aShape.IsRowVec()) { 85 axis = 0 86 } 87 if op.axis != -1 { 88 axis = op.axis 89 } 90 91 ret := tensor.New(tensor.WithShape(aShape.Clone()...), tensor.Of(inputTensor.Dtype())) 92 data := inputTensor.Data() 93 output := ret.Data() 94 op.do(aShape, axis, data, output) 95 return ret, nil 96 97 } 98 99 func (op *softmaxOp) UsePreallocDo(prealloc Value, inputs ...Value) (Value, error) { 100 inputTensor, err := op.checkInput(inputs...) 101 if err != nil { 102 return nil, fmt.Errorf("Can't check Softmax input: %w", err) 103 } 104 105 aShape := inputTensor.Shape() 106 axis := aShape.Dims() - 1 // default: last dim 107 108 if aShape.IsColVec() || (aShape.IsVector() && !aShape.IsRowVec()) { 109 axis = 0 110 } 111 if op.axis != -1 { 112 axis = op.axis 113 } 114 115 op.do(aShape, axis, inputTensor.Data(), prealloc.Data()) 116 return prealloc, nil 117 } 118 119 // DoDiff calculates the diff and sets its value to the output node. Implementation for ADOp interface. 120 func (op *softmaxOp) DoDiff(ctx ExecutionContext, inputs Nodes, output *Node) error { 121 if len(inputs) != 1 { 122 return fmt.Errorf("SoftmaxOp.DoDiff needs 1 arguments") 123 } 124 125 odv := output.boundTo.(*dualValue) 126 idv := inputs[0].boundTo.(*dualValue) 127 idvd := idv.d.(*tensor.Dense) 128 diffOp := &softmaxDiffOp{op} 129 130 result, err := diffOp.Do(idv.Value, odv.Value, odv.d) 131 if err != nil { 132 return err 133 } 134 135 sum, err := idvd.Add(result.(*tensor.Dense), tensor.UseUnsafe()) 136 if err != nil { 137 return err 138 } 139 140 odv.d = sum 141 142 return nil 143 } 144 145 // SymDiff applies the diff op. Implementation for SDOp interface. 146 func (op *softmaxOp) SymDiff(inputs Nodes, output, grad *Node) (Nodes, error) { 147 err := checkArity(op, len(inputs)) 148 if err != nil { 149 return nil, err 150 } 151 152 diffOp := &softmaxDiffOp{op} 153 nodes := make(Nodes, 1) 154 155 nodes[0], err = ApplyOp(diffOp, inputs[0], output, grad) 156 157 return nodes, err 158 } 159 160 // DiffWRT is an implementation for the SDOp interface 161 func (op *softmaxOp) DiffWRT(inputs int) []bool { 162 if inputs != 1 { 163 panic(fmt.Sprintf("softmax operator only supports one input, got %d instead", inputs)) 164 } 165 166 return []bool{true} 167 } 168 169 func (op *softmaxOp) f64skernel(data, output []float64, inner, ostride, dimSize, dimStride int) { 170 for i := 0; i < len(data); i++ { 171 oi := i / inner 172 ii := i % inner 173 xidx := oi*ostride + ii 174 yidx := oi*ostride + ii 175 176 if xidx >= len(data) { 177 continue 178 } 179 180 if yidx >= len(data) { 181 continue 182 } 183 184 x := data[xidx:] 185 y := output[yidx:] 186 if len(x) == 0 { 187 continue 188 } 189 190 max := x[0] 191 for d := 1; d < dimSize && d*dimStride < len(x); d++ { 192 dm := x[d*dimStride] 193 if dm > max { 194 max = dm 195 } 196 } 197 198 var sum float64 199 for d := 0; d < dimSize && d*dimStride < len(x); d++ { 200 z := math.Exp(x[d*dimStride] - max) 201 if !op.isLog { 202 y[d*dimStride] = z 203 } 204 sum += z 205 } 206 207 if op.isLog { 208 sum = math.Log(sum) 209 } else { 210 sum = 1 / sum 211 } 212 213 // set output 214 for d := 0; d < dimSize && d*dimStride < len(y); d++ { 215 if op.isLog { 216 y[d*dimStride] = x[d*dimStride] - max - sum 217 } else { 218 y[d*dimStride] *= sum 219 } 220 } 221 222 } 223 } 224 225 func (op *softmaxOp) f32skernel(data, output []float32, inner, ostride, dimSize, dimStride int) { 226 for i := 0; i < len(data); i++ { 227 oi := i / inner 228 ii := i % inner 229 xidx := oi*ostride + ii 230 yidx := oi*ostride + ii 231 232 if xidx >= len(data) { 233 continue 234 } 235 236 if yidx >= len(output) { 237 continue 238 } 239 240 x := data[xidx:] 241 y := output[yidx:] 242 if len(x) == 0 { 243 continue 244 } 245 246 max := x[0] 247 for d := 1; d < dimSize && d*dimStride < len(x); d++ { 248 dm := x[d*dimStride] 249 if dm > max { 250 max = dm 251 } 252 } 253 254 var tmp float32 255 for d := 0; d < dimSize && d*dimStride < len(x); d++ { 256 z := math32.Exp(x[d*dimStride] - max) 257 if !op.isLog { 258 y[d*dimStride] = z 259 } 260 tmp += z 261 } 262 263 if op.isLog { 264 tmp = math32.Log(tmp) 265 } else { 266 tmp = 1 / tmp 267 } 268 269 // set output 270 for d := 0; d < dimSize && d*dimStride < len(y); d++ { 271 if op.isLog { 272 y[d*dimStride] = x[d*dimStride] - max - tmp 273 } else { 274 y[d*dimStride] *= tmp 275 } 276 } 277 } 278 } 279 280 // output and data are of the same size 281 func (op *softmaxOp) do(shp tensor.Shape, axis int, input, output interface{}) { 282 threads := runtime.GOMAXPROCS(0) + 1 283 284 dimSize := shp[axis] 285 outer := tensor.ProdInts([]int(shp[:axis])) 286 inner := tensor.ProdInts([]int(shp[axis+1:])) 287 if outer == 0 { 288 outer = 1 289 } 290 if inner == 0 { 291 inner = 1 292 } 293 dimStride := inner 294 ostride := dimSize * dimStride 295 axisSize := shp[axis] 296 297 datalen := shp.TotalSize() 298 blockSize := calcBlocks(datalen, threads) 299 blocks := datalen / blockSize 300 if blockSize == 0 || blockSize < axisSize { 301 blockSize = datalen 302 blocks = 1 303 } 304 305 if blocks < minParallelBlocks { 306 switch data := input.(type) { 307 case []float64: 308 output := output.([]float64) 309 op.f64skernel(data, output, inner, ostride, dimSize, dimStride) 310 case []float32: 311 output := output.([]float32) 312 op.f32skernel(data, output, inner, ostride, dimSize, dimStride) 313 default: 314 panic(fmt.Sprintf("tensor of %T not handled for softmax diff ", data)) 315 } 316 return 317 } 318 319 workers := workersChan() 320 var wg sync.WaitGroup 321 322 for b := 0; b < datalen; b += blockSize { 323 wg.Add(1) 324 switch data := input.(type) { 325 case []float64: 326 output := output.([]float64) 327 end := b + blockSize 328 if end > len(data) { 329 end = len(data) 330 } 331 newdata := data[b:end] 332 newoutput := output[b:end] 333 go func(data, output []float64, dimSize, dimStride int, wg *sync.WaitGroup) { 334 workers <- struct{}{} 335 op.f64skernel(data, output, inner, ostride, dimSize, dimStride) 336 wg.Done() 337 <-workers 338 }(newdata, newoutput, dimSize, dimStride, &wg) 339 case []float32: 340 output := output.([]float32) 341 end := b + blockSize 342 if end > len(data) { 343 end = len(data) 344 } 345 newdata := data[b:end] 346 newoutput := output[b:end] 347 go func(data, output []float32, dimSize, dimStride int, wg *sync.WaitGroup) { 348 workers <- struct{}{} 349 op.f32skernel(data, output, inner, ostride, dimSize, dimStride) 350 wg.Done() 351 <-workers 352 }(newdata, newoutput, dimSize, dimStride, &wg) 353 default: 354 panic(fmt.Sprintf("tensor of %T not handled for softmax diff ", data)) 355 356 } 357 } 358 wg.Wait() 359 } 360 361 type softmaxDiffOp struct { 362 *softmaxOp 363 } 364 365 func (op *softmaxDiffOp) Arity() int { return 3 } 366 367 func (op *softmaxDiffOp) ReturnsPtr() bool { return false } 368 369 func (op *softmaxDiffOp) CallsExtern() bool { return false } 370 371 func (op *softmaxDiffOp) WriteHash(h hash.Hash) { 372 fmt.Fprintf(h, "SoftmaxDiff{%d, %v}()", op.axis, op.isLog) 373 } 374 375 func (op *softmaxDiffOp) Hashcode() uint32 { return simpleHash(op) } 376 377 func (op *softmaxDiffOp) String() string { 378 return fmt.Sprintf("SoftmaxDiff{%d, %v}()", op.axis, op.isLog) 379 } 380 381 func (op *softmaxDiffOp) InferShape(inputs ...DimSizer) (tensor.Shape, error) { 382 s := inputs[0].(tensor.Shape).Clone() 383 384 return s, nil 385 } 386 387 func (op *softmaxDiffOp) Type() hm.Type { 388 a := hm.TypeVariable('a') 389 390 return hm.NewFnType(a, a, a, a) // f(float64) float64 391 } 392 393 func (op *softmaxDiffOp) OverwritesInput() int { return -1 } 394 395 func (op *softmaxDiffOp) checkInput(inputs ...Value) (tensor.Tensor, tensor.Tensor, tensor.Tensor, error) { 396 if err := checkArity(op, len(inputs)); err != nil { 397 return nil, nil, nil, err 398 } 399 400 var ( 401 in tensor.Tensor 402 out tensor.Tensor 403 grad tensor.Tensor 404 ok bool 405 ) 406 407 switch t := inputs[0].(type) { 408 case *dualValue: 409 if in, ok = t.Value.(tensor.Tensor); !ok { 410 return nil, nil, nil, errors.Errorf("input should be a tensor, got %T", inputs[0]) 411 } 412 case tensor.Tensor: 413 in = t 414 default: 415 return nil, nil, nil, errors.Errorf("input type is not supported, got %T", inputs[0]) 416 } 417 418 switch t := inputs[1].(type) { 419 case *dualValue: 420 if out, ok = t.Value.(tensor.Tensor); !ok { 421 return nil, nil, nil, errors.Errorf("output should be a tensor, got %T", inputs[1]) 422 } 423 case tensor.Tensor: 424 out = t 425 default: 426 return nil, nil, nil, errors.Errorf("output type is not supported, got %T", inputs[1]) 427 } 428 429 switch t := inputs[2].(type) { 430 case *dualValue: 431 if grad, ok = t.Value.(tensor.Tensor); !ok { 432 return nil, nil, nil, errors.Errorf("grad should be a tensor, got %T", inputs[1]) 433 } 434 case tensor.Tensor: 435 grad = t 436 default: 437 return nil, nil, nil, errors.Errorf("grad type is not supported, got %T", inputs[1]) 438 } 439 440 return in, out, grad, nil 441 } 442 443 func (op *softmaxDiffOp) Do(inputs ...Value) (Value, error) { 444 x, y, grad, err := op.checkInput(inputs...) 445 if err != nil { 446 return nil, fmt.Errorf("Can't check SoftmaxDiff input: %w", err) 447 } 448 449 s := x.Shape() 450 axis := op.axis 451 if axis == -1 { 452 axis = s.Dims() - 1 453 } 454 455 ret := tensor.New(tensor.WithShape(x.Shape().Clone()...), tensor.Of(x.Dtype())) 456 op.do(x.Shape(), axis, x.Data(), y.Data(), grad.Data(), ret.Data()) 457 return ret, nil 458 459 } 460 461 func (op *softmaxDiffOp) UsePreallocDo(prealloc Value, inputs ...Value) (Value, error) { 462 x, y, grad, err := op.checkInput(inputs...) 463 if err != nil { 464 return nil, fmt.Errorf("Can't check SoftmaxDiff input: %w", err) 465 } 466 467 s := x.Shape() 468 axis := op.axis 469 if axis == -1 { 470 axis = s.Dims() - 1 471 } 472 473 op.do(x.Shape(), axis, x.Data(), y.Data(), grad.Data(), prealloc.Data()) 474 return prealloc, nil 475 } 476 477 func (op *softmaxDiffOp) f64Kernel(input, output, yGrad, xGrad []float64, inner, ostride, dimSize, dimStride int) { 478 for i := 0; i < len(input); i++ { 479 oi := i / inner 480 ii := i % inner 481 idx := oi*ostride + ii 482 if idx >= len(input) { 483 continue 484 } 485 486 if idx >= len(output) { 487 continue 488 } 489 if idx >= len(yGrad) { 490 continue 491 } 492 if idx >= len(xGrad) { 493 continue 494 } 495 496 y := output 497 dy := yGrad 498 dydx := xGrad 499 500 // calculate sum 501 var sum float64 502 for d := 0; d < dimSize; d++ { 503 if op.isLog { 504 sum += dy[d*dimStride] 505 } else { 506 sum += dy[d*dimStride] * y[d*dimStride] 507 } 508 } 509 for d := 0; d < dimSize; d++ { 510 if op.isLog { 511 dydx[d*dimStride] = dy[d*dimStride] - math.Exp(y[d*dimStride])*sum 512 } else { 513 dydx[d*dimStride] = y[d*dimStride] * (dy[d*dimStride] - sum) 514 } 515 } 516 } 517 } 518 519 func (op *softmaxDiffOp) f32Kernel(input, output, yGrad, xGrad []float32, inner, ostride, dimSize, dimStride int) { 520 for i := 0; i < len(input); i++ { 521 oi := i / inner 522 ii := i % inner 523 idx := oi*ostride + ii 524 if idx >= len(input) { 525 continue 526 } 527 528 if idx >= len(output) { 529 continue 530 } 531 if idx >= len(yGrad) { 532 continue 533 } 534 if idx >= len(xGrad) { 535 continue 536 } 537 538 y := output[idx:] 539 dy := yGrad[idx:] 540 dydx := xGrad[idx:] 541 542 // calculate sum 543 var sum float32 544 for d := 0; d < dimSize; d++ { 545 546 if op.isLog { 547 sum += dy[d*dimStride] 548 } else { 549 sum += dy[d*dimStride] * y[d*dimStride] 550 } 551 } 552 for d := 0; d < dimSize; d++ { 553 if op.isLog { 554 dydx[d*dimStride] = dy[d*dimStride] - math32.Exp(y[d*dimStride])*sum 555 } else { 556 dydx[d*dimStride] = y[d*dimStride] * (dy[d*dimStride] - sum) 557 } 558 } 559 560 } 561 } 562 563 func (op *softmaxDiffOp) do(shp tensor.Shape, axis int, x, y, dy, retVal interface{}) { 564 //blocks := runtime.GOMAXPROCS(0) + 1 565 dimSize := shp[axis] 566 outer := tensor.ProdInts([]int(shp[:axis])) 567 inner := tensor.ProdInts([]int(shp[axis+1:])) 568 if outer == 0 { 569 outer = 1 570 } 571 if inner == 0 { 572 inner = 1 573 } 574 dimStride := inner 575 ostride := dimSize * dimStride 576 577 //datalen := shp.TotalSize() 578 579 //if blocks < minParallelBlocks { 580 switch x := x.(type) { 581 case []float64: 582 y := y.([]float64) 583 dy := dy.([]float64) 584 dydx := retVal.([]float64) 585 op.f64Kernel(x, y, dy, dydx, inner, ostride, dimSize, dimStride) 586 case []float32: 587 y := y.([]float32) 588 dy := dy.([]float32) 589 dydx := retVal.([]float32) 590 op.f32Kernel(x, y, dy, dydx, inner, ostride, dimSize, dimStride) 591 default: 592 panic(fmt.Sprintf("tensor of %T not handled for softmax diff ", x)) 593 594 } 595 //} 596 /* 597 workers := workersChan() 598 var wg sync.WaitGroup 599 blockSize := datalen / blocks 600 if blockSize == 0 { 601 blockSize = datalen // 1 block 602 } 603 604 for b := 0; b < datalen; b += blockSize { 605 wg.Add(1) 606 switch xData := x.(type) { 607 case []float64: 608 yData := y.([]float64) 609 dyData := dy.([]float64) 610 dydxData := retVal.([]float64) 611 end := b + blockSize 612 if end > len(xData) { 613 end = len(xData) 614 } 615 newX := xData[b:end] 616 newY := yData[b:end] 617 newDy := dyData[b:end] 618 newDydx := dydxData[b:end] 619 620 go func(x, y, dy, dydx []float64, dimSize, dimStride int, wg *sync.WaitGroup) { 621 workers <- struct{}{} 622 op.f64Kernel(x, y, dy, dydx, inner, ostride, dimSize, dimStride) 623 wg.Done() 624 <-workers 625 }(newX, newY, newDy, newDydx, dimSize, dimStride, &wg) 626 case []float32: 627 yData := y.([]float32) 628 dyData := dy.([]float32) 629 dydxData := retVal.([]float32) 630 end := b + blockSize 631 if end > len(xData) { 632 end = len(xData) 633 } 634 newX := xData[b:end] 635 newY := yData[b:end] 636 newDy := dyData[b:end] 637 newDydx := dydxData[b:end] 638 go func(x, y, dy, dydx []float32, dimSize, dimStride int, wg *sync.WaitGroup) { 639 workers <- struct{}{} 640 op.f32Kernel(x, y, dy, dydx, inner, ostride, dimSize, dimStride) 641 wg.Done() 642 <-workers 643 }(newX, newY, newDy, newDydx, dimSize, dimStride, &wg) 644 default: 645 panic(fmt.Sprintf("tensor of %T not handled for softmax diff ", xData)) 646 647 } 648 } 649 */ 650 } 651 652 // ensure it complies with the Op interface 653 var ( 654 _ Op = &softmaxOp{} 655 _ ADOp = &softmaxOp{} 656 _ SDOp = &softmaxOp{} 657 658 _ Op = &softmaxDiffOp{} 659 )