gorgonia.org/gorgonia@v0.9.17/example_operations_test.go (about) 1 package gorgonia 2 3 import ( 4 "fmt" 5 "strings" 6 7 "gorgonia.org/tensor" 8 ) 9 10 func ExampleSoftMax() { 11 g := NewGraph() 12 t := tensor.New(tensor.WithShape(2, 3), tensor.WithBacking([]float64{1, 3, 2, 3, 2, 1})) 13 u := t.Clone().(*tensor.Dense) 14 v := tensor.New(tensor.WithShape(2, 2, 3), tensor.WithBacking([]float64{ 15 1, 3, 2, 16 4, 2, 1, 17 18 3, 5, 3, 19 2, 1, 5, 20 })) 21 22 a := NodeFromAny(g, t, WithName("a")) 23 b := NodeFromAny(g, u, WithName("b")) 24 c := NodeFromAny(g, v, WithName("c")) 25 26 sm1 := Must(SoftMax(a)) 27 sm0 := Must(SoftMax(b, 0)) 28 sm := Must(SoftMax(c)) 29 m := NewTapeMachine(g) 30 if err := m.RunAll(); err != nil { 31 panic(err) 32 } 33 34 fmt.Printf("a:\n%v\nsoftmax(a) - along last axis (default behaviour):\n%1.2f", a.Value(), sm1.Value()) 35 fmt.Printf("b:\n%v\nsoftmax(b) - along axis 0:\n%1.2f", b.Value(), sm0.Value()) 36 37 tmp := fmt.Sprintf("c %v:\n%v\nsoftmax(c) - along last axis (default behaviour) %v:\n%1.2f", c.Value().Shape(), c.Value(), sm.Value().Shape(), sm.Value()) 38 39 fmt.Println(strings.Replace(tmp, "\n\n\n", "\n\n", -1)) 40 41 // the requirement to use tmp and strings.Replace is because when Go runs example tests, it strips excess newlines. 42 43 // Output: 44 // a: 45 // ⎡1 3 2⎤ 46 // ⎣3 2 1⎦ 47 // 48 // softmax(a) - along last axis (default behaviour): 49 // ⎡0.09 0.67 0.24⎤ 50 // ⎣0.67 0.24 0.09⎦ 51 // b: 52 // ⎡1 3 2⎤ 53 // ⎣3 2 1⎦ 54 // 55 // softmax(b) - along axis 0: 56 // ⎡0.12 0.73 0.73⎤ 57 // ⎣0.88 0.27 0.27⎦ 58 // c (2, 2, 3): 59 // ⎡1 3 2⎤ 60 // ⎣4 2 1⎦ 61 // 62 // ⎡3 5 3⎤ 63 // ⎣2 1 5⎦ 64 // 65 // 66 // softmax(c) - along last axis (default behaviour) (2, 2, 3): 67 // ⎡0.09 0.67 0.24⎤ 68 // ⎣0.84 0.11 0.04⎦ 69 // 70 // ⎡0.11 0.79 0.11⎤ 71 // ⎣0.05 0.02 0.94⎦ 72 73 } 74 75 func ExampleConcat() { 76 g := NewGraph() 77 x := NewTensor(g, Float64, 4, WithShape(2, 3, 4, 5), WithInit(RangedFrom(0)), WithName("x")) 78 y := NewTensor(g, Float64, 4, WithShape(2, 3, 4, 5), WithInit(RangedFrom(120)), WithName("y")) 79 80 z, err := Concat(2, x, y) 81 if err != nil { 82 panic(err) 83 } 84 85 m := NewTapeMachine(g) 86 if err := m.RunAll(); err != nil { 87 panic(err) 88 } 89 tmp := fmt.Sprintf("z %v\n%v", z.Value().Shape(), z.Value()) 90 fmt.Println(strings.Replace(tmp, "\n\n", "\n", -1)) // this is because 91 92 // Output: 93 //z (2, 3, 8, 5) 94 //⎡ 0 1 2 3 4⎤ 95 //⎢ 5 6 7 8 9⎥ 96 //⎢ 10 11 12 13 14⎥ 97 //⎢ 15 16 17 18 19⎥ 98 //⎢120 121 122 123 124⎥ 99 //⎢125 126 127 128 129⎥ 100 //⎢130 131 132 133 134⎥ 101 //⎣135 136 137 138 139⎦ 102 // 103 // 104 //⎡ 20 21 22 23 24⎤ 105 //⎢ 25 26 27 28 29⎥ 106 //⎢ 30 31 32 33 34⎥ 107 //⎢ 35 36 37 38 39⎥ 108 //⎢140 141 142 143 144⎥ 109 //⎢145 146 147 148 149⎥ 110 //⎢150 151 152 153 154⎥ 111 //⎣155 156 157 158 159⎦ 112 // 113 // 114 //⎡ 40 41 42 43 44⎤ 115 //⎢ 45 46 47 48 49⎥ 116 //⎢ 50 51 52 53 54⎥ 117 //⎢ 55 56 57 58 59⎥ 118 //⎢160 161 162 163 164⎥ 119 //⎢165 166 167 168 169⎥ 120 //⎢170 171 172 173 174⎥ 121 //⎣175 176 177 178 179⎦ 122 // 123 // 124 //⎡ 60 61 62 63 64⎤ 125 //⎢ 65 66 67 68 69⎥ 126 //⎢ 70 71 72 73 74⎥ 127 //⎢ 75 76 77 78 79⎥ 128 //⎢180 181 182 183 184⎥ 129 //⎢185 186 187 188 189⎥ 130 //⎢190 191 192 193 194⎥ 131 //⎣195 196 197 198 199⎦ 132 // 133 // 134 //⎡ 80 81 82 83 84⎤ 135 //⎢ 85 86 87 88 89⎥ 136 //⎢ 90 91 92 93 94⎥ 137 //⎢ 95 96 97 98 99⎥ 138 //⎢200 201 202 203 204⎥ 139 //⎢205 206 207 208 209⎥ 140 //⎢210 211 212 213 214⎥ 141 //⎣215 216 217 218 219⎦ 142 // 143 // 144 //⎡100 101 102 103 104⎤ 145 //⎢105 106 107 108 109⎥ 146 //⎢110 111 112 113 114⎥ 147 //⎢115 116 117 118 119⎥ 148 //⎢220 221 222 223 224⎥ 149 //⎢225 226 227 228 229⎥ 150 //⎢230 231 232 233 234⎥ 151 //⎣235 236 237 238 239⎦ 152 } 153 154 func ExampleUnconcat() { 155 g := NewGraph() 156 x := NewTensor(g, Float64, 4, WithShape(2, 3, 4, 5), WithInit(RangedFrom(0)), WithName("x")) 157 y := NewTensor(g, Float64, 4, WithShape(2, 3, 4, 5), WithInit(RangedFrom(120)), WithName("y")) 158 159 z, err := Concat(2, x, y) 160 if err != nil { 161 panic(err) 162 } 163 164 unconcats, err := Unconcat(z, 2, 2) 165 if err != nil { 166 panic(err) 167 } 168 a, b := unconcats[0], unconcats[1] 169 170 m := NewTapeMachine(g) 171 if err := m.RunAll(); err != nil { 172 panic(err) 173 } 174 tmp := fmt.Sprintf("a %v\n%v\nb %v\n%v", a.Value().Shape(), a.Value(), b.Value().Shape(), b.Value()) 175 fmt.Println(strings.Replace(tmp, "\n\n", "\n", -1)) 176 177 // Output: 178 // a (2, 3, 4, 5) 179 // ⎡ 0 1 2 3 4⎤ 180 // ⎢ 5 6 7 8 9⎥ 181 // ⎢ 10 11 12 13 14⎥ 182 // ⎣ 15 16 17 18 19⎦ 183 // 184 // 185 // ⎡ 20 21 22 23 24⎤ 186 // ⎢ 25 26 27 28 29⎥ 187 // ⎢ 30 31 32 33 34⎥ 188 // ⎣ 35 36 37 38 39⎦ 189 // 190 // 191 // ⎡ 40 41 42 43 44⎤ 192 // ⎢ 45 46 47 48 49⎥ 193 // ⎢ 50 51 52 53 54⎥ 194 // ⎣ 55 56 57 58 59⎦ 195 // 196 // 197 // ⎡ 60 61 62 63 64⎤ 198 // ⎢ 65 66 67 68 69⎥ 199 // ⎢ 70 71 72 73 74⎥ 200 // ⎣ 75 76 77 78 79⎦ 201 // 202 // 203 // ⎡ 80 81 82 83 84⎤ 204 // ⎢ 85 86 87 88 89⎥ 205 // ⎢ 90 91 92 93 94⎥ 206 // ⎣ 95 96 97 98 99⎦ 207 // 208 // 209 // ⎡100 101 102 103 104⎤ 210 // ⎢105 106 107 108 109⎥ 211 // ⎢110 111 112 113 114⎥ 212 // ⎣115 116 117 118 119⎦ 213 // 214 // 215 // 216 // b (2, 3, 4, 5) 217 // ⎡120 121 122 123 124⎤ 218 // ⎢125 126 127 128 129⎥ 219 // ⎢130 131 132 133 134⎥ 220 // ⎣135 136 137 138 139⎦ 221 // 222 // 223 // ⎡140 141 142 143 144⎤ 224 // ⎢145 146 147 148 149⎥ 225 // ⎢150 151 152 153 154⎥ 226 // ⎣155 156 157 158 159⎦ 227 // 228 // 229 // ⎡160 161 162 163 164⎤ 230 // ⎢165 166 167 168 169⎥ 231 // ⎢170 171 172 173 174⎥ 232 // ⎣175 176 177 178 179⎦ 233 // 234 // 235 // ⎡180 181 182 183 184⎤ 236 // ⎢185 186 187 188 189⎥ 237 // ⎢190 191 192 193 194⎥ 238 // ⎣195 196 197 198 199⎦ 239 // 240 // 241 // ⎡200 201 202 203 204⎤ 242 // ⎢205 206 207 208 209⎥ 243 // ⎢210 211 212 213 214⎥ 244 // ⎣215 216 217 218 219⎦ 245 // 246 // 247 // ⎡220 221 222 223 224⎤ 248 // ⎢225 226 227 228 229⎥ 249 // ⎢230 231 232 233 234⎥ 250 // ⎣235 236 237 238 239⎦ 251 }