gorgonia.org/gorgonia@v0.9.17/op_softmax_test.go (about)

     1  package gorgonia
     2  
     3  import (
     4  	"io/ioutil"
     5  	"testing"
     6  
     7  	"github.com/pkg/errors"
     8  	"github.com/stretchr/testify/assert"
     9  	"gorgonia.org/tensor"
    10  )
    11  
    12  var testCasesSoftMaxDo = []struct {
    13  	input    []float64
    14  	expected []float64
    15  }{
    16  	{
    17  		[]float64{0.2094, -1.0, 0.6411, 0.0, -0.3909}, []float64{0.2382105379413429, 0.07107636737487558, 0.36681399568548617, 0.19320559786800362, 0.13069350113029174},
    18  	},
    19  	{
    20  		[]float64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, []float64{7.801341612780742e-05, 0.00021206245143623275, 0.0005764455082375902, 0.0015669413501390804, 0.004259388198344144, 0.0115782175399118, 0.031472858344688034, 0.08555209892803112, 0.23255471590259755, 0.6321492583604866},
    21  	},
    22  	{
    23  		[]float64{0.1, 0.1, 0.1}, []float64{0.3333333333333333, 0.3333333333333333, 0.3333333333333333},
    24  	},
    25  	{
    26  		[]float64{-0.1, 0.3, -1.1, 2.7}, []float64{0.05180179352659075, 0.07727919496508177, 0.019056814854240642, 0.8518621966540868},
    27  	},
    28  }
    29  
    30  func TestSoftmaxDo(t *testing.T) {
    31  	assert := assert.New(t)
    32  
    33  	for i, testCase := range testCasesSoftMaxDo {
    34  		tt := tensor.New(tensor.Of(tensor.Float64), tensor.WithShape(len(testCase.input)), tensor.WithBacking(testCase.input))
    35  		op := newSoftmaxOp(tt.Shape())
    36  
    37  		out, err := op.Do(tt)
    38  		assert.NoError(err, "failed test case: %d", i)
    39  		assert.True(floatsEqual64(out.Data().([]float64), testCase.expected))
    40  	}
    41  }
    42  
    43  func TestSoftmaxKernel(t *testing.T) {
    44  	// this test is used for migrating to a new algorithm for softmax
    45  	assert := assert.New(t)
    46  	a := tensor.New(tensor.WithShape(2, 3), tensor.WithBacking([]float64{-0.1, 0.3, -1.1, 2.7, 3.14, 0.1}))
    47  	op := newSoftmaxOp(a.Shape())
    48  	op.axis = 0
    49  	b0, _ := op.Do(a)
    50  	op.axis = 1
    51  	b1, _ := op.Do(a)
    52  
    53  	// across axis 0
    54  	out := make([]float64, 6)
    55  	op.do(tensor.Shape{2, 3}, 0, a.Data().([]float64), out)
    56  	assert.True(floatsEqual64(out, b0.Data().([]float64)))
    57  	t.Logf("\n%v\n%v", out, b0.Data())
    58  
    59  	// acros axis 1
    60  	out = make([]float64, 6)
    61  	op.do(tensor.Shape{2, 3}, 1, a.Data().([]float64), out)
    62  	assert.True(floatsEqual64(out, b1.Data().([]float64)))
    63  	/*
    64  		// super large
    65  		a = tensor.New(tensor.WithShape(10, 1024, 2048, 30), tensor.WithBacking(Uniform64(-1, 1, 10, 1024, 2048, 30)))
    66  		op = newSoftmaxOp(a.Shape())
    67  		op.axis = 0
    68  		b, _ := op.Do(a)
    69  
    70  		out = make([]float64, 10*1024*2048*30)
    71  		op.doF64s(tensor.Shape{10, 1024, 2048, 30}, 0, a.Data().([]float64), out)
    72  		assert.True(floatsEqual64(out, b.Data().([]float64)))
    73  	*/
    74  }
    75  
    76  func oldsoftmax(a *Node, axes ...int) (retVal *Node, err error) {
    77  	aShape := a.Shape()
    78  	axis := aShape.Dims() - 1 // default: last dim
    79  	if a.IsColVec() || (a.IsVector() && !a.IsRowVec()) {
    80  		axis = 0
    81  	}
    82  
    83  	if len(axes) > 0 {
    84  		if axes[0] >= axis+1 || axes[0] < 0 {
    85  			return nil, errors.Errorf("Cannot perform SoftMax on axis %d. Input has shape %v", axes[0], a.Shape())
    86  		}
    87  		axis = axes[0]
    88  	}
    89  
    90  	var exp, sum *Node
    91  	if exp, err = Exp(a); err != nil {
    92  		return nil, errors.Wrap(err, operationError)
    93  	}
    94  	if sum, err = Sum(exp, axis); err != nil {
    95  		return nil, errors.Wrap(err, operationError)
    96  	}
    97  
    98  	if sum.IsScalar() {
    99  		return HadamardDiv(exp, sum)
   100  	}
   101  
   102  	// reshape if necessary
   103  	ss := sum.Shape()
   104  	diff := exp.Shape().Dims() - ss.Dims()
   105  
   106  	// TODO: multirank softmax
   107  	if diff > 0 {
   108  		newShape := tensor.Shape(tensor.BorrowInts(ss.Dims() + diff))
   109  		copy(newShape, ss)
   110  		copy(newShape[axis+1:], newShape[axis:])
   111  		newShape[axis] = 1
   112  
   113  		if sum, err = Reshape(sum, newShape); err != nil {
   114  			return nil, errors.Wrap(err, "Failed to reshape")
   115  		}
   116  	}
   117  
   118  	return BroadcastHadamardDiv(exp, sum, nil, []byte{byte(axis)})
   119  }
   120  
   121  func TestOld_NewSoftmax(t *testing.T) {
   122  	a := tensor.New(tensor.WithBacking([]float64{0.1, 0.1, 0.3, 0.1, 0.4}))
   123  
   124  	g := NewGraph()
   125  	A := NodeFromAny(g, a, WithName("A"))
   126  	sm := Must(SoftMax(A))
   127  	sum := Must(Sum(sm))
   128  	if _, err := Grad(sum, A); err != nil {
   129  		t.Fatal(err)
   130  	}
   131  
   132  	h := NewGraph()
   133  	A2 := NodeFromAny(h, a, WithName("A"))
   134  	sm2 := Must(oldsoftmax(A2))
   135  	sum2 := Must(Sum(sm2))
   136  	if _, err := Grad(sum2, A2); err != nil {
   137  		t.Fatal(err)
   138  	}
   139  
   140  	m1 := NewTapeMachine(g, TraceExec(), BindDualValues())
   141  	if err := m1.RunAll(); err != nil {
   142  		t.Fatalf("m1 %v", err)
   143  	}
   144  
   145  	m2 := NewTapeMachine(h, TraceExec(), BindDualValues())
   146  	if err := m2.RunAll(); err != nil {
   147  		t.Fatalf("m2 %v", err)
   148  	}
   149  
   150  	Agrad, err := A.Grad()
   151  	if err != nil {
   152  		t.Fatalf("No grad for A %v", err)
   153  	}
   154  
   155  	A2grad, err := A2.Grad()
   156  	if err != nil {
   157  		t.Fatalf("No grad for A2 %v", err)
   158  	}
   159  
   160  	t.Logf("\n%v\n%v", sm.Value(), sm2.Value())
   161  	t.Logf("\n%v\n%v", Agrad, A2grad)
   162  
   163  	ioutil.WriteFile("oldsm.dot", []byte(h.ToDot()), 0644)
   164  	ioutil.WriteFile("newsm.dot", []byte(g.ToDot()), 0644)
   165  
   166  }
   167  
   168  func BenchmarkSoftmaxLargeOldAxis0(b *testing.B) {
   169  	b.StopTimer()
   170  	a := tensor.New(tensor.WithShape(10, 1024, 2048, 30), tensor.WithBacking(Uniform64(-1, 1, 10, 1024, 2048, 30)))
   171  	op := newSoftmaxOp(a.Shape())
   172  	op.axis = 0
   173  	var v Value
   174  
   175  	b.ResetTimer()
   176  	b.StartTimer()
   177  	for i := 0; i < b.N; i++ {
   178  		v, _ = op.Do(a)
   179  	}
   180  	_ = v
   181  }
   182  
   183  func BenchmarkSoftmaxLargeNewAxis0(b *testing.B) {
   184  	b.StopTimer()
   185  	a := tensor.New(tensor.WithShape(10, 1024, 2048, 30), tensor.WithBacking(Uniform64(-1, 1, 10, 1024, 2048, 30)))
   186  	op := newSoftmaxOp(a.Shape())
   187  	op.axis = 0
   188  	out := make([]float64, len(a.Data().([]float64)))
   189  
   190  	b.ResetTimer()
   191  	b.StartTimer()
   192  	for i := 0; i < b.N; i++ {
   193  		op.do(a.Shape(), 0, a.Data().([]float64), out)
   194  	}
   195  
   196  }
   197  
   198  func BenchmarkSoftmaxMedOldAxis0(b *testing.B) {
   199  	b.StopTimer()
   200  	a := tensor.New(tensor.WithShape(1200, 2500), tensor.WithBacking(Uniform64(-1, 1, 1200, 2500)))
   201  	op := newSoftmaxOp(a.Shape())
   202  	op.axis = 0
   203  	var v Value
   204  
   205  	b.ResetTimer()
   206  	b.StartTimer()
   207  	for i := 0; i < b.N; i++ {
   208  		v, _ = op.Do(a)
   209  	}
   210  	_ = v
   211  }
   212  
   213  func BenchmarkSoftmaxMedNewAxis0(b *testing.B) {
   214  	b.StopTimer()
   215  	a := tensor.New(tensor.WithShape(1200, 2500), tensor.WithBacking(Uniform64(-1, 1, 1200, 2500)))
   216  	op := newSoftmaxOp(a.Shape())
   217  	op.axis = 0
   218  	out := make([]float64, len(a.Data().([]float64)))
   219  
   220  	b.ResetTimer()
   221  	b.StartTimer()
   222  	for i := 0; i < b.N; i++ {
   223  		op.do(a.Shape(), 0, a.Data().([]float64), out)
   224  	}
   225  
   226  }