gorgonia.org/gorgonia@v0.9.17/example_symdiff_test.go (about)

     1  package gorgonia_test
     2  
     3  import (
     4  	"fmt"
     5  	"log"
     6  
     7  	. "gorgonia.org/gorgonia"
     8  )
     9  
    10  // SymbolicDiff showcases symbolic differentiation
    11  func Example_symbolicDiff() {
    12  	g := NewGraph()
    13  
    14  	var x, y, z *Node
    15  	var err error
    16  
    17  	// define the expression
    18  	x = NewScalar(g, Float64, WithName("x"))
    19  	y = NewScalar(g, Float64, WithName("y"))
    20  	if z, err = Add(x, y); err != nil {
    21  		log.Fatal(err)
    22  	}
    23  
    24  	// symbolically differentiate z with regards to x and y
    25  	// this adds the gradient nodes to the graph g
    26  	var grads Nodes
    27  	if grads, err = Grad(z, x, y); err != nil {
    28  		log.Fatal(err)
    29  	}
    30  
    31  	// create a VM to run the program on
    32  	machine := NewTapeMachine(g)
    33  	defer machine.Close()
    34  
    35  	// set initial values then run
    36  	Let(x, 2.0)
    37  	Let(y, 2.5)
    38  	if err = machine.RunAll(); err != nil {
    39  		log.Fatal(err)
    40  	}
    41  
    42  	fmt.Printf("z: %v\n", z.Value())
    43  	if xgrad, err := x.Grad(); err == nil {
    44  		fmt.Printf("dz/dx: %v | %v\n", xgrad, grads[0].Value())
    45  	}
    46  
    47  	if ygrad, err := y.Grad(); err == nil {
    48  		fmt.Printf("dz/dy: %v | %v\n", ygrad, grads[1].Value())
    49  	}
    50  
    51  	// Output:
    52  	// z: 4.5
    53  	// dz/dx: 1 | 1
    54  	// dz/dy: 1 | 1
    55  }