gorgonia.org/gorgonia@v0.9.17/op_sparsemax.go (about)

     1  package gorgonia
     2  
     3  import (
     4  	"fmt"
     5  	"hash"
     6  	"math"
     7  	"sort"
     8  
     9  	"github.com/chewxy/hm"
    10  	"github.com/pkg/errors"
    11  	"gorgonia.org/tensor"
    12  )
    13  
    14  type sparsemaxOp struct {
    15  	axis int
    16  }
    17  
    18  func newSparsemaxOp(axes ...int) *sparsemaxOp {
    19  	axis := -1
    20  	if len(axes) > 0 {
    21  		axis = axes[0]
    22  	}
    23  
    24  	sparsemaxop := &sparsemaxOp{
    25  		axis: axis,
    26  	}
    27  
    28  	return sparsemaxop
    29  }
    30  
    31  // Sparsemax -  implements the sparsemax operation described here: http://proceedings.mlr.press/v48/martins16.pdf
    32  func Sparsemax(x *Node, axes ...int) (*Node, error) {
    33  	op := newSparsemaxOp(axes...)
    34  
    35  	return ApplyOp(op, x)
    36  }
    37  
    38  func (op *sparsemaxOp) Arity() int {
    39  	return 1
    40  }
    41  
    42  func (op *sparsemaxOp) ReturnsPtr() bool { return false }
    43  
    44  func (op *sparsemaxOp) CallsExtern() bool { return false }
    45  
    46  func (op *sparsemaxOp) WriteHash(h hash.Hash) {
    47  	fmt.Fprintf(h, "Sparsemax{}()")
    48  }
    49  
    50  func (op *sparsemaxOp) Hashcode() uint32 { return simpleHash(op) }
    51  
    52  func (op *sparsemaxOp) String() string {
    53  	return fmt.Sprintf("Sparsemax{}()")
    54  }
    55  
    56  func (op *sparsemaxOp) InferShape(inputs ...DimSizer) (tensor.Shape, error) {
    57  	s := inputs[0].(tensor.Shape).Clone()
    58  	return s, nil
    59  }
    60  
    61  func (op *sparsemaxOp) Type() hm.Type {
    62  	a := hm.TypeVariable('a')
    63  	return hm.NewFnType(a, a)
    64  }
    65  
    66  func (op *sparsemaxOp) OverwritesInput() int { return -1 }
    67  
    68  func (op *sparsemaxOp) checkInput(inputs ...Value) (tensor.Tensor, error) {
    69  	if err := checkArity(op, len(inputs)); err != nil {
    70  		return nil, err
    71  	}
    72  
    73  	var in tensor.Tensor
    74  	var ok bool
    75  
    76  	if in, ok = inputs[0].(tensor.Tensor); !ok {
    77  		return nil, errors.Errorf("Expected input to be a tensor, got %T", inputs[0])
    78  	}
    79  
    80  	return in, nil
    81  }
    82  
    83  func (op *sparsemaxOp) Do(inputs ...Value) (Value, error) {
    84  	inputTensor, err := op.checkInput(inputs...)
    85  	if err != nil {
    86  		return nil, fmt.Errorf("Can't check Sparsemax input: %w", err)
    87  	}
    88  
    89  	inputShape := inputTensor.Shape()
    90  
    91  	if op.axis != -1 {
    92  		axes := make([]int, inputTensor.Dims())
    93  		axes[op.axis] = 1
    94  
    95  		inputTensor, err = tensor.Transpose(inputTensor, axes...)
    96  		if err != nil {
    97  			return nil, fmt.Errorf("error tranposing the input tensor: %w", err)
    98  		}
    99  	}
   100  
   101  	var output interface{}
   102  
   103  	switch inputTensor.Dtype() {
   104  	case tensor.Float64:
   105  		output = op.float64sparseMax(inputTensor)
   106  	case tensor.Float32:
   107  		output = op.float32sparseMax(inputTensor)
   108  	default:
   109  		return nil, fmt.Errorf("invalid input type for Sparsemax, expected float64 or float32, got: %v", inputTensor.Dtype())
   110  	}
   111  
   112  	return tensor.New(tensor.Of(inputTensor.Dtype()), tensor.WithShape(inputShape.Clone()...), tensor.WithEngine(inputTensor.Engine()), tensor.WithBacking(output)), nil
   113  }
   114  
   115  // FIXME: go2 generics
   116  func (op *sparsemaxOp) float32sparseMax(inputTensor tensor.Tensor) interface{} {
   117  	inputData := inputTensor.Data().([]float32)
   118  	dims := inputTensor.Dims()
   119  	it := 0
   120  
   121  	to := inputTensor.Shape()[dims-1]
   122  	from := tensor.Shape(inputTensor.Shape()[0 : dims-1]).TotalSize()
   123  	if from == 0 {
   124  		from = 1
   125  	}
   126  
   127  	maxValues := make([]float32, from)
   128  
   129  	for i := 0; i < from; i++ {
   130  		maxValue := float32(-math.MaxFloat32)
   131  
   132  		for j := 0; j < to; j++ {
   133  			if inputData[it] > maxValue {
   134  				maxValue = inputData[it]
   135  			}
   136  
   137  			it++
   138  		}
   139  
   140  		maxValues[i] = maxValue
   141  	}
   142  
   143  	// this is math trick for numerical stability
   144  	stableInput := make([]float32, len(inputData))
   145  	it = 0
   146  
   147  	for i := 0; i < from; i++ {
   148  		for j := 0; j < to; j++ {
   149  			stableInput[it] = inputData[it] - maxValues[i]
   150  			it++
   151  		}
   152  	}
   153  
   154  	sortedData := make([]float32, len(inputData))
   155  	copy(sortedData, stableInput)
   156  
   157  	sort.Slice(sortedData, func(i, j int) bool {
   158  		return sortedData[i] > sortedData[j]
   159  	})
   160  
   161  	thresholds := make([]float32, from)
   162  	it = 0
   163  
   164  	for i := 0; i < from; i++ {
   165  		cumSum := float32(0.0)
   166  		prevCum := float32(0.0)
   167  		maxIndex := 0
   168  
   169  		for j := 0; j < to; j++ {
   170  			k := 1 + float32(j+1)*sortedData[it]
   171  
   172  			prevCum += sortedData[it]
   173  
   174  			if k > prevCum {
   175  				maxIndex = j + 1
   176  
   177  				cumSum += sortedData[i]
   178  			}
   179  
   180  			it++
   181  		}
   182  
   183  		thresholds[i] = (cumSum - 1) / float32(maxIndex)
   184  	}
   185  
   186  	output := make([]float32, len(stableInput))
   187  	it = 0
   188  
   189  	for i := 0; i < from; i++ {
   190  		for j := 0; j < to; j++ {
   191  			vF := stableInput[it]
   192  
   193  			if vF-thresholds[i] > 0 {
   194  				output[it] = vF - thresholds[i]
   195  			}
   196  
   197  			it++
   198  		}
   199  	}
   200  
   201  	return output
   202  }
   203  
   204  func (op *sparsemaxOp) float64sparseMax(inputTensor tensor.Tensor) interface{} {
   205  	inputData := inputTensor.Data().([]float64)
   206  	dims := inputTensor.Dims()
   207  	it := 0
   208  
   209  	to := inputTensor.Shape()[dims-1]
   210  	from := tensor.Shape(inputTensor.Shape()[0 : dims-1]).TotalSize()
   211  	if from == 0 {
   212  		from = 1
   213  	}
   214  
   215  	maxValues := make([]float64, from)
   216  
   217  	for i := 0; i < from; i++ {
   218  		maxValue := -math.MaxFloat64
   219  
   220  		for j := 0; j < to; j++ {
   221  			if inputData[it] > maxValue {
   222  				maxValue = inputData[it]
   223  			}
   224  
   225  			it++
   226  		}
   227  
   228  		maxValues[i] = maxValue
   229  	}
   230  
   231  	// this is math trick for numerical stability
   232  	stableInput := make([]float64, len(inputData))
   233  	it = 0
   234  
   235  	for i := 0; i < from; i++ {
   236  		for j := 0; j < to; j++ {
   237  			stableInput[it] = inputData[it] - maxValues[i]
   238  			it++
   239  		}
   240  	}
   241  
   242  	sortedData := make([]float64, len(inputData))
   243  	copy(sortedData, stableInput)
   244  
   245  	sort.Slice(sortedData, func(i, j int) bool {
   246  		return sortedData[i] > sortedData[j]
   247  	})
   248  
   249  	thresholds := make([]float64, from)
   250  	it = 0
   251  
   252  	for i := 0; i < from; i++ {
   253  		cumSum := 0.0
   254  		prevCum := 0.0
   255  		maxIndex := 0
   256  
   257  		for j := 0; j < to; j++ {
   258  			k := 1 + float64(j+1)*sortedData[it]
   259  
   260  			prevCum += sortedData[it]
   261  
   262  			if k > prevCum {
   263  				maxIndex = j + 1
   264  
   265  				cumSum += sortedData[i]
   266  			}
   267  
   268  			it++
   269  		}
   270  
   271  		thresholds[i] = (cumSum - 1) / float64(maxIndex)
   272  	}
   273  
   274  	output := make([]float64, len(stableInput))
   275  	it = 0
   276  
   277  	for i := 0; i < from; i++ {
   278  		for j := 0; j < to; j++ {
   279  			vF := stableInput[it]
   280  
   281  			if vF-thresholds[i] > 0 {
   282  				output[it] = vF - thresholds[i]
   283  			}
   284  
   285  			it++
   286  		}
   287  	}
   288  
   289  	return output
   290  }
   291  
   292  // DoDiff calculates the diff and sets its value to the output node. Implementation for ADOp interface.
   293  func (op *sparsemaxOp) DoDiff(ctx ExecutionContext, inputs Nodes, output *Node) error {
   294  	if len(inputs) != 2 {
   295  		return fmt.Errorf("SparsemaxOp.DoDiff needs 2 arguments")
   296  	}
   297  
   298  	odv := output.boundTo.(*dualValue)
   299  	odvd := odv.Value.(tensor.Tensor)
   300  	diffOp := &sparsemaxDiffOp{}
   301  
   302  	result, err := diffOp.Do(inputs[0].boundTo, inputs[1].boundTo)
   303  	if err != nil {
   304  		return err
   305  	}
   306  
   307  	err = result.(*tensor.Dense).Reshape(odvd.Shape()...)
   308  	if err != nil {
   309  		return err
   310  	}
   311  
   312  	sum, err := odvd.(*tensor.Dense).Add(result.(*tensor.Dense), tensor.UseUnsafe())
   313  	if err != nil {
   314  		return err
   315  	}
   316  
   317  	odv.d = sum
   318  
   319  	return nil
   320  }
   321  
   322  // SymDiff applies the diff op. Implementation for SDOp interface.
   323  func (op *sparsemaxOp) SymDiff(inputs Nodes, output, grad *Node) (Nodes, error) {
   324  	err := checkArity(op, len(inputs))
   325  	if err != nil {
   326  		return nil, err
   327  	}
   328  
   329  	t := inputs[0]
   330  
   331  	diffOp := &sparsemaxDiffOp{}
   332  	nodes := make(Nodes, 1)
   333  
   334  	nodes[0], err = ApplyOp(diffOp, t, grad)
   335  
   336  	return nodes, err
   337  }
   338  
   339  // DiffWRT is an implementation for the SDOp interface
   340  func (op *sparsemaxOp) DiffWRT(inputs int) []bool {
   341  	if inputs != 1 {
   342  		panic(fmt.Sprintf("sparsemax operator only supports one input, got %d instead", inputs))
   343  	}
   344  
   345  	return []bool{true}
   346  }
   347  
   348  type sparsemaxDiffOp struct {
   349  }
   350  
   351  func newSparsemaxOpDiff() *sparsemaxDiffOp {
   352  	return &sparsemaxDiffOp{}
   353  }
   354  
   355  func (op *sparsemaxDiffOp) Arity() int {
   356  	return 2
   357  }
   358  
   359  func (op *sparsemaxDiffOp) ReturnsPtr() bool { return false }
   360  
   361  func (op *sparsemaxDiffOp) CallsExtern() bool { return false }
   362  
   363  func (op *sparsemaxDiffOp) WriteHash(h hash.Hash) {
   364  	fmt.Fprintf(h, "SparsemaxDiff{}()")
   365  }
   366  
   367  func (op *sparsemaxDiffOp) Hashcode() uint32 { return simpleHash(op) }
   368  
   369  func (op *sparsemaxDiffOp) String() string {
   370  	return fmt.Sprintf("SparsemaxDiff{}()")
   371  }
   372  
   373  func (op *sparsemaxDiffOp) InferShape(inputs ...DimSizer) (tensor.Shape, error) {
   374  	s := inputs[0].(tensor.Shape).Clone()
   375  
   376  	return s, nil
   377  }
   378  
   379  func (op *sparsemaxDiffOp) Type() hm.Type {
   380  	a := hm.TypeVariable('a')
   381  	return hm.NewFnType(a, a, a)
   382  }
   383  
   384  func (op *sparsemaxDiffOp) OverwritesInput() int { return -1 }
   385  
   386  func (op *sparsemaxDiffOp) checkInput(inputs ...Value) (*tensor.Dense, *tensor.Dense, error) {
   387  	if err := checkArity(op, len(inputs)); err != nil {
   388  		return nil, nil, err
   389  	}
   390  
   391  	var (
   392  		in *tensor.Dense
   393  
   394  		gradient *tensor.Dense
   395  		ok       bool
   396  	)
   397  
   398  	switch t := inputs[0].(type) {
   399  	case *dualValue:
   400  		if in, ok = t.Value.(*tensor.Dense); !ok {
   401  			return nil, nil, errors.Errorf("input should be a tensor.Tensor, got %T", inputs[0])
   402  		}
   403  	case *tensor.Dense:
   404  		in = t
   405  	default:
   406  		return nil, nil, errors.Errorf("input type is not supported, got %T", inputs[0])
   407  	}
   408  
   409  	switch t := inputs[1].(type) {
   410  	case *dualValue:
   411  		if gradient, ok = t.Value.(*tensor.Dense); !ok {
   412  			return nil, nil, errors.Errorf("gradient should be a tensor, got %T", inputs[1])
   413  		}
   414  	case *tensor.Dense:
   415  		gradient = t
   416  	default:
   417  		return nil, nil, errors.Errorf("gradient type is not supported, got %T", inputs[1])
   418  	}
   419  
   420  	return in, gradient, nil
   421  }
   422  
   423  func (op *sparsemaxDiffOp) mul(a tensor.Tensor, b tensor.Tensor) (tensor.Tensor, error) {
   424  	if a.Dims() != b.Dims() {
   425  		return tensor.Outer(a, b)
   426  	}
   427  
   428  	return tensor.Mul(a, b)
   429  }
   430  
   431  func (op *sparsemaxDiffOp) Do(inputs ...Value) (Value, error) {
   432  	inputTensor, gradTensor, err := op.checkInput(inputs...)
   433  	if err != nil {
   434  		return nil, fmt.Errorf("Can't check SparsemaxDiff input: %w", err)
   435  	}
   436  
   437  	if inputTensor.Size() != gradTensor.Size() {
   438  		return nil, fmt.Errorf("sparsemaxDiffOp.Do inputs sizes should be equal")
   439  	}
   440  
   441  	var zero interface{}
   442  
   443  	if inputTensor.Dtype() == tensor.Float32 {
   444  		zero = float32(0.0)
   445  	} else {
   446  		zero = float64(0.0)
   447  	}
   448  
   449  	nonZeros, err := inputTensor.ElNeScalar(zero, false, tensor.AsSameType())
   450  	if err != nil {
   451  		return nil, fmt.Errorf("sparsemaxDiffOp.Do failed to get non-zeros: %w", err)
   452  	}
   453  
   454  	mul, err := op.mul(nonZeros, gradTensor)
   455  	if err != nil {
   456  		return nil, fmt.Errorf("sparsemaxDiffOp.Do failed to mul grad tensor: %w", err)
   457  	}
   458  
   459  	a, err := tensor.Sum(mul, 1)
   460  	if err != nil {
   461  		return nil, err
   462  	}
   463  
   464  	b, err := tensor.Sum(nonZeros, 1)
   465  	if err != nil {
   466  		return nil, err
   467  	}
   468  
   469  	sum, err := tensor.Div(a, b)
   470  	if err != nil {
   471  		return nil, err
   472  	}
   473  
   474  	if sum.Dims() == 1 && gradTensor.Dims() == 2 {
   475  		err := sum.Reshape(sum.Shape()[0], 1)
   476  		if err != nil {
   477  			return nil, err
   478  		}
   479  
   480  		sum, err = tensor.Repeat(sum, 1, gradTensor.Shape()[1])
   481  		if err != nil {
   482  			panic(err)
   483  		}
   484  	}
   485  
   486  	sub, err := tensor.Sub(gradTensor, sum)
   487  	if err != nil {
   488  		return nil, err
   489  	}
   490  
   491  	result, err := op.mul(nonZeros, sub)
   492  	if err != nil {
   493  		return nil, err
   494  	}
   495  
   496  	return result, nil
   497  }
   498  
   499  // ensure it complies with the Op interface
   500  var (
   501  	_ Op = &sparsemaxDiffOp{}
   502  
   503  	_ Op   = &sparsemaxOp{}
   504  	_ SDOp = &sparsemaxOp{}
   505  	_ ADOp = &sparsemaxOp{}
   506  )