gorgonia.org/gorgonia@v0.9.17/op_softmax.go (about)

     1  package gorgonia
     2  
     3  import (
     4  	"fmt"
     5  	"hash"
     6  	"math"
     7  	"runtime"
     8  	"sync"
     9  
    10  	"github.com/chewxy/hm"
    11  	"github.com/chewxy/math32"
    12  	"github.com/pkg/errors"
    13  	"gorgonia.org/tensor"
    14  )
    15  
    16  type softmaxOp struct {
    17  	shape tensor.Shape
    18  	axis  int
    19  	isLog bool
    20  }
    21  
    22  func newSoftmaxOp(inputShape tensor.Shape, axes ...int) *softmaxOp {
    23  	axis := -1
    24  	if len(axes) > 0 {
    25  		axis = axes[0]
    26  	}
    27  	softmaxop := &softmaxOp{
    28  		shape: inputShape,
    29  		axis:  axis,
    30  	}
    31  
    32  	return softmaxop
    33  }
    34  
    35  func (op *softmaxOp) Arity() int { return 1 }
    36  
    37  func (op *softmaxOp) ReturnsPtr() bool { return false }
    38  
    39  func (op *softmaxOp) CallsExtern() bool { return false }
    40  
    41  func (op *softmaxOp) WriteHash(h hash.Hash) { fmt.Fprintf(h, "Softmax{%v}()", op.axis) }
    42  
    43  func (op *softmaxOp) Hashcode() uint32 { return simpleHash(op) }
    44  
    45  func (op *softmaxOp) String() string { return fmt.Sprintf("Softmax{%d, %v}()", op.axis, op.isLog) }
    46  
    47  func (op *softmaxOp) InferShape(inputs ...DimSizer) (tensor.Shape, error) {
    48  	return inputs[0].(tensor.Shape), nil
    49  }
    50  
    51  func (op *softmaxOp) Type() hm.Type {
    52  	a := hm.TypeVariable('a')
    53  	return hm.NewFnType(a, a) // f(float64) float64
    54  }
    55  
    56  func (op *softmaxOp) OverwritesInput() int { return -1 }
    57  
    58  func (op *softmaxOp) checkInput(inputs ...Value) (tensor.Tensor, error) {
    59  	if err := checkArity(op, len(inputs)); err != nil {
    60  		return nil, err
    61  	}
    62  
    63  	var (
    64  		in tensor.Tensor
    65  		ok bool
    66  	)
    67  
    68  	if in, ok = inputs[0].(tensor.Tensor); !ok {
    69  		return nil, errors.Errorf("Expected input to be a tensor")
    70  	}
    71  
    72  	return in, nil
    73  }
    74  
    75  func (op *softmaxOp) Do(inputs ...Value) (retVal Value, err error) {
    76  	inputTensor, err := op.checkInput(inputs...)
    77  	if err != nil {
    78  		return nil, fmt.Errorf("Can't check Softmax input: %w", err)
    79  	}
    80  
    81  	aShape := inputTensor.Shape()
    82  	axis := aShape.Dims() - 1 // default: last dim
    83  
    84  	if aShape.IsColVec() || (aShape.IsVector() && !aShape.IsRowVec()) {
    85  		axis = 0
    86  	}
    87  	if op.axis != -1 {
    88  		axis = op.axis
    89  	}
    90  
    91  	ret := tensor.New(tensor.WithShape(aShape.Clone()...), tensor.Of(inputTensor.Dtype()))
    92  	data := inputTensor.Data()
    93  	output := ret.Data()
    94  	op.do(aShape, axis, data, output)
    95  	return ret, nil
    96  
    97  }
    98  
    99  func (op *softmaxOp) UsePreallocDo(prealloc Value, inputs ...Value) (Value, error) {
   100  	inputTensor, err := op.checkInput(inputs...)
   101  	if err != nil {
   102  		return nil, fmt.Errorf("Can't check Softmax input: %w", err)
   103  	}
   104  
   105  	aShape := inputTensor.Shape()
   106  	axis := aShape.Dims() - 1 // default: last dim
   107  
   108  	if aShape.IsColVec() || (aShape.IsVector() && !aShape.IsRowVec()) {
   109  		axis = 0
   110  	}
   111  	if op.axis != -1 {
   112  		axis = op.axis
   113  	}
   114  
   115  	op.do(aShape, axis, inputTensor.Data(), prealloc.Data())
   116  	return prealloc, nil
   117  }
   118  
   119  // DoDiff calculates the diff and sets its value to the output node. Implementation for ADOp interface.
   120  func (op *softmaxOp) DoDiff(ctx ExecutionContext, inputs Nodes, output *Node) error {
   121  	if len(inputs) != 1 {
   122  		return fmt.Errorf("SoftmaxOp.DoDiff needs 1 arguments")
   123  	}
   124  
   125  	odv := output.boundTo.(*dualValue)
   126  	idv := inputs[0].boundTo.(*dualValue)
   127  	idvd := idv.d.(*tensor.Dense)
   128  	diffOp := &softmaxDiffOp{op}
   129  
   130  	result, err := diffOp.Do(idv.Value, odv.Value, odv.d)
   131  	if err != nil {
   132  		return err
   133  	}
   134  
   135  	sum, err := idvd.Add(result.(*tensor.Dense), tensor.UseUnsafe())
   136  	if err != nil {
   137  		return err
   138  	}
   139  
   140  	odv.d = sum
   141  
   142  	return nil
   143  }
   144  
   145  // SymDiff applies the diff op. Implementation for SDOp interface.
   146  func (op *softmaxOp) SymDiff(inputs Nodes, output, grad *Node) (Nodes, error) {
   147  	err := checkArity(op, len(inputs))
   148  	if err != nil {
   149  		return nil, err
   150  	}
   151  
   152  	diffOp := &softmaxDiffOp{op}
   153  	nodes := make(Nodes, 1)
   154  
   155  	nodes[0], err = ApplyOp(diffOp, inputs[0], output, grad)
   156  
   157  	return nodes, err
   158  }
   159  
   160  // DiffWRT is an implementation for the SDOp interface
   161  func (op *softmaxOp) DiffWRT(inputs int) []bool {
   162  	if inputs != 1 {
   163  		panic(fmt.Sprintf("softmax operator only supports one input, got %d instead", inputs))
   164  	}
   165  
   166  	return []bool{true}
   167  }
   168  
   169  func (op *softmaxOp) f64skernel(data, output []float64, inner, ostride, dimSize, dimStride int) {
   170  	for i := 0; i < len(data); i++ {
   171  		oi := i / inner
   172  		ii := i % inner
   173  		xidx := oi*ostride + ii
   174  		yidx := oi*ostride + ii
   175  
   176  		if xidx >= len(data) {
   177  			continue
   178  		}
   179  
   180  		if yidx >= len(data) {
   181  			continue
   182  		}
   183  
   184  		x := data[xidx:]
   185  		y := output[yidx:]
   186  		if len(x) == 0 {
   187  			continue
   188  		}
   189  
   190  		max := x[0]
   191  		for d := 1; d < dimSize && d*dimStride < len(x); d++ {
   192  			dm := x[d*dimStride]
   193  			if dm > max {
   194  				max = dm
   195  			}
   196  		}
   197  
   198  		var sum float64
   199  		for d := 0; d < dimSize && d*dimStride < len(x); d++ {
   200  			z := math.Exp(x[d*dimStride] - max)
   201  			if !op.isLog {
   202  				y[d*dimStride] = z
   203  			}
   204  			sum += z
   205  		}
   206  
   207  		if op.isLog {
   208  			sum = math.Log(sum)
   209  		} else {
   210  			sum = 1 / sum
   211  		}
   212  
   213  		// set output
   214  		for d := 0; d < dimSize && d*dimStride < len(y); d++ {
   215  			if op.isLog {
   216  				y[d*dimStride] = x[d*dimStride] - max - sum
   217  			} else {
   218  				y[d*dimStride] *= sum
   219  			}
   220  		}
   221  
   222  	}
   223  }
   224  
   225  func (op *softmaxOp) f32skernel(data, output []float32, inner, ostride, dimSize, dimStride int) {
   226  	for i := 0; i < len(data); i++ {
   227  		oi := i / inner
   228  		ii := i % inner
   229  		xidx := oi*ostride + ii
   230  		yidx := oi*ostride + ii
   231  
   232  		if xidx >= len(data) {
   233  			continue
   234  		}
   235  
   236  		if yidx >= len(output) {
   237  			continue
   238  		}
   239  
   240  		x := data[xidx:]
   241  		y := output[yidx:]
   242  		if len(x) == 0 {
   243  			continue
   244  		}
   245  
   246  		max := x[0]
   247  		for d := 1; d < dimSize && d*dimStride < len(x); d++ {
   248  			dm := x[d*dimStride]
   249  			if dm > max {
   250  				max = dm
   251  			}
   252  		}
   253  
   254  		var tmp float32
   255  		for d := 0; d < dimSize && d*dimStride < len(x); d++ {
   256  			z := math32.Exp(x[d*dimStride] - max)
   257  			if !op.isLog {
   258  				y[d*dimStride] = z
   259  			}
   260  			tmp += z
   261  		}
   262  
   263  		if op.isLog {
   264  			tmp = math32.Log(tmp)
   265  		} else {
   266  			tmp = 1 / tmp
   267  		}
   268  
   269  		// set output
   270  		for d := 0; d < dimSize && d*dimStride < len(y); d++ {
   271  			if op.isLog {
   272  				y[d*dimStride] = x[d*dimStride] - max - tmp
   273  			} else {
   274  				y[d*dimStride] *= tmp
   275  			}
   276  		}
   277  	}
   278  }
   279  
   280  // output and data are of the same size
   281  func (op *softmaxOp) do(shp tensor.Shape, axis int, input, output interface{}) {
   282  	threads := runtime.GOMAXPROCS(0) + 1
   283  
   284  	dimSize := shp[axis]
   285  	outer := tensor.ProdInts([]int(shp[:axis]))
   286  	inner := tensor.ProdInts([]int(shp[axis+1:]))
   287  	if outer == 0 {
   288  		outer = 1
   289  	}
   290  	if inner == 0 {
   291  		inner = 1
   292  	}
   293  	dimStride := inner
   294  	ostride := dimSize * dimStride
   295  	axisSize := shp[axis]
   296  
   297  	datalen := shp.TotalSize()
   298  	blockSize := calcBlocks(datalen, threads)
   299  	blocks := datalen / blockSize
   300  	if blockSize == 0 || blockSize < axisSize {
   301  		blockSize = datalen
   302  		blocks = 1
   303  	}
   304  
   305  	if blocks < minParallelBlocks {
   306  		switch data := input.(type) {
   307  		case []float64:
   308  			output := output.([]float64)
   309  			op.f64skernel(data, output, inner, ostride, dimSize, dimStride)
   310  		case []float32:
   311  			output := output.([]float32)
   312  			op.f32skernel(data, output, inner, ostride, dimSize, dimStride)
   313  		default:
   314  			panic(fmt.Sprintf("tensor of %T not handled for softmax diff ", data))
   315  		}
   316  		return
   317  	}
   318  
   319  	workers := workersChan()
   320  	var wg sync.WaitGroup
   321  
   322  	for b := 0; b < datalen; b += blockSize {
   323  		wg.Add(1)
   324  		switch data := input.(type) {
   325  		case []float64:
   326  			output := output.([]float64)
   327  			end := b + blockSize
   328  			if end > len(data) {
   329  				end = len(data)
   330  			}
   331  			newdata := data[b:end]
   332  			newoutput := output[b:end]
   333  			go func(data, output []float64, dimSize, dimStride int, wg *sync.WaitGroup) {
   334  				workers <- struct{}{}
   335  				op.f64skernel(data, output, inner, ostride, dimSize, dimStride)
   336  				wg.Done()
   337  				<-workers
   338  			}(newdata, newoutput, dimSize, dimStride, &wg)
   339  		case []float32:
   340  			output := output.([]float32)
   341  			end := b + blockSize
   342  			if end > len(data) {
   343  				end = len(data)
   344  			}
   345  			newdata := data[b:end]
   346  			newoutput := output[b:end]
   347  			go func(data, output []float32, dimSize, dimStride int, wg *sync.WaitGroup) {
   348  				workers <- struct{}{}
   349  				op.f32skernel(data, output, inner, ostride, dimSize, dimStride)
   350  				wg.Done()
   351  				<-workers
   352  			}(newdata, newoutput, dimSize, dimStride, &wg)
   353  		default:
   354  			panic(fmt.Sprintf("tensor of %T not handled for softmax diff ", data))
   355  
   356  		}
   357  	}
   358  	wg.Wait()
   359  }
   360  
   361  type softmaxDiffOp struct {
   362  	*softmaxOp
   363  }
   364  
   365  func (op *softmaxDiffOp) Arity() int { return 3 }
   366  
   367  func (op *softmaxDiffOp) ReturnsPtr() bool { return false }
   368  
   369  func (op *softmaxDiffOp) CallsExtern() bool { return false }
   370  
   371  func (op *softmaxDiffOp) WriteHash(h hash.Hash) {
   372  	fmt.Fprintf(h, "SoftmaxDiff{%d, %v}()", op.axis, op.isLog)
   373  }
   374  
   375  func (op *softmaxDiffOp) Hashcode() uint32 { return simpleHash(op) }
   376  
   377  func (op *softmaxDiffOp) String() string {
   378  	return fmt.Sprintf("SoftmaxDiff{%d, %v}()", op.axis, op.isLog)
   379  }
   380  
   381  func (op *softmaxDiffOp) InferShape(inputs ...DimSizer) (tensor.Shape, error) {
   382  	s := inputs[0].(tensor.Shape).Clone()
   383  
   384  	return s, nil
   385  }
   386  
   387  func (op *softmaxDiffOp) Type() hm.Type {
   388  	a := hm.TypeVariable('a')
   389  
   390  	return hm.NewFnType(a, a, a, a) // f(float64) float64
   391  }
   392  
   393  func (op *softmaxDiffOp) OverwritesInput() int { return -1 }
   394  
   395  func (op *softmaxDiffOp) checkInput(inputs ...Value) (tensor.Tensor, tensor.Tensor, tensor.Tensor, error) {
   396  	if err := checkArity(op, len(inputs)); err != nil {
   397  		return nil, nil, nil, err
   398  	}
   399  
   400  	var (
   401  		in   tensor.Tensor
   402  		out  tensor.Tensor
   403  		grad tensor.Tensor
   404  		ok   bool
   405  	)
   406  
   407  	switch t := inputs[0].(type) {
   408  	case *dualValue:
   409  		if in, ok = t.Value.(tensor.Tensor); !ok {
   410  			return nil, nil, nil, errors.Errorf("input should be a tensor, got %T", inputs[0])
   411  		}
   412  	case tensor.Tensor:
   413  		in = t
   414  	default:
   415  		return nil, nil, nil, errors.Errorf("input type is not supported, got %T", inputs[0])
   416  	}
   417  
   418  	switch t := inputs[1].(type) {
   419  	case *dualValue:
   420  		if out, ok = t.Value.(tensor.Tensor); !ok {
   421  			return nil, nil, nil, errors.Errorf("output should be a tensor, got %T", inputs[1])
   422  		}
   423  	case tensor.Tensor:
   424  		out = t
   425  	default:
   426  		return nil, nil, nil, errors.Errorf("output type is not supported, got %T", inputs[1])
   427  	}
   428  
   429  	switch t := inputs[2].(type) {
   430  	case *dualValue:
   431  		if grad, ok = t.Value.(tensor.Tensor); !ok {
   432  			return nil, nil, nil, errors.Errorf("grad should be a tensor, got %T", inputs[1])
   433  		}
   434  	case tensor.Tensor:
   435  		grad = t
   436  	default:
   437  		return nil, nil, nil, errors.Errorf("grad type is not supported, got %T", inputs[1])
   438  	}
   439  
   440  	return in, out, grad, nil
   441  }
   442  
   443  func (op *softmaxDiffOp) Do(inputs ...Value) (Value, error) {
   444  	x, y, grad, err := op.checkInput(inputs...)
   445  	if err != nil {
   446  		return nil, fmt.Errorf("Can't check SoftmaxDiff input: %w", err)
   447  	}
   448  
   449  	s := x.Shape()
   450  	axis := op.axis
   451  	if axis == -1 {
   452  		axis = s.Dims() - 1
   453  	}
   454  
   455  	ret := tensor.New(tensor.WithShape(x.Shape().Clone()...), tensor.Of(x.Dtype()))
   456  	op.do(x.Shape(), axis, x.Data(), y.Data(), grad.Data(), ret.Data())
   457  	return ret, nil
   458  
   459  }
   460  
   461  func (op *softmaxDiffOp) UsePreallocDo(prealloc Value, inputs ...Value) (Value, error) {
   462  	x, y, grad, err := op.checkInput(inputs...)
   463  	if err != nil {
   464  		return nil, fmt.Errorf("Can't check SoftmaxDiff input: %w", err)
   465  	}
   466  
   467  	s := x.Shape()
   468  	axis := op.axis
   469  	if axis == -1 {
   470  		axis = s.Dims() - 1
   471  	}
   472  
   473  	op.do(x.Shape(), axis, x.Data(), y.Data(), grad.Data(), prealloc.Data())
   474  	return prealloc, nil
   475  }
   476  
   477  func (op *softmaxDiffOp) f64Kernel(input, output, yGrad, xGrad []float64, inner, ostride, dimSize, dimStride int) {
   478  	for i := 0; i < len(input); i++ {
   479  		oi := i / inner
   480  		ii := i % inner
   481  		idx := oi*ostride + ii
   482  		if idx >= len(input) {
   483  			continue
   484  		}
   485  
   486  		if idx >= len(output) {
   487  			continue
   488  		}
   489  		if idx >= len(yGrad) {
   490  			continue
   491  		}
   492  		if idx >= len(xGrad) {
   493  			continue
   494  		}
   495  
   496  		y := output
   497  		dy := yGrad
   498  		dydx := xGrad
   499  
   500  		// calculate sum
   501  		var sum float64
   502  		for d := 0; d < dimSize; d++ {
   503  			if op.isLog {
   504  				sum += dy[d*dimStride]
   505  			} else {
   506  				sum += dy[d*dimStride] * y[d*dimStride]
   507  			}
   508  		}
   509  		for d := 0; d < dimSize; d++ {
   510  			if op.isLog {
   511  				dydx[d*dimStride] = dy[d*dimStride] - math.Exp(y[d*dimStride])*sum
   512  			} else {
   513  				dydx[d*dimStride] = y[d*dimStride] * (dy[d*dimStride] - sum)
   514  			}
   515  		}
   516  	}
   517  }
   518  
   519  func (op *softmaxDiffOp) f32Kernel(input, output, yGrad, xGrad []float32, inner, ostride, dimSize, dimStride int) {
   520  	for i := 0; i < len(input); i++ {
   521  		oi := i / inner
   522  		ii := i % inner
   523  		idx := oi*ostride + ii
   524  		if idx >= len(input) {
   525  			continue
   526  		}
   527  
   528  		if idx >= len(output) {
   529  			continue
   530  		}
   531  		if idx >= len(yGrad) {
   532  			continue
   533  		}
   534  		if idx >= len(xGrad) {
   535  			continue
   536  		}
   537  
   538  		y := output[idx:]
   539  		dy := yGrad[idx:]
   540  		dydx := xGrad[idx:]
   541  
   542  		// calculate sum
   543  		var sum float32
   544  		for d := 0; d < dimSize; d++ {
   545  
   546  			if op.isLog {
   547  				sum += dy[d*dimStride]
   548  			} else {
   549  				sum += dy[d*dimStride] * y[d*dimStride]
   550  			}
   551  		}
   552  		for d := 0; d < dimSize; d++ {
   553  			if op.isLog {
   554  				dydx[d*dimStride] = dy[d*dimStride] - math32.Exp(y[d*dimStride])*sum
   555  			} else {
   556  				dydx[d*dimStride] = y[d*dimStride] * (dy[d*dimStride] - sum)
   557  			}
   558  		}
   559  
   560  	}
   561  }
   562  
   563  func (op *softmaxDiffOp) do(shp tensor.Shape, axis int, x, y, dy, retVal interface{}) {
   564  	//blocks := runtime.GOMAXPROCS(0) + 1
   565  	dimSize := shp[axis]
   566  	outer := tensor.ProdInts([]int(shp[:axis]))
   567  	inner := tensor.ProdInts([]int(shp[axis+1:]))
   568  	if outer == 0 {
   569  		outer = 1
   570  	}
   571  	if inner == 0 {
   572  		inner = 1
   573  	}
   574  	dimStride := inner
   575  	ostride := dimSize * dimStride
   576  
   577  	//datalen := shp.TotalSize()
   578  
   579  	//if blocks < minParallelBlocks {
   580  	switch x := x.(type) {
   581  	case []float64:
   582  		y := y.([]float64)
   583  		dy := dy.([]float64)
   584  		dydx := retVal.([]float64)
   585  		op.f64Kernel(x, y, dy, dydx, inner, ostride, dimSize, dimStride)
   586  	case []float32:
   587  		y := y.([]float32)
   588  		dy := dy.([]float32)
   589  		dydx := retVal.([]float32)
   590  		op.f32Kernel(x, y, dy, dydx, inner, ostride, dimSize, dimStride)
   591  	default:
   592  		panic(fmt.Sprintf("tensor of %T not handled for softmax diff ", x))
   593  
   594  	}
   595  	//}
   596  	/*
   597  		workers := workersChan()
   598  		var wg sync.WaitGroup
   599  		blockSize := datalen / blocks
   600  		if blockSize == 0 {
   601  			blockSize = datalen // 1 block
   602  		}
   603  
   604  		for b := 0; b < datalen; b += blockSize {
   605  			wg.Add(1)
   606  			switch xData := x.(type) {
   607  			case []float64:
   608  				yData := y.([]float64)
   609  				dyData := dy.([]float64)
   610  				dydxData := retVal.([]float64)
   611  				end := b + blockSize
   612  				if end > len(xData) {
   613  					end = len(xData)
   614  				}
   615  				newX := xData[b:end]
   616  				newY := yData[b:end]
   617  				newDy := dyData[b:end]
   618  				newDydx := dydxData[b:end]
   619  
   620  				go func(x, y, dy, dydx []float64, dimSize, dimStride int, wg *sync.WaitGroup) {
   621  					workers <- struct{}{}
   622  					op.f64Kernel(x, y, dy, dydx, inner, ostride, dimSize, dimStride)
   623  					wg.Done()
   624  					<-workers
   625  				}(newX, newY, newDy, newDydx, dimSize, dimStride, &wg)
   626  			case []float32:
   627  				yData := y.([]float32)
   628  				dyData := dy.([]float32)
   629  				dydxData := retVal.([]float32)
   630  				end := b + blockSize
   631  				if end > len(xData) {
   632  					end = len(xData)
   633  				}
   634  				newX := xData[b:end]
   635  				newY := yData[b:end]
   636  				newDy := dyData[b:end]
   637  				newDydx := dydxData[b:end]
   638  				go func(x, y, dy, dydx []float32, dimSize, dimStride int, wg *sync.WaitGroup) {
   639  					workers <- struct{}{}
   640  					op.f32Kernel(x, y, dy, dydx, inner, ostride, dimSize, dimStride)
   641  					wg.Done()
   642  					<-workers
   643  				}(newX, newY, newDy, newDydx, dimSize, dimStride, &wg)
   644  			default:
   645  				panic(fmt.Sprintf("tensor of %T not handled for softmax diff ", xData))
   646  
   647  			}
   648  		}
   649  	*/
   650  }
   651  
   652  // ensure it complies with the Op interface
   653  var (
   654  	_ Op   = &softmaxOp{}
   655  	_ ADOp = &softmaxOp{}
   656  	_ SDOp = &softmaxOp{}
   657  
   658  	_ Op = &softmaxDiffOp{}
   659  )