gorgonia.org/gorgonia@v0.9.17/op_softmax_test.go (about) 1 package gorgonia 2 3 import ( 4 "io/ioutil" 5 "testing" 6 7 "github.com/pkg/errors" 8 "github.com/stretchr/testify/assert" 9 "gorgonia.org/tensor" 10 ) 11 12 var testCasesSoftMaxDo = []struct { 13 input []float64 14 expected []float64 15 }{ 16 { 17 []float64{0.2094, -1.0, 0.6411, 0.0, -0.3909}, []float64{0.2382105379413429, 0.07107636737487558, 0.36681399568548617, 0.19320559786800362, 0.13069350113029174}, 18 }, 19 { 20 []float64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, []float64{7.801341612780742e-05, 0.00021206245143623275, 0.0005764455082375902, 0.0015669413501390804, 0.004259388198344144, 0.0115782175399118, 0.031472858344688034, 0.08555209892803112, 0.23255471590259755, 0.6321492583604866}, 21 }, 22 { 23 []float64{0.1, 0.1, 0.1}, []float64{0.3333333333333333, 0.3333333333333333, 0.3333333333333333}, 24 }, 25 { 26 []float64{-0.1, 0.3, -1.1, 2.7}, []float64{0.05180179352659075, 0.07727919496508177, 0.019056814854240642, 0.8518621966540868}, 27 }, 28 } 29 30 func TestSoftmaxDo(t *testing.T) { 31 assert := assert.New(t) 32 33 for i, testCase := range testCasesSoftMaxDo { 34 tt := tensor.New(tensor.Of(tensor.Float64), tensor.WithShape(len(testCase.input)), tensor.WithBacking(testCase.input)) 35 op := newSoftmaxOp(tt.Shape()) 36 37 out, err := op.Do(tt) 38 assert.NoError(err, "failed test case: %d", i) 39 assert.True(floatsEqual64(out.Data().([]float64), testCase.expected)) 40 } 41 } 42 43 func TestSoftmaxKernel(t *testing.T) { 44 // this test is used for migrating to a new algorithm for softmax 45 assert := assert.New(t) 46 a := tensor.New(tensor.WithShape(2, 3), tensor.WithBacking([]float64{-0.1, 0.3, -1.1, 2.7, 3.14, 0.1})) 47 op := newSoftmaxOp(a.Shape()) 48 op.axis = 0 49 b0, _ := op.Do(a) 50 op.axis = 1 51 b1, _ := op.Do(a) 52 53 // across axis 0 54 out := make([]float64, 6) 55 op.do(tensor.Shape{2, 3}, 0, a.Data().([]float64), out) 56 assert.True(floatsEqual64(out, b0.Data().([]float64))) 57 t.Logf("\n%v\n%v", out, b0.Data()) 58 59 // acros axis 1 60 out = make([]float64, 6) 61 op.do(tensor.Shape{2, 3}, 1, a.Data().([]float64), out) 62 assert.True(floatsEqual64(out, b1.Data().([]float64))) 63 /* 64 // super large 65 a = tensor.New(tensor.WithShape(10, 1024, 2048, 30), tensor.WithBacking(Uniform64(-1, 1, 10, 1024, 2048, 30))) 66 op = newSoftmaxOp(a.Shape()) 67 op.axis = 0 68 b, _ := op.Do(a) 69 70 out = make([]float64, 10*1024*2048*30) 71 op.doF64s(tensor.Shape{10, 1024, 2048, 30}, 0, a.Data().([]float64), out) 72 assert.True(floatsEqual64(out, b.Data().([]float64))) 73 */ 74 } 75 76 func oldsoftmax(a *Node, axes ...int) (retVal *Node, err error) { 77 aShape := a.Shape() 78 axis := aShape.Dims() - 1 // default: last dim 79 if a.IsColVec() || (a.IsVector() && !a.IsRowVec()) { 80 axis = 0 81 } 82 83 if len(axes) > 0 { 84 if axes[0] >= axis+1 || axes[0] < 0 { 85 return nil, errors.Errorf("Cannot perform SoftMax on axis %d. Input has shape %v", axes[0], a.Shape()) 86 } 87 axis = axes[0] 88 } 89 90 var exp, sum *Node 91 if exp, err = Exp(a); err != nil { 92 return nil, errors.Wrap(err, operationError) 93 } 94 if sum, err = Sum(exp, axis); err != nil { 95 return nil, errors.Wrap(err, operationError) 96 } 97 98 if sum.IsScalar() { 99 return HadamardDiv(exp, sum) 100 } 101 102 // reshape if necessary 103 ss := sum.Shape() 104 diff := exp.Shape().Dims() - ss.Dims() 105 106 // TODO: multirank softmax 107 if diff > 0 { 108 newShape := tensor.Shape(tensor.BorrowInts(ss.Dims() + diff)) 109 copy(newShape, ss) 110 copy(newShape[axis+1:], newShape[axis:]) 111 newShape[axis] = 1 112 113 if sum, err = Reshape(sum, newShape); err != nil { 114 return nil, errors.Wrap(err, "Failed to reshape") 115 } 116 } 117 118 return BroadcastHadamardDiv(exp, sum, nil, []byte{byte(axis)}) 119 } 120 121 func TestOld_NewSoftmax(t *testing.T) { 122 a := tensor.New(tensor.WithBacking([]float64{0.1, 0.1, 0.3, 0.1, 0.4})) 123 124 g := NewGraph() 125 A := NodeFromAny(g, a, WithName("A")) 126 sm := Must(SoftMax(A)) 127 sum := Must(Sum(sm)) 128 if _, err := Grad(sum, A); err != nil { 129 t.Fatal(err) 130 } 131 132 h := NewGraph() 133 A2 := NodeFromAny(h, a, WithName("A")) 134 sm2 := Must(oldsoftmax(A2)) 135 sum2 := Must(Sum(sm2)) 136 if _, err := Grad(sum2, A2); err != nil { 137 t.Fatal(err) 138 } 139 140 m1 := NewTapeMachine(g, TraceExec(), BindDualValues()) 141 if err := m1.RunAll(); err != nil { 142 t.Fatalf("m1 %v", err) 143 } 144 145 m2 := NewTapeMachine(h, TraceExec(), BindDualValues()) 146 if err := m2.RunAll(); err != nil { 147 t.Fatalf("m2 %v", err) 148 } 149 150 Agrad, err := A.Grad() 151 if err != nil { 152 t.Fatalf("No grad for A %v", err) 153 } 154 155 A2grad, err := A2.Grad() 156 if err != nil { 157 t.Fatalf("No grad for A2 %v", err) 158 } 159 160 t.Logf("\n%v\n%v", sm.Value(), sm2.Value()) 161 t.Logf("\n%v\n%v", Agrad, A2grad) 162 163 ioutil.WriteFile("oldsm.dot", []byte(h.ToDot()), 0644) 164 ioutil.WriteFile("newsm.dot", []byte(g.ToDot()), 0644) 165 166 } 167 168 func BenchmarkSoftmaxLargeOldAxis0(b *testing.B) { 169 b.StopTimer() 170 a := tensor.New(tensor.WithShape(10, 1024, 2048, 30), tensor.WithBacking(Uniform64(-1, 1, 10, 1024, 2048, 30))) 171 op := newSoftmaxOp(a.Shape()) 172 op.axis = 0 173 var v Value 174 175 b.ResetTimer() 176 b.StartTimer() 177 for i := 0; i < b.N; i++ { 178 v, _ = op.Do(a) 179 } 180 _ = v 181 } 182 183 func BenchmarkSoftmaxLargeNewAxis0(b *testing.B) { 184 b.StopTimer() 185 a := tensor.New(tensor.WithShape(10, 1024, 2048, 30), tensor.WithBacking(Uniform64(-1, 1, 10, 1024, 2048, 30))) 186 op := newSoftmaxOp(a.Shape()) 187 op.axis = 0 188 out := make([]float64, len(a.Data().([]float64))) 189 190 b.ResetTimer() 191 b.StartTimer() 192 for i := 0; i < b.N; i++ { 193 op.do(a.Shape(), 0, a.Data().([]float64), out) 194 } 195 196 } 197 198 func BenchmarkSoftmaxMedOldAxis0(b *testing.B) { 199 b.StopTimer() 200 a := tensor.New(tensor.WithShape(1200, 2500), tensor.WithBacking(Uniform64(-1, 1, 1200, 2500))) 201 op := newSoftmaxOp(a.Shape()) 202 op.axis = 0 203 var v Value 204 205 b.ResetTimer() 206 b.StartTimer() 207 for i := 0; i < b.N; i++ { 208 v, _ = op.Do(a) 209 } 210 _ = v 211 } 212 213 func BenchmarkSoftmaxMedNewAxis0(b *testing.B) { 214 b.StopTimer() 215 a := tensor.New(tensor.WithShape(1200, 2500), tensor.WithBacking(Uniform64(-1, 1, 1200, 2500))) 216 op := newSoftmaxOp(a.Shape()) 217 op.axis = 0 218 out := make([]float64, len(a.Data().([]float64))) 219 220 b.ResetTimer() 221 b.StartTimer() 222 for i := 0; i < b.N; i++ { 223 op.do(a.Shape(), 0, a.Data().([]float64), out) 224 } 225 226 }