gorgonia.org/gorgonia@v0.9.17/example_operations_test.go (about)

     1  package gorgonia
     2  
     3  import (
     4  	"fmt"
     5  	"strings"
     6  
     7  	"gorgonia.org/tensor"
     8  )
     9  
    10  func ExampleSoftMax() {
    11  	g := NewGraph()
    12  	t := tensor.New(tensor.WithShape(2, 3), tensor.WithBacking([]float64{1, 3, 2, 3, 2, 1}))
    13  	u := t.Clone().(*tensor.Dense)
    14  	v := tensor.New(tensor.WithShape(2, 2, 3), tensor.WithBacking([]float64{
    15  		1, 3, 2,
    16  		4, 2, 1,
    17  
    18  		3, 5, 3,
    19  		2, 1, 5,
    20  	}))
    21  
    22  	a := NodeFromAny(g, t, WithName("a"))
    23  	b := NodeFromAny(g, u, WithName("b"))
    24  	c := NodeFromAny(g, v, WithName("c"))
    25  
    26  	sm1 := Must(SoftMax(a))
    27  	sm0 := Must(SoftMax(b, 0))
    28  	sm := Must(SoftMax(c))
    29  	m := NewTapeMachine(g)
    30  	if err := m.RunAll(); err != nil {
    31  		panic(err)
    32  	}
    33  
    34  	fmt.Printf("a:\n%v\nsoftmax(a) - along last axis (default behaviour):\n%1.2f", a.Value(), sm1.Value())
    35  	fmt.Printf("b:\n%v\nsoftmax(b) - along axis 0:\n%1.2f", b.Value(), sm0.Value())
    36  
    37  	tmp := fmt.Sprintf("c %v:\n%v\nsoftmax(c) - along last axis (default behaviour) %v:\n%1.2f", c.Value().Shape(), c.Value(), sm.Value().Shape(), sm.Value())
    38  
    39  	fmt.Println(strings.Replace(tmp, "\n\n\n", "\n\n", -1))
    40  
    41  	// the requirement to use tmp and strings.Replace is because when Go runs example tests, it strips excess newlines.
    42  
    43  	// Output:
    44  	// a:
    45  	// ⎡1  3  2⎤
    46  	// ⎣3  2  1⎦
    47  	//
    48  	// softmax(a) - along last axis (default behaviour):
    49  	// ⎡0.09  0.67  0.24⎤
    50  	// ⎣0.67  0.24  0.09⎦
    51  	// b:
    52  	// ⎡1  3  2⎤
    53  	// ⎣3  2  1⎦
    54  	//
    55  	// softmax(b) - along axis 0:
    56  	// ⎡0.12  0.73  0.73⎤
    57  	// ⎣0.88  0.27  0.27⎦
    58  	// c (2, 2, 3):
    59  	// ⎡1  3  2⎤
    60  	// ⎣4  2  1⎦
    61  	//
    62  	// ⎡3  5  3⎤
    63  	// ⎣2  1  5⎦
    64  	//
    65  	//
    66  	// softmax(c) - along last axis (default behaviour) (2, 2, 3):
    67  	// ⎡0.09  0.67  0.24⎤
    68  	// ⎣0.84  0.11  0.04⎦
    69  	//
    70  	// ⎡0.11  0.79  0.11⎤
    71  	// ⎣0.05  0.02  0.94⎦
    72  
    73  }
    74  
    75  func ExampleConcat() {
    76  	g := NewGraph()
    77  	x := NewTensor(g, Float64, 4, WithShape(2, 3, 4, 5), WithInit(RangedFrom(0)), WithName("x"))
    78  	y := NewTensor(g, Float64, 4, WithShape(2, 3, 4, 5), WithInit(RangedFrom(120)), WithName("y"))
    79  
    80  	z, err := Concat(2, x, y)
    81  	if err != nil {
    82  		panic(err)
    83  	}
    84  
    85  	m := NewTapeMachine(g)
    86  	if err := m.RunAll(); err != nil {
    87  		panic(err)
    88  	}
    89  	tmp := fmt.Sprintf("z %v\n%v", z.Value().Shape(), z.Value())
    90  	fmt.Println(strings.Replace(tmp, "\n\n", "\n", -1)) // this is because
    91  
    92  	// Output:
    93  	//z (2, 3, 8, 5)
    94  	//⎡  0    1    2    3    4⎤
    95  	//⎢  5    6    7    8    9⎥
    96  	//⎢ 10   11   12   13   14⎥
    97  	//⎢ 15   16   17   18   19⎥
    98  	//⎢120  121  122  123  124⎥
    99  	//⎢125  126  127  128  129⎥
   100  	//⎢130  131  132  133  134⎥
   101  	//⎣135  136  137  138  139⎦
   102  	//
   103  	//
   104  	//⎡ 20   21   22   23   24⎤
   105  	//⎢ 25   26   27   28   29⎥
   106  	//⎢ 30   31   32   33   34⎥
   107  	//⎢ 35   36   37   38   39⎥
   108  	//⎢140  141  142  143  144⎥
   109  	//⎢145  146  147  148  149⎥
   110  	//⎢150  151  152  153  154⎥
   111  	//⎣155  156  157  158  159⎦
   112  	//
   113  	//
   114  	//⎡ 40   41   42   43   44⎤
   115  	//⎢ 45   46   47   48   49⎥
   116  	//⎢ 50   51   52   53   54⎥
   117  	//⎢ 55   56   57   58   59⎥
   118  	//⎢160  161  162  163  164⎥
   119  	//⎢165  166  167  168  169⎥
   120  	//⎢170  171  172  173  174⎥
   121  	//⎣175  176  177  178  179⎦
   122  	//
   123  	//
   124  	//⎡ 60   61   62   63   64⎤
   125  	//⎢ 65   66   67   68   69⎥
   126  	//⎢ 70   71   72   73   74⎥
   127  	//⎢ 75   76   77   78   79⎥
   128  	//⎢180  181  182  183  184⎥
   129  	//⎢185  186  187  188  189⎥
   130  	//⎢190  191  192  193  194⎥
   131  	//⎣195  196  197  198  199⎦
   132  	//
   133  	//
   134  	//⎡ 80   81   82   83   84⎤
   135  	//⎢ 85   86   87   88   89⎥
   136  	//⎢ 90   91   92   93   94⎥
   137  	//⎢ 95   96   97   98   99⎥
   138  	//⎢200  201  202  203  204⎥
   139  	//⎢205  206  207  208  209⎥
   140  	//⎢210  211  212  213  214⎥
   141  	//⎣215  216  217  218  219⎦
   142  	//
   143  	//
   144  	//⎡100  101  102  103  104⎤
   145  	//⎢105  106  107  108  109⎥
   146  	//⎢110  111  112  113  114⎥
   147  	//⎢115  116  117  118  119⎥
   148  	//⎢220  221  222  223  224⎥
   149  	//⎢225  226  227  228  229⎥
   150  	//⎢230  231  232  233  234⎥
   151  	//⎣235  236  237  238  239⎦
   152  }
   153  
   154  func ExampleUnconcat() {
   155  	g := NewGraph()
   156  	x := NewTensor(g, Float64, 4, WithShape(2, 3, 4, 5), WithInit(RangedFrom(0)), WithName("x"))
   157  	y := NewTensor(g, Float64, 4, WithShape(2, 3, 4, 5), WithInit(RangedFrom(120)), WithName("y"))
   158  
   159  	z, err := Concat(2, x, y)
   160  	if err != nil {
   161  		panic(err)
   162  	}
   163  
   164  	unconcats, err := Unconcat(z, 2, 2)
   165  	if err != nil {
   166  		panic(err)
   167  	}
   168  	a, b := unconcats[0], unconcats[1]
   169  
   170  	m := NewTapeMachine(g)
   171  	if err := m.RunAll(); err != nil {
   172  		panic(err)
   173  	}
   174  	tmp := fmt.Sprintf("a %v\n%v\nb %v\n%v", a.Value().Shape(), a.Value(), b.Value().Shape(), b.Value())
   175  	fmt.Println(strings.Replace(tmp, "\n\n", "\n", -1))
   176  
   177  	// Output:
   178  	// a (2, 3, 4, 5)
   179  	// ⎡  0    1    2    3    4⎤
   180  	// ⎢  5    6    7    8    9⎥
   181  	// ⎢ 10   11   12   13   14⎥
   182  	// ⎣ 15   16   17   18   19⎦
   183  	//
   184  	//
   185  	// ⎡ 20   21   22   23   24⎤
   186  	// ⎢ 25   26   27   28   29⎥
   187  	// ⎢ 30   31   32   33   34⎥
   188  	// ⎣ 35   36   37   38   39⎦
   189  	//
   190  	//
   191  	// ⎡ 40   41   42   43   44⎤
   192  	// ⎢ 45   46   47   48   49⎥
   193  	// ⎢ 50   51   52   53   54⎥
   194  	// ⎣ 55   56   57   58   59⎦
   195  	//
   196  	//
   197  	// ⎡ 60   61   62   63   64⎤
   198  	// ⎢ 65   66   67   68   69⎥
   199  	// ⎢ 70   71   72   73   74⎥
   200  	// ⎣ 75   76   77   78   79⎦
   201  	//
   202  	//
   203  	// ⎡ 80   81   82   83   84⎤
   204  	// ⎢ 85   86   87   88   89⎥
   205  	// ⎢ 90   91   92   93   94⎥
   206  	// ⎣ 95   96   97   98   99⎦
   207  	//
   208  	//
   209  	// ⎡100  101  102  103  104⎤
   210  	// ⎢105  106  107  108  109⎥
   211  	// ⎢110  111  112  113  114⎥
   212  	// ⎣115  116  117  118  119⎦
   213  	//
   214  	//
   215  	//
   216  	// b (2, 3, 4, 5)
   217  	// ⎡120  121  122  123  124⎤
   218  	// ⎢125  126  127  128  129⎥
   219  	// ⎢130  131  132  133  134⎥
   220  	// ⎣135  136  137  138  139⎦
   221  	//
   222  	//
   223  	// ⎡140  141  142  143  144⎤
   224  	// ⎢145  146  147  148  149⎥
   225  	// ⎢150  151  152  153  154⎥
   226  	// ⎣155  156  157  158  159⎦
   227  	//
   228  	//
   229  	// ⎡160  161  162  163  164⎤
   230  	// ⎢165  166  167  168  169⎥
   231  	// ⎢170  171  172  173  174⎥
   232  	// ⎣175  176  177  178  179⎦
   233  	//
   234  	//
   235  	// ⎡180  181  182  183  184⎤
   236  	// ⎢185  186  187  188  189⎥
   237  	// ⎢190  191  192  193  194⎥
   238  	// ⎣195  196  197  198  199⎦
   239  	//
   240  	//
   241  	// ⎡200  201  202  203  204⎤
   242  	// ⎢205  206  207  208  209⎥
   243  	// ⎢210  211  212  213  214⎥
   244  	// ⎣215  216  217  218  219⎦
   245  	//
   246  	//
   247  	// ⎡220  221  222  223  224⎤
   248  	// ⎢225  226  227  228  229⎥
   249  	// ⎢230  231  232  233  234⎥
   250  	// ⎣235  236  237  238  239⎦
   251  }