github.com/wzzhu/tensor@v0.9.24/defaultengine_softmax.go (about) 1 package tensor 2 3 import ( 4 "fmt" 5 "math" 6 "sync" 7 8 "github.com/chewxy/math32" 9 "github.com/pkg/errors" 10 ) 11 12 // if dims = 2 and axis -1 it returns the last dimension. In this case 1 13 func resolveAxis(axis int, dims int) int { 14 res := axis % dims 15 if (res < 0 && dims > 0) || (res > 0 && dims < 0) { 16 return res + dims 17 } 18 19 return res 20 } 21 22 // SoftMax performs the softmax operation on the given tensor. Currently it expects the tensor to be a Dense tensor. 23 // Please make a pull request to support sparse tensors. 24 // 25 // The softmax function is defined as : 26 // σ(x) = e^x_i / Σ(e^x_i) 27 func (e StdEng) SoftMax(x Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) { 28 axis = resolveAxis(axis, x.Dims()) 29 expectedShape := x.Shape() 30 31 var reuse DenseTensor 32 var safe, toReuse, _ bool 33 if reuse, safe, toReuse, _, _, err = handleFuncOpts(expectedShape, x.Dtype(), x.DataOrder(), true, opts...); err != nil { 34 return nil, errors.Wrap(err, "Unable to handle funcOpts") 35 } 36 if safe || !toReuse && reuse == nil && safe { 37 // create reuse 38 reuse = New(WithShape(expectedShape...), Of(x.Dtype())) 39 } 40 41 switch x.Dtype() { 42 case Float32: 43 if expectedShape.Dims()-1 == axis { 44 e.softMaxLastDimF32(reuse, x, axis, false) 45 } else { 46 e.softMaxInnerDimF32(reuse, x, axis, false) 47 } 48 case Float64: 49 if expectedShape.Dims()-1 == axis { 50 e.softMaxLastDimF64(reuse, x, axis, false) 51 } else { 52 e.softMaxInnerDimF64(reuse, x, axis, false) 53 } 54 default: 55 return nil, fmt.Errorf("type %v not supported", x.Dtype()) 56 } 57 58 return reuse, nil 59 } 60 61 // SoftMaxB computes gradient of the input `x`, given the `output = SoftMax(x)` and its associated gradient. Currently it expects the tensor to be a Dense tensor. 62 // Please make a pull request to support sparse tensors. 63 func (e StdEng) SoftMaxB(output, grad Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) { 64 if !output.Shape().Eq(grad.Shape()) { 65 return nil, fmt.Errorf("output and grad shapes don't match") 66 } 67 68 if !output.Dtype().Eq(grad.Dtype()) { 69 return nil, fmt.Errorf("output and grad types don't match") 70 } 71 72 axis = resolveAxis(axis, output.Dims()) 73 expectedShape := output.Shape() 74 75 var reuse DenseTensor 76 var safe, toReuse, _ bool 77 if reuse, safe, toReuse, _, _, err = handleFuncOpts(expectedShape, output.Dtype(), output.DataOrder(), true, opts...); err != nil { 78 return nil, errors.Wrap(err, "Unable to handle funcOpts") 79 } 80 if safe || !toReuse && reuse == nil && safe { 81 // create reuse 82 reuse = New(WithShape(expectedShape...), Of(output.Dtype())) 83 } 84 85 switch output.Dtype() { 86 case Float32: 87 if expectedShape.Dims()-1 == axis { 88 e.softMaxBLastDimF32(reuse, output, grad, axis, false) 89 } else { 90 e.softMaxBInnerDimF32(reuse, output, grad, axis, false) 91 } 92 case Float64: 93 if expectedShape.Dims()-1 == axis { 94 e.softMaxBLastDimF64(reuse, output, grad, axis, false) 95 } else { 96 e.softMaxBInnerDimF64(reuse, output, grad, axis, false) 97 } 98 default: 99 return nil, fmt.Errorf("type %v not supported", output.Dtype()) 100 } 101 102 return reuse, nil 103 } 104 105 // LogSoftMax performs softmax but in log space. This provides some amount of numerical stabilization. 106 // Conceptually it is the same as performing a logarithm after applying the softmax function. 107 // Currently it expects the tensor to be a Dense tensor. 108 // Please make a pull request to support sparse tensors. 109 func (e StdEng) LogSoftMax(x Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) { 110 axis = resolveAxis(axis, x.Dims()) 111 expectedShape := x.Shape() 112 113 var reuse DenseTensor 114 var safe, toReuse, _ bool 115 if reuse, safe, toReuse, _, _, err = handleFuncOpts(expectedShape, x.Dtype(), x.DataOrder(), true, opts...); err != nil { 116 return nil, errors.Wrap(err, "Unable to handle funcOpts") 117 } 118 if safe || !toReuse && reuse == nil && safe { 119 // create reuse 120 reuse = New(WithShape(expectedShape...), Of(x.Dtype())) 121 } 122 123 switch x.Dtype() { 124 case Float32: 125 if expectedShape.Dims()-1 == axis { 126 e.softMaxLastDimF32(reuse, x, axis, true) 127 } else { 128 e.softMaxInnerDimF32(reuse, x, axis, true) 129 } 130 case Float64: 131 if expectedShape.Dims()-1 == axis { 132 e.softMaxLastDimF64(reuse, x, axis, true) 133 } else { 134 e.softMaxInnerDimF64(reuse, x, axis, true) 135 } 136 default: 137 return nil, fmt.Errorf("type %v not supported", x.Dtype()) 138 } 139 140 return reuse, nil 141 } 142 143 // LogSoftMaxB computes the gradient of the input `x`, given the `output = LogSoftmax(x)` and its associated gradient. 144 // Currently it expects the tensor to be a Dense tensor. 145 // Please make a pull request to support sparse tensors. 146 func (e StdEng) LogSoftMaxB(output, grad Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) { 147 if !output.Shape().Eq(grad.Shape()) { 148 return nil, fmt.Errorf("output and grad shapes don't match") 149 } 150 151 if !output.Dtype().Eq(grad.Dtype()) { 152 return nil, fmt.Errorf("output and grad types don't match") 153 } 154 155 axis = resolveAxis(axis, output.Dims()) 156 expectedShape := output.Shape() 157 158 var reuse DenseTensor 159 var safe, toReuse, _ bool 160 if reuse, safe, toReuse, _, _, err = handleFuncOpts(expectedShape, output.Dtype(), output.DataOrder(), true, opts...); err != nil { 161 return nil, errors.Wrap(err, "Unable to handle funcOpts") 162 } 163 if safe || !toReuse && reuse == nil && safe { 164 // create reuse 165 reuse = New(WithShape(expectedShape...), Of(output.Dtype())) 166 } 167 168 switch output.Dtype() { 169 case Float32: 170 if expectedShape.Dims()-1 == axis { 171 e.softMaxBLastDimF32(reuse, output, grad, axis, true) 172 } else { 173 e.softMaxBInnerDimF32(reuse, output, grad, axis, true) 174 } 175 case Float64: 176 if expectedShape.Dims()-1 == axis { 177 e.softMaxBLastDimF64(reuse, output, grad, axis, true) 178 } else { 179 e.softMaxBInnerDimF64(reuse, output, grad, axis, true) 180 } 181 default: 182 return nil, fmt.Errorf("type %v not supported", output.Dtype()) 183 } 184 185 return reuse, nil 186 } 187 188 func (e StdEng) softMaxLastDimF64(output Tensor, x Tensor, axis int, logSoftMax bool) { 189 outputArr := getFloat64s(output) 190 xArr := getFloat64s(x) 191 192 xShape := x.Shape() 193 194 outerSize := 1 195 dimSize := xShape[axis] 196 for i := 0; i < axis; i++ { 197 outerSize *= xShape[i] 198 } 199 200 var wg sync.WaitGroup 201 for ii := 0; ii < outerSize; ii++ { 202 wg.Add(1) 203 go func(ii int, wg *sync.WaitGroup) { 204 maxInput := xArr[0] 205 for j := 1; j < dimSize; j++ { 206 i := ii*dimSize + j 207 208 if xArr[i] > maxInput { 209 maxInput = xArr[i] 210 } 211 } 212 213 sumExp := float64(0.0) 214 for j := 0; j < dimSize; j++ { 215 i := ii*dimSize + j 216 z := xArr[i] - maxInput 217 exp := math.Exp(z) 218 219 if logSoftMax { 220 outputArr[i] = z 221 } else { 222 outputArr[i] = exp 223 } 224 225 sumExp += exp 226 } 227 228 if !logSoftMax { 229 sumExp = 1 / sumExp 230 } 231 232 for j := 0; j < dimSize; j++ { 233 i := ii*dimSize + j 234 235 if logSoftMax { 236 outputArr[i] -= math.Log(sumExp) 237 } else { 238 outputArr[i] *= sumExp 239 } 240 } 241 wg.Done() 242 }(ii, &wg) 243 244 } 245 wg.Wait() 246 } 247 248 func (e StdEng) softMaxBLastDimF64(inputGrad, output, grad Tensor, axis int, logSoftMax bool) { 249 dx := getFloat64s(inputGrad) 250 outputArr := getFloat64s(output) 251 gradArr := getFloat64s(grad) 252 253 outputShape := output.Shape() 254 255 outerSize := 1 256 dimSize := outputShape[axis] 257 for i := 0; i < axis; i++ { 258 outerSize *= outputShape[i] 259 } 260 261 var wg sync.WaitGroup 262 for ii := 0; ii < outerSize; ii++ { 263 wg.Add(1) 264 if logSoftMax { 265 go func(gradArr, dx []float64, ii int, wg *sync.WaitGroup) { 266 sum := gradArr[ii*dimSize] 267 for j := 1; j < dimSize; j++ { 268 i := ii*dimSize + j 269 270 sum += gradArr[i] 271 } 272 273 for j := 0; j < dimSize; j++ { 274 i := ii*dimSize + j 275 276 dx[i] = gradArr[i] - (math.Exp(outputArr[i]) * sum) 277 } 278 wg.Done() 279 }(gradArr, dx, ii, &wg) 280 281 } else { 282 go func(outputArr, gradArr, dx []float64, ii int, wg *sync.WaitGroup) { 283 //mul := make([]float64, dimSize) 284 var sum float64 285 for j := 0; j < dimSize; j++ { 286 i := ii*dimSize + j 287 288 //mul[j] = outputArr[i] * gradArr[i] 289 sum += outputArr[i] * gradArr[i] 290 } 291 292 // sum := mul[0] 293 // for j := 1; j < dimSize; j++ { 294 // sum += mul[j] 295 // } 296 297 for j := 0; j < dimSize; j++ { 298 i := ii*dimSize + j 299 dx[i] = (gradArr[i] - sum) * outputArr[i] 300 } 301 wg.Done() 302 }(outputArr, gradArr, dx, ii, &wg) 303 } 304 } 305 wg.Wait() 306 } 307 308 func (e StdEng) softMaxInnerDimF64(output Tensor, x Tensor, axis int, logSoftmax bool) { 309 xShape := x.Shape() 310 311 innerSize, outerSize := 1, 1 312 for i := 0; i < axis; i++ { 313 outerSize *= xShape[i] 314 } 315 316 for i := axis + 1; i < xShape.Dims(); i++ { 317 innerSize *= xShape[i] 318 } 319 320 dimSize := xShape[axis] 321 dimStride := innerSize 322 outerStride := dimSize * dimStride 323 324 outputArr := getFloat64s(output) 325 xArr := getFloat64s(x) 326 327 var wg sync.WaitGroup 328 for ii := 0; ii < innerSize*outerSize; ii++ { 329 wg.Add(1) 330 go func(ii int, wg *sync.WaitGroup) { 331 outerIndex, innerIndex := divmod(ii, innerSize) 332 333 inputPart := xArr[outerIndex*outerStride+innerIndex:] 334 outputPart := outputArr[outerIndex*outerStride+innerIndex:] 335 336 maxInput := inputPart[0] 337 for j := 1; j < dimSize; j++ { 338 i := j * dimStride 339 340 if inputPart[i] > maxInput { 341 maxInput = inputPart[i] 342 } 343 } 344 345 sumExp := 0.0 346 for j := 0; j < dimSize; j++ { 347 i := j * dimStride 348 349 exp := math.Exp(inputPart[i] - maxInput) 350 351 if !logSoftmax { 352 outputPart[i] = exp 353 } 354 355 sumExp += exp 356 } 357 358 if logSoftmax { 359 sumExp = math.Log(sumExp) 360 } else { 361 sumExp = 1 / sumExp 362 } 363 364 for j := 0; j < dimSize; j++ { 365 i := j * dimStride 366 367 if logSoftmax { 368 outputPart[i] = inputPart[i] - maxInput - sumExp 369 } else { 370 outputPart[i] *= sumExp 371 } 372 } 373 wg.Done() 374 }(ii, &wg) 375 } 376 wg.Wait() 377 } 378 379 func (e StdEng) softMaxBInnerDimF64(inputGrad, output, grad Tensor, axis int, logSoftmax bool) { 380 dxShape := inputGrad.Shape() 381 382 innerSize, outerSize := 1, 1 383 for i := 0; i < axis; i++ { 384 outerSize *= dxShape[i] 385 } 386 387 for i := axis + 1; i < dxShape.Dims(); i++ { 388 innerSize *= dxShape[i] 389 } 390 391 dimSize := dxShape[axis] 392 dimStride := innerSize 393 outerStride := dimSize * dimStride 394 395 dxArr := getFloat64s(inputGrad) 396 outputArr := getFloat64s(output) 397 gradArr := getFloat64s(grad) 398 399 var wg sync.WaitGroup 400 for ii := 0; ii < innerSize*outerSize; ii++ { 401 wg.Add(1) 402 go func(ii int, wg *sync.WaitGroup) { 403 outerIndex, innerIndex := divmod(ii, innerSize) 404 405 gradPart := gradArr[outerIndex*outerStride+innerIndex:] 406 dxPart := dxArr[outerIndex*outerStride+innerIndex:] 407 outputPart := outputArr[outerIndex*outerStride+innerIndex:] 408 409 sum := 0.0 410 for j := 0; j < dimSize; j++ { 411 i := j * dimStride 412 413 if logSoftmax { 414 sum += gradPart[i] 415 } else { 416 sum += gradPart[i] * outputPart[i] 417 } 418 } 419 420 for j := 0; j < dimSize; j++ { 421 i := j * dimStride 422 423 if logSoftmax { 424 dxPart[i] = gradPart[i] - math.Exp(outputPart[i])*sum 425 } else { 426 dxPart[i] = outputPart[i] * (gradPart[i] - sum) 427 } 428 } 429 wg.Done() 430 }(ii, &wg) 431 432 } 433 wg.Wait() 434 } 435 436 func (e StdEng) softMaxLastDimF32(output Tensor, x Tensor, axis int, logSoftMax bool) { 437 outputArr := getFloat32s(output) 438 xArr := getFloat32s(x) 439 xShape := x.Shape() 440 441 outerSize := 1 442 dimSize := xShape[axis] 443 for i := 0; i < axis; i++ { 444 outerSize *= xShape[i] 445 } 446 447 var wg sync.WaitGroup 448 for ii := 0; ii < outerSize; ii++ { 449 wg.Add(1) 450 go func(ii int, wg *sync.WaitGroup) { 451 maxInput := xArr[0] 452 for j := 1; j < dimSize; j++ { 453 i := ii*dimSize + j 454 455 if xArr[i] > maxInput { 456 maxInput = xArr[i] 457 } 458 } 459 460 sumExp := float32(0.0) 461 for j := 0; j < dimSize; j++ { 462 i := ii*dimSize + j 463 z := xArr[i] - maxInput 464 exp := math32.Exp(z) 465 466 if logSoftMax { 467 outputArr[i] = z 468 } else { 469 outputArr[i] = exp 470 } 471 472 sumExp += exp 473 } 474 475 if !logSoftMax { 476 sumExp = 1 / sumExp 477 } 478 479 for j := 0; j < dimSize; j++ { 480 i := ii*dimSize + j 481 482 if logSoftMax { 483 outputArr[i] -= math32.Log(sumExp) 484 } else { 485 outputArr[i] *= sumExp 486 } 487 } 488 wg.Done() 489 }(ii, &wg) 490 } 491 wg.Wait() 492 } 493 494 func (e StdEng) softMaxBLastDimF32(inputGrad, output, grad Tensor, axis int, logSoftMax bool) { 495 dx := getFloat32s(inputGrad) 496 outputArr := getFloat32s(output) 497 gradArr := getFloat32s(grad) 498 499 outputShape := output.Shape() 500 501 outerSize := 1 502 dimSize := outputShape[axis] 503 for i := 0; i < axis; i++ { 504 outerSize *= outputShape[i] 505 } 506 507 var wg sync.WaitGroup 508 for ii := 0; ii < outerSize; ii++ { 509 wg.Add(1) 510 511 if logSoftMax { 512 go func(ii int, wg *sync.WaitGroup) { 513 sum := gradArr[ii*dimSize] 514 for j := 1; j < dimSize; j++ { 515 i := ii*dimSize + j 516 517 sum += gradArr[i] 518 } 519 520 for j := 0; j < dimSize; j++ { 521 i := ii*dimSize + j 522 523 dx[i] = gradArr[i] - (math32.Exp(outputArr[i]) * sum) 524 } 525 wg.Done() 526 }(ii, &wg) 527 } else { 528 go func(ii int, wg *sync.WaitGroup) { 529 //mul := make([]float32, dimSize) 530 var sum float32 531 for j := 0; j < dimSize; j++ { 532 i := ii*dimSize + j 533 534 //mul[j] = outputArr[i] * gradArr[i] 535 sum += outputArr[i] * gradArr[i] 536 } 537 538 // sum := mul[0] 539 // for j := 1; j < dimSize; j++ { 540 // sum += mul[j] 541 // } 542 543 for j := 0; j < dimSize; j++ { 544 i := ii*dimSize + j 545 546 dx[i] = (gradArr[i] - sum) * outputArr[i] 547 } 548 wg.Done() 549 }(ii, &wg) 550 } 551 } 552 wg.Wait() 553 } 554 555 func (e StdEng) softMaxInnerDimF32(output Tensor, x Tensor, axis int, logSoftmax bool) { 556 xShape := x.Shape() 557 558 innerSize, outerSize := 1, 1 559 for i := 0; i < axis; i++ { 560 outerSize *= xShape[i] 561 } 562 563 for i := axis + 1; i < xShape.Dims(); i++ { 564 innerSize *= xShape[i] 565 } 566 567 dimSize := xShape[axis] 568 dimStride := innerSize 569 outerStride := dimSize * dimStride 570 571 outputArr := getFloat32s(output) 572 xArr := getFloat32s(x) 573 574 var wg sync.WaitGroup 575 for ii := 0; ii < innerSize*outerSize; ii++ { 576 wg.Add(1) 577 578 go func(ii int, wg *sync.WaitGroup) { 579 outerIndex, innerIndex := divmod(ii, innerSize) 580 581 inputPart := xArr[outerIndex*outerStride+innerIndex:] 582 outputPart := outputArr[outerIndex*outerStride+innerIndex:] 583 584 maxInput := inputPart[0] 585 for j := 1; j < dimSize; j++ { 586 i := j * dimStride 587 588 if inputPart[i] > maxInput { 589 maxInput = inputPart[i] 590 } 591 } 592 593 sumExp := float32(0.0) 594 for j := 0; j < dimSize; j++ { 595 i := j * dimStride 596 597 exp := math32.Exp(inputPart[i] - maxInput) 598 599 if !logSoftmax { 600 outputPart[i] = exp 601 } 602 603 sumExp += exp 604 } 605 606 if logSoftmax { 607 sumExp = math32.Log(sumExp) 608 } else { 609 sumExp = 1 / sumExp 610 } 611 612 for j := 0; j < dimSize; j++ { 613 i := j * dimStride 614 615 if logSoftmax { 616 outputPart[i] = inputPart[i] - maxInput - sumExp 617 } else { 618 outputPart[i] *= sumExp 619 } 620 } 621 wg.Done() 622 }(ii, &wg) 623 } 624 wg.Wait() 625 } 626 627 func (e StdEng) softMaxBInnerDimF32(inputGrad, output, grad Tensor, axis int, logSoftmax bool) { 628 dxShape := inputGrad.Shape() 629 630 innerSize, outerSize := 1, 1 631 for i := 0; i < axis; i++ { 632 outerSize *= dxShape[i] 633 } 634 635 for i := axis + 1; i < dxShape.Dims(); i++ { 636 innerSize *= dxShape[i] 637 } 638 639 dimSize := dxShape[axis] 640 dimStride := innerSize 641 outerStride := dimSize * dimStride 642 643 dxArr := getFloat32s(inputGrad) 644 outputArr := getFloat32s(output) 645 gradArr := getFloat32s(grad) 646 647 var wg sync.WaitGroup 648 for ii := 0; ii < innerSize*outerSize; ii++ { 649 wg.Add(1) 650 651 go func(ii int, wg *sync.WaitGroup) { 652 outerIndex, innerIndex := divmod(ii, innerSize) 653 654 gradPart := gradArr[outerIndex*outerStride+innerIndex:] 655 dxPart := dxArr[outerIndex*outerStride+innerIndex:] 656 outputPart := outputArr[outerIndex*outerStride+innerIndex:] 657 658 sum := float32(0.0) 659 for j := 0; j < dimSize; j++ { 660 i := j * dimStride 661 662 if logSoftmax { 663 sum += gradPart[i] 664 } else { 665 sum += gradPart[i] * outputPart[i] 666 } 667 } 668 669 for j := 0; j < dimSize; j++ { 670 i := j * dimStride 671 672 if logSoftmax { 673 dxPart[i] = gradPart[i] - math32.Exp(outputPart[i])*sum 674 } else { 675 dxPart[i] = outputPart[i] * (gradPart[i] - sum) 676 } 677 } 678 wg.Done() 679 }(ii, &wg) 680 } 681 wg.Wait() 682 }