github.com/wzzhu/tensor@v0.9.24/defaultengine_softmax.go (about)

     1  package tensor
     2  
     3  import (
     4  	"fmt"
     5  	"math"
     6  	"sync"
     7  
     8  	"github.com/chewxy/math32"
     9  	"github.com/pkg/errors"
    10  )
    11  
    12  // if dims = 2 and axis -1 it returns the last dimension. In this case 1
    13  func resolveAxis(axis int, dims int) int {
    14  	res := axis % dims
    15  	if (res < 0 && dims > 0) || (res > 0 && dims < 0) {
    16  		return res + dims
    17  	}
    18  
    19  	return res
    20  }
    21  
    22  // SoftMax performs the softmax operation on the given tensor. Currently it expects the tensor to be a Dense tensor.
    23  // Please make a pull request to support sparse tensors.
    24  //
    25  // The softmax function is defined as :
    26  //	σ(x) = e^x_i / Σ(e^x_i)
    27  func (e StdEng) SoftMax(x Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) {
    28  	axis = resolveAxis(axis, x.Dims())
    29  	expectedShape := x.Shape()
    30  
    31  	var reuse DenseTensor
    32  	var safe, toReuse, _ bool
    33  	if reuse, safe, toReuse, _, _, err = handleFuncOpts(expectedShape, x.Dtype(), x.DataOrder(), true, opts...); err != nil {
    34  		return nil, errors.Wrap(err, "Unable to handle funcOpts")
    35  	}
    36  	if safe || !toReuse && reuse == nil && safe {
    37  		// create reuse
    38  		reuse = New(WithShape(expectedShape...), Of(x.Dtype()))
    39  	}
    40  
    41  	switch x.Dtype() {
    42  	case Float32:
    43  		if expectedShape.Dims()-1 == axis {
    44  			e.softMaxLastDimF32(reuse, x, axis, false)
    45  		} else {
    46  			e.softMaxInnerDimF32(reuse, x, axis, false)
    47  		}
    48  	case Float64:
    49  		if expectedShape.Dims()-1 == axis {
    50  			e.softMaxLastDimF64(reuse, x, axis, false)
    51  		} else {
    52  			e.softMaxInnerDimF64(reuse, x, axis, false)
    53  		}
    54  	default:
    55  		return nil, fmt.Errorf("type %v not supported", x.Dtype())
    56  	}
    57  
    58  	return reuse, nil
    59  }
    60  
    61  // SoftMaxB computes gradient of the input `x`, given the `output = SoftMax(x)` and its associated gradient. Currently it expects the tensor to be a Dense tensor.
    62  // Please make a pull request to support sparse tensors.
    63  func (e StdEng) SoftMaxB(output, grad Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) {
    64  	if !output.Shape().Eq(grad.Shape()) {
    65  		return nil, fmt.Errorf("output and grad shapes don't match")
    66  	}
    67  
    68  	if !output.Dtype().Eq(grad.Dtype()) {
    69  		return nil, fmt.Errorf("output and grad types don't match")
    70  	}
    71  
    72  	axis = resolveAxis(axis, output.Dims())
    73  	expectedShape := output.Shape()
    74  
    75  	var reuse DenseTensor
    76  	var safe, toReuse, _ bool
    77  	if reuse, safe, toReuse, _, _, err = handleFuncOpts(expectedShape, output.Dtype(), output.DataOrder(), true, opts...); err != nil {
    78  		return nil, errors.Wrap(err, "Unable to handle funcOpts")
    79  	}
    80  	if safe || !toReuse && reuse == nil && safe {
    81  		// create reuse
    82  		reuse = New(WithShape(expectedShape...), Of(output.Dtype()))
    83  	}
    84  
    85  	switch output.Dtype() {
    86  	case Float32:
    87  		if expectedShape.Dims()-1 == axis {
    88  			e.softMaxBLastDimF32(reuse, output, grad, axis, false)
    89  		} else {
    90  			e.softMaxBInnerDimF32(reuse, output, grad, axis, false)
    91  		}
    92  	case Float64:
    93  		if expectedShape.Dims()-1 == axis {
    94  			e.softMaxBLastDimF64(reuse, output, grad, axis, false)
    95  		} else {
    96  			e.softMaxBInnerDimF64(reuse, output, grad, axis, false)
    97  		}
    98  	default:
    99  		return nil, fmt.Errorf("type %v not supported", output.Dtype())
   100  	}
   101  
   102  	return reuse, nil
   103  }
   104  
   105  // LogSoftMax performs softmax but in log space. This provides some amount of numerical stabilization.
   106  // Conceptually it is the same as performing a logarithm after applying the softmax function.
   107  // Currently it expects the tensor to be a Dense tensor.
   108  // Please make a pull request to support sparse tensors.
   109  func (e StdEng) LogSoftMax(x Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) {
   110  	axis = resolveAxis(axis, x.Dims())
   111  	expectedShape := x.Shape()
   112  
   113  	var reuse DenseTensor
   114  	var safe, toReuse, _ bool
   115  	if reuse, safe, toReuse, _, _, err = handleFuncOpts(expectedShape, x.Dtype(), x.DataOrder(), true, opts...); err != nil {
   116  		return nil, errors.Wrap(err, "Unable to handle funcOpts")
   117  	}
   118  	if safe || !toReuse && reuse == nil && safe {
   119  		// create reuse
   120  		reuse = New(WithShape(expectedShape...), Of(x.Dtype()))
   121  	}
   122  
   123  	switch x.Dtype() {
   124  	case Float32:
   125  		if expectedShape.Dims()-1 == axis {
   126  			e.softMaxLastDimF32(reuse, x, axis, true)
   127  		} else {
   128  			e.softMaxInnerDimF32(reuse, x, axis, true)
   129  		}
   130  	case Float64:
   131  		if expectedShape.Dims()-1 == axis {
   132  			e.softMaxLastDimF64(reuse, x, axis, true)
   133  		} else {
   134  			e.softMaxInnerDimF64(reuse, x, axis, true)
   135  		}
   136  	default:
   137  		return nil, fmt.Errorf("type %v not supported", x.Dtype())
   138  	}
   139  
   140  	return reuse, nil
   141  }
   142  
   143  // LogSoftMaxB computes the gradient of the input `x`, given the `output = LogSoftmax(x)` and its associated gradient.
   144  // Currently it expects the tensor to be a Dense tensor.
   145  // Please make a pull request to support sparse tensors.
   146  func (e StdEng) LogSoftMaxB(output, grad Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) {
   147  	if !output.Shape().Eq(grad.Shape()) {
   148  		return nil, fmt.Errorf("output and grad shapes don't match")
   149  	}
   150  
   151  	if !output.Dtype().Eq(grad.Dtype()) {
   152  		return nil, fmt.Errorf("output and grad types don't match")
   153  	}
   154  
   155  	axis = resolveAxis(axis, output.Dims())
   156  	expectedShape := output.Shape()
   157  
   158  	var reuse DenseTensor
   159  	var safe, toReuse, _ bool
   160  	if reuse, safe, toReuse, _, _, err = handleFuncOpts(expectedShape, output.Dtype(), output.DataOrder(), true, opts...); err != nil {
   161  		return nil, errors.Wrap(err, "Unable to handle funcOpts")
   162  	}
   163  	if safe || !toReuse && reuse == nil && safe {
   164  		// create reuse
   165  		reuse = New(WithShape(expectedShape...), Of(output.Dtype()))
   166  	}
   167  
   168  	switch output.Dtype() {
   169  	case Float32:
   170  		if expectedShape.Dims()-1 == axis {
   171  			e.softMaxBLastDimF32(reuse, output, grad, axis, true)
   172  		} else {
   173  			e.softMaxBInnerDimF32(reuse, output, grad, axis, true)
   174  		}
   175  	case Float64:
   176  		if expectedShape.Dims()-1 == axis {
   177  			e.softMaxBLastDimF64(reuse, output, grad, axis, true)
   178  		} else {
   179  			e.softMaxBInnerDimF64(reuse, output, grad, axis, true)
   180  		}
   181  	default:
   182  		return nil, fmt.Errorf("type %v not supported", output.Dtype())
   183  	}
   184  
   185  	return reuse, nil
   186  }
   187  
   188  func (e StdEng) softMaxLastDimF64(output Tensor, x Tensor, axis int, logSoftMax bool) {
   189  	outputArr := getFloat64s(output)
   190  	xArr := getFloat64s(x)
   191  
   192  	xShape := x.Shape()
   193  
   194  	outerSize := 1
   195  	dimSize := xShape[axis]
   196  	for i := 0; i < axis; i++ {
   197  		outerSize *= xShape[i]
   198  	}
   199  
   200  	var wg sync.WaitGroup
   201  	for ii := 0; ii < outerSize; ii++ {
   202  		wg.Add(1)
   203  		go func(ii int, wg *sync.WaitGroup) {
   204  			maxInput := xArr[0]
   205  			for j := 1; j < dimSize; j++ {
   206  				i := ii*dimSize + j
   207  
   208  				if xArr[i] > maxInput {
   209  					maxInput = xArr[i]
   210  				}
   211  			}
   212  
   213  			sumExp := float64(0.0)
   214  			for j := 0; j < dimSize; j++ {
   215  				i := ii*dimSize + j
   216  				z := xArr[i] - maxInput
   217  				exp := math.Exp(z)
   218  
   219  				if logSoftMax {
   220  					outputArr[i] = z
   221  				} else {
   222  					outputArr[i] = exp
   223  				}
   224  
   225  				sumExp += exp
   226  			}
   227  
   228  			if !logSoftMax {
   229  				sumExp = 1 / sumExp
   230  			}
   231  
   232  			for j := 0; j < dimSize; j++ {
   233  				i := ii*dimSize + j
   234  
   235  				if logSoftMax {
   236  					outputArr[i] -= math.Log(sumExp)
   237  				} else {
   238  					outputArr[i] *= sumExp
   239  				}
   240  			}
   241  			wg.Done()
   242  		}(ii, &wg)
   243  
   244  	}
   245  	wg.Wait()
   246  }
   247  
   248  func (e StdEng) softMaxBLastDimF64(inputGrad, output, grad Tensor, axis int, logSoftMax bool) {
   249  	dx := getFloat64s(inputGrad)
   250  	outputArr := getFloat64s(output)
   251  	gradArr := getFloat64s(grad)
   252  
   253  	outputShape := output.Shape()
   254  
   255  	outerSize := 1
   256  	dimSize := outputShape[axis]
   257  	for i := 0; i < axis; i++ {
   258  		outerSize *= outputShape[i]
   259  	}
   260  
   261  	var wg sync.WaitGroup
   262  	for ii := 0; ii < outerSize; ii++ {
   263  		wg.Add(1)
   264  		if logSoftMax {
   265  			go func(gradArr, dx []float64, ii int, wg *sync.WaitGroup) {
   266  				sum := gradArr[ii*dimSize]
   267  				for j := 1; j < dimSize; j++ {
   268  					i := ii*dimSize + j
   269  
   270  					sum += gradArr[i]
   271  				}
   272  
   273  				for j := 0; j < dimSize; j++ {
   274  					i := ii*dimSize + j
   275  
   276  					dx[i] = gradArr[i] - (math.Exp(outputArr[i]) * sum)
   277  				}
   278  				wg.Done()
   279  			}(gradArr, dx, ii, &wg)
   280  
   281  		} else {
   282  			go func(outputArr, gradArr, dx []float64, ii int, wg *sync.WaitGroup) {
   283  				//mul := make([]float64, dimSize)
   284  				var sum float64
   285  				for j := 0; j < dimSize; j++ {
   286  					i := ii*dimSize + j
   287  
   288  					//mul[j] = outputArr[i] * gradArr[i]
   289  					sum += outputArr[i] * gradArr[i]
   290  				}
   291  
   292  				// sum := mul[0]
   293  				// for j := 1; j < dimSize; j++ {
   294  				// 	sum += mul[j]
   295  				// }
   296  
   297  				for j := 0; j < dimSize; j++ {
   298  					i := ii*dimSize + j
   299  					dx[i] = (gradArr[i] - sum) * outputArr[i]
   300  				}
   301  				wg.Done()
   302  			}(outputArr, gradArr, dx, ii, &wg)
   303  		}
   304  	}
   305  	wg.Wait()
   306  }
   307  
   308  func (e StdEng) softMaxInnerDimF64(output Tensor, x Tensor, axis int, logSoftmax bool) {
   309  	xShape := x.Shape()
   310  
   311  	innerSize, outerSize := 1, 1
   312  	for i := 0; i < axis; i++ {
   313  		outerSize *= xShape[i]
   314  	}
   315  
   316  	for i := axis + 1; i < xShape.Dims(); i++ {
   317  		innerSize *= xShape[i]
   318  	}
   319  
   320  	dimSize := xShape[axis]
   321  	dimStride := innerSize
   322  	outerStride := dimSize * dimStride
   323  
   324  	outputArr := getFloat64s(output)
   325  	xArr := getFloat64s(x)
   326  
   327  	var wg sync.WaitGroup
   328  	for ii := 0; ii < innerSize*outerSize; ii++ {
   329  		wg.Add(1)
   330  		go func(ii int, wg *sync.WaitGroup) {
   331  			outerIndex, innerIndex := divmod(ii, innerSize)
   332  
   333  			inputPart := xArr[outerIndex*outerStride+innerIndex:]
   334  			outputPart := outputArr[outerIndex*outerStride+innerIndex:]
   335  
   336  			maxInput := inputPart[0]
   337  			for j := 1; j < dimSize; j++ {
   338  				i := j * dimStride
   339  
   340  				if inputPart[i] > maxInput {
   341  					maxInput = inputPart[i]
   342  				}
   343  			}
   344  
   345  			sumExp := 0.0
   346  			for j := 0; j < dimSize; j++ {
   347  				i := j * dimStride
   348  
   349  				exp := math.Exp(inputPart[i] - maxInput)
   350  
   351  				if !logSoftmax {
   352  					outputPart[i] = exp
   353  				}
   354  
   355  				sumExp += exp
   356  			}
   357  
   358  			if logSoftmax {
   359  				sumExp = math.Log(sumExp)
   360  			} else {
   361  				sumExp = 1 / sumExp
   362  			}
   363  
   364  			for j := 0; j < dimSize; j++ {
   365  				i := j * dimStride
   366  
   367  				if logSoftmax {
   368  					outputPart[i] = inputPart[i] - maxInput - sumExp
   369  				} else {
   370  					outputPart[i] *= sumExp
   371  				}
   372  			}
   373  			wg.Done()
   374  		}(ii, &wg)
   375  	}
   376  	wg.Wait()
   377  }
   378  
   379  func (e StdEng) softMaxBInnerDimF64(inputGrad, output, grad Tensor, axis int, logSoftmax bool) {
   380  	dxShape := inputGrad.Shape()
   381  
   382  	innerSize, outerSize := 1, 1
   383  	for i := 0; i < axis; i++ {
   384  		outerSize *= dxShape[i]
   385  	}
   386  
   387  	for i := axis + 1; i < dxShape.Dims(); i++ {
   388  		innerSize *= dxShape[i]
   389  	}
   390  
   391  	dimSize := dxShape[axis]
   392  	dimStride := innerSize
   393  	outerStride := dimSize * dimStride
   394  
   395  	dxArr := getFloat64s(inputGrad)
   396  	outputArr := getFloat64s(output)
   397  	gradArr := getFloat64s(grad)
   398  
   399  	var wg sync.WaitGroup
   400  	for ii := 0; ii < innerSize*outerSize; ii++ {
   401  		wg.Add(1)
   402  		go func(ii int, wg *sync.WaitGroup) {
   403  			outerIndex, innerIndex := divmod(ii, innerSize)
   404  
   405  			gradPart := gradArr[outerIndex*outerStride+innerIndex:]
   406  			dxPart := dxArr[outerIndex*outerStride+innerIndex:]
   407  			outputPart := outputArr[outerIndex*outerStride+innerIndex:]
   408  
   409  			sum := 0.0
   410  			for j := 0; j < dimSize; j++ {
   411  				i := j * dimStride
   412  
   413  				if logSoftmax {
   414  					sum += gradPart[i]
   415  				} else {
   416  					sum += gradPart[i] * outputPart[i]
   417  				}
   418  			}
   419  
   420  			for j := 0; j < dimSize; j++ {
   421  				i := j * dimStride
   422  
   423  				if logSoftmax {
   424  					dxPart[i] = gradPart[i] - math.Exp(outputPart[i])*sum
   425  				} else {
   426  					dxPart[i] = outputPart[i] * (gradPart[i] - sum)
   427  				}
   428  			}
   429  			wg.Done()
   430  		}(ii, &wg)
   431  
   432  	}
   433  	wg.Wait()
   434  }
   435  
   436  func (e StdEng) softMaxLastDimF32(output Tensor, x Tensor, axis int, logSoftMax bool) {
   437  	outputArr := getFloat32s(output)
   438  	xArr := getFloat32s(x)
   439  	xShape := x.Shape()
   440  
   441  	outerSize := 1
   442  	dimSize := xShape[axis]
   443  	for i := 0; i < axis; i++ {
   444  		outerSize *= xShape[i]
   445  	}
   446  
   447  	var wg sync.WaitGroup
   448  	for ii := 0; ii < outerSize; ii++ {
   449  		wg.Add(1)
   450  		go func(ii int, wg *sync.WaitGroup) {
   451  			maxInput := xArr[0]
   452  			for j := 1; j < dimSize; j++ {
   453  				i := ii*dimSize + j
   454  
   455  				if xArr[i] > maxInput {
   456  					maxInput = xArr[i]
   457  				}
   458  			}
   459  
   460  			sumExp := float32(0.0)
   461  			for j := 0; j < dimSize; j++ {
   462  				i := ii*dimSize + j
   463  				z := xArr[i] - maxInput
   464  				exp := math32.Exp(z)
   465  
   466  				if logSoftMax {
   467  					outputArr[i] = z
   468  				} else {
   469  					outputArr[i] = exp
   470  				}
   471  
   472  				sumExp += exp
   473  			}
   474  
   475  			if !logSoftMax {
   476  				sumExp = 1 / sumExp
   477  			}
   478  
   479  			for j := 0; j < dimSize; j++ {
   480  				i := ii*dimSize + j
   481  
   482  				if logSoftMax {
   483  					outputArr[i] -= math32.Log(sumExp)
   484  				} else {
   485  					outputArr[i] *= sumExp
   486  				}
   487  			}
   488  			wg.Done()
   489  		}(ii, &wg)
   490  	}
   491  	wg.Wait()
   492  }
   493  
   494  func (e StdEng) softMaxBLastDimF32(inputGrad, output, grad Tensor, axis int, logSoftMax bool) {
   495  	dx := getFloat32s(inputGrad)
   496  	outputArr := getFloat32s(output)
   497  	gradArr := getFloat32s(grad)
   498  
   499  	outputShape := output.Shape()
   500  
   501  	outerSize := 1
   502  	dimSize := outputShape[axis]
   503  	for i := 0; i < axis; i++ {
   504  		outerSize *= outputShape[i]
   505  	}
   506  
   507  	var wg sync.WaitGroup
   508  	for ii := 0; ii < outerSize; ii++ {
   509  		wg.Add(1)
   510  
   511  		if logSoftMax {
   512  			go func(ii int, wg *sync.WaitGroup) {
   513  				sum := gradArr[ii*dimSize]
   514  				for j := 1; j < dimSize; j++ {
   515  					i := ii*dimSize + j
   516  
   517  					sum += gradArr[i]
   518  				}
   519  
   520  				for j := 0; j < dimSize; j++ {
   521  					i := ii*dimSize + j
   522  
   523  					dx[i] = gradArr[i] - (math32.Exp(outputArr[i]) * sum)
   524  				}
   525  				wg.Done()
   526  			}(ii, &wg)
   527  		} else {
   528  			go func(ii int, wg *sync.WaitGroup) {
   529  				//mul := make([]float32, dimSize)
   530  				var sum float32
   531  				for j := 0; j < dimSize; j++ {
   532  					i := ii*dimSize + j
   533  
   534  					//mul[j] = outputArr[i] * gradArr[i]
   535  					sum += outputArr[i] * gradArr[i]
   536  				}
   537  
   538  				// sum := mul[0]
   539  				// for j := 1; j < dimSize; j++ {
   540  				// 	sum += mul[j]
   541  				// }
   542  
   543  				for j := 0; j < dimSize; j++ {
   544  					i := ii*dimSize + j
   545  
   546  					dx[i] = (gradArr[i] - sum) * outputArr[i]
   547  				}
   548  				wg.Done()
   549  			}(ii, &wg)
   550  		}
   551  	}
   552  	wg.Wait()
   553  }
   554  
   555  func (e StdEng) softMaxInnerDimF32(output Tensor, x Tensor, axis int, logSoftmax bool) {
   556  	xShape := x.Shape()
   557  
   558  	innerSize, outerSize := 1, 1
   559  	for i := 0; i < axis; i++ {
   560  		outerSize *= xShape[i]
   561  	}
   562  
   563  	for i := axis + 1; i < xShape.Dims(); i++ {
   564  		innerSize *= xShape[i]
   565  	}
   566  
   567  	dimSize := xShape[axis]
   568  	dimStride := innerSize
   569  	outerStride := dimSize * dimStride
   570  
   571  	outputArr := getFloat32s(output)
   572  	xArr := getFloat32s(x)
   573  
   574  	var wg sync.WaitGroup
   575  	for ii := 0; ii < innerSize*outerSize; ii++ {
   576  		wg.Add(1)
   577  
   578  		go func(ii int, wg *sync.WaitGroup) {
   579  			outerIndex, innerIndex := divmod(ii, innerSize)
   580  
   581  			inputPart := xArr[outerIndex*outerStride+innerIndex:]
   582  			outputPart := outputArr[outerIndex*outerStride+innerIndex:]
   583  
   584  			maxInput := inputPart[0]
   585  			for j := 1; j < dimSize; j++ {
   586  				i := j * dimStride
   587  
   588  				if inputPart[i] > maxInput {
   589  					maxInput = inputPart[i]
   590  				}
   591  			}
   592  
   593  			sumExp := float32(0.0)
   594  			for j := 0; j < dimSize; j++ {
   595  				i := j * dimStride
   596  
   597  				exp := math32.Exp(inputPart[i] - maxInput)
   598  
   599  				if !logSoftmax {
   600  					outputPart[i] = exp
   601  				}
   602  
   603  				sumExp += exp
   604  			}
   605  
   606  			if logSoftmax {
   607  				sumExp = math32.Log(sumExp)
   608  			} else {
   609  				sumExp = 1 / sumExp
   610  			}
   611  
   612  			for j := 0; j < dimSize; j++ {
   613  				i := j * dimStride
   614  
   615  				if logSoftmax {
   616  					outputPart[i] = inputPart[i] - maxInput - sumExp
   617  				} else {
   618  					outputPart[i] *= sumExp
   619  				}
   620  			}
   621  			wg.Done()
   622  		}(ii, &wg)
   623  	}
   624  	wg.Wait()
   625  }
   626  
   627  func (e StdEng) softMaxBInnerDimF32(inputGrad, output, grad Tensor, axis int, logSoftmax bool) {
   628  	dxShape := inputGrad.Shape()
   629  
   630  	innerSize, outerSize := 1, 1
   631  	for i := 0; i < axis; i++ {
   632  		outerSize *= dxShape[i]
   633  	}
   634  
   635  	for i := axis + 1; i < dxShape.Dims(); i++ {
   636  		innerSize *= dxShape[i]
   637  	}
   638  
   639  	dimSize := dxShape[axis]
   640  	dimStride := innerSize
   641  	outerStride := dimSize * dimStride
   642  
   643  	dxArr := getFloat32s(inputGrad)
   644  	outputArr := getFloat32s(output)
   645  	gradArr := getFloat32s(grad)
   646  
   647  	var wg sync.WaitGroup
   648  	for ii := 0; ii < innerSize*outerSize; ii++ {
   649  		wg.Add(1)
   650  
   651  		go func(ii int, wg *sync.WaitGroup) {
   652  			outerIndex, innerIndex := divmod(ii, innerSize)
   653  
   654  			gradPart := gradArr[outerIndex*outerStride+innerIndex:]
   655  			dxPart := dxArr[outerIndex*outerStride+innerIndex:]
   656  			outputPart := outputArr[outerIndex*outerStride+innerIndex:]
   657  
   658  			sum := float32(0.0)
   659  			for j := 0; j < dimSize; j++ {
   660  				i := j * dimStride
   661  
   662  				if logSoftmax {
   663  					sum += gradPart[i]
   664  				} else {
   665  					sum += gradPart[i] * outputPart[i]
   666  				}
   667  			}
   668  
   669  			for j := 0; j < dimSize; j++ {
   670  				i := j * dimStride
   671  
   672  				if logSoftmax {
   673  					dxPart[i] = gradPart[i] - math32.Exp(outputPart[i])*sum
   674  				} else {
   675  					dxPart[i] = outputPart[i] * (gradPart[i] - sum)
   676  				}
   677  			}
   678  			wg.Done()
   679  		}(ii, &wg)
   680  	}
   681  	wg.Wait()
   682  }