gorgonia.org/gorgonia@v0.9.17/example_tensordot_test.go (about)

     1  package gorgonia
     2  
     3  import (
     4  	"fmt"
     5  )
     6  
     7  /*
     8  func ExampleTensordot_scalar() {
     9  	// Scalars
    10  	g := NewGraph()
    11  	a := NewScalar(g, Float64, WithValue(2.0), WithName("a"))
    12  	b := NewScalar(g, Float64, WithValue(21.0), WithName("b"))
    13  	c, err := Tensordot([]int{0}, []int{0}, a, b)
    14  	if err != nil {
    15  		fmt.Printf("Cannot call Tensordot. Error: %v\n", err)
    16  		return
    17  	}
    18  
    19  	vm := NewTapeMachine(g)
    20  	if err := vm.RunAll(); err != nil {
    21  		fmt.Printf("Cannot perform scalars. Error %v\n", err)
    22  	}
    23  	fmt.Printf("c: %v (%v) of %v", c.Value(), c.Value().Dtype(), c.Value().Shape())
    24  
    25  	// Output:
    26  	//...
    27  }
    28  */
    29  func ExampleTensordot_vectors() {
    30  	g := NewGraph()
    31  	a := NewVector(g, Float64, WithName("a"), WithShape(2), WithInit(RangedFrom(2)))
    32  	b := NewVector(g, Float64, WithName("b"), WithShape(2), WithInit(RangedFrom(21)))
    33  
    34  	c, err := Tensordot([]int{0}, []int{0}, a, b)
    35  	if err != nil {
    36  		fmt.Printf("Cannot call Tensordot. Error: %v\n", err)
    37  		return
    38  	}
    39  
    40  	vm := NewTapeMachine(g)
    41  	if err := vm.RunAll(); err != nil {
    42  		fmt.Printf("Cannot perform tensordot on vectors. Error %v\n", err)
    43  	}
    44  	fmt.Printf("a %v b %v ", a.Value(), b.Value())
    45  	fmt.Printf("c: %v (%v) of %v", c.Value(), c.Type(), c.Value().Shape())
    46  
    47  	// Output:
    48  	// a [2  3] b [21  22] c: [108] (float64) of (1)
    49  
    50  }