gorgonia.org/gorgonia@v0.9.17/example_concurrent_training_test.go (about)

     1  package gorgonia_test
     2  
     3  import (
     4  	"fmt"
     5  	"runtime"
     6  	"sync"
     7  
     8  	. "gorgonia.org/gorgonia"
     9  	"gorgonia.org/tensor"
    10  )
    11  
    12  const (
    13  	// rows      = 373127
    14  	// cols      = 53
    15  
    16  	// We'll use a nice even sized batch size, instead of weird prime numbers
    17  	rows      = 30000
    18  	cols      = 5
    19  	batchSize = 100
    20  	epochs    = 10
    21  )
    22  
    23  type concurrentTrainer struct {
    24  	g    *ExprGraph
    25  	x, y *Node
    26  	vm   VM
    27  	// cost Value
    28  
    29  	batchSize int
    30  	epoch     int // number of epochs done
    31  }
    32  
    33  func newConcurrentTrainer() *concurrentTrainer {
    34  	g := NewGraph()
    35  	x := NewMatrix(g, Float64, WithShape(batchSize, cols), WithName("x"))
    36  	y := NewVector(g, Float64, WithShape(batchSize), WithName("y"))
    37  	xT := Must(Transpose(x))
    38  	z := Must(Mul(xT, y))
    39  	sz := Must(Sum(z))
    40  
    41  	// Read(sz, &ct.cost)
    42  	Grad(sz, x, y)
    43  	vm := NewTapeMachine(g, BindDualValues(x, y))
    44  
    45  	return &concurrentTrainer{
    46  		g:  g,
    47  		x:  x,
    48  		y:  y,
    49  		vm: vm,
    50  
    51  		batchSize: batchSize,
    52  		epoch:     -1,
    53  	}
    54  }
    55  
    56  type cost struct {
    57  	Nodes []ValueGrad
    58  	VM    // placed for debugging purposes. In real life use you can just use a channel of Nodes
    59  
    60  	// cost Value
    61  }
    62  
    63  func (t *concurrentTrainer) train(x, y Value, costChan chan cost, wg *sync.WaitGroup) {
    64  	Let(t.x, x)
    65  	Let(t.y, y)
    66  	if err := t.vm.RunAll(); err != nil {
    67  		panic("HELP")
    68  	}
    69  
    70  	costChan <- cost{
    71  		[]ValueGrad{t.x, t.y},
    72  		t.vm,
    73  		// t.cost,
    74  	}
    75  
    76  	t.vm.Reset()
    77  	wg.Done()
    78  }
    79  
    80  func trainEpoch(bs []batch, ts []*concurrentTrainer, threads int) {
    81  	// costs := make([]float64, 0, len(bs))
    82  	chunks := len(bs) / len(ts)
    83  	for chunk := 0; chunk <= chunks; chunk++ {
    84  		costChan := make(chan cost, len(bs))
    85  
    86  		var wg sync.WaitGroup
    87  		for i, t := range ts {
    88  			idx := chunk*threads + i
    89  			if idx >= len(bs) {
    90  				break
    91  			}
    92  			b := bs[idx]
    93  
    94  			wg.Add(1)
    95  			go t.train(b.xs, b.ys, costChan, &wg)
    96  		}
    97  		wg.Wait()
    98  		close(costChan)
    99  
   100  		solver := NewVanillaSolver(WithLearnRate(0.01), WithBatchSize(batchSize))
   101  		for cost := range costChan {
   102  			// y := cost.Nodes[1].Value()
   103  			// yG, _ := cost.Nodes[1].Grad()
   104  			// c := cost.cost.Data().(float64)
   105  			// costs = append(costs, c)
   106  			solver.Step(cost.Nodes)
   107  		}
   108  	}
   109  
   110  	// var avg float64
   111  	// for _, c := range costs {
   112  	// 	avg += c
   113  	// }
   114  	// avg /= float64(len(costs))
   115  }
   116  
   117  type batch struct {
   118  	xs Value
   119  	ys Value
   120  }
   121  
   122  func prep() (x, y Value, bs []batch) {
   123  	xV := tensor.New(tensor.WithShape(rows, cols), tensor.WithBacking(tensor.Range(Float64, 0, cols*rows)))
   124  	yV := tensor.New(tensor.WithShape(rows), tensor.WithBacking(tensor.Range(Float64, 0, rows)))
   125  
   126  	// prep the data: y = ΣnX, where n = col ID, x ∈ X = colID / 100
   127  	xData := xV.Data().([]float64)
   128  	yData := yV.Data().([]float64)
   129  	for r := 0; r < rows; r++ {
   130  		var sum float64
   131  		for c := 0; c < cols; c++ {
   132  			idx := r*cols + c
   133  			fc := float64(c)
   134  			v := fc * fc / 100
   135  			xData[idx] = v
   136  			sum += v
   137  		}
   138  		yData[r] = sum
   139  	}
   140  
   141  	// batch the examples up into their respective batchSize
   142  	for i := 0; i < rows; i += batchSize {
   143  		xVS, _ := xV.Slice(S(i, i+batchSize))
   144  		yVS, _ := yV.Slice(S(i, i+batchSize))
   145  		b := batch{xVS, yVS}
   146  		bs = append(bs, b)
   147  	}
   148  	return xV, yV, bs
   149  }
   150  
   151  func concurrentTraining(xV, yV Value, bs []batch, es int) {
   152  	threads := runtime.NumCPU()
   153  
   154  	ts := make([]*concurrentTrainer, threads)
   155  	for chunk := 0; chunk < threads; chunk++ {
   156  		trainer := newConcurrentTrainer()
   157  		ts[chunk] = trainer
   158  		defer trainer.vm.Close()
   159  	}
   160  
   161  	for e := 0; e < es; e++ {
   162  		trainEpoch(bs, ts, threads)
   163  	}
   164  }
   165  
   166  func nonConcurrentTraining(xV, yV Value, es int) {
   167  	g := NewGraph()
   168  	x := NewMatrix(g, Float64, WithValue(xV))
   169  	y := NewVector(g, Float64, WithValue(yV))
   170  	xT := Must(Transpose(x))
   171  	z := Must(Mul(xT, y))
   172  	sz := Must(Sum(z))
   173  	Grad(sz, x, y)
   174  	vm := NewTapeMachine(g, BindDualValues(x, y))
   175  
   176  	Let(x, xV)
   177  	Let(y, yV)
   178  	solver := NewVanillaSolver(WithLearnRate(0.01), WithBatchSize(batchSize))
   179  	for i := 0; i < es; i++ {
   180  		vm.RunAll()
   181  		solver.Step([]ValueGrad{x, y})
   182  		vm.Reset()
   183  		runtime.GC()
   184  	}
   185  }
   186  
   187  func Example_concurrentTraining() {
   188  	xV, yV, bs := prep()
   189  	concurrentTraining(xV, yV, bs, epochs)
   190  
   191  	fmt.Printf("x:\n%1.1v", xV)
   192  	fmt.Printf("y:\n%1.1v", yV)
   193  
   194  	// Output:
   195  	// x:
   196  	// ⎡-0.0003     0.01     0.04     0.09      0.2⎤
   197  	// ⎢-0.0003     0.01     0.04     0.09      0.2⎥
   198  	// ⎢-0.0003     0.01     0.04     0.09      0.2⎥
   199  	// ⎢-0.0003     0.01     0.04     0.09      0.2⎥
   200  	// .
   201  	// .
   202  	// .
   203  	// ⎢-0.0003     0.01     0.04     0.09      0.2⎥
   204  	// ⎢-0.0003     0.01     0.04     0.09      0.2⎥
   205  	// ⎢-0.0003     0.01     0.04     0.09      0.2⎥
   206  	// ⎣-0.0003     0.01     0.04     0.09      0.2⎦
   207  	// y:
   208  	// [0.3  0.3  0.3  0.3  ... 0.3  0.3  0.3  0.3]
   209  
   210  }
   211  
   212  func Example_nonConcurrentTraining() {
   213  	xV, yV, _ := prep()
   214  	nonConcurrentTraining(xV, yV, epochs)
   215  
   216  	fmt.Printf("x:\n%1.1v", xV)
   217  	fmt.Printf("y:\n%1.1v", yV)
   218  
   219  	//Output:
   220  	// x:
   221  	// ⎡-0.0003     0.01     0.04     0.09      0.2⎤
   222  	// ⎢-0.0003     0.01     0.04     0.09      0.2⎥
   223  	// ⎢-0.0003     0.01     0.04     0.09      0.2⎥
   224  	// ⎢-0.0003     0.01     0.04     0.09      0.2⎥
   225  	// .
   226  	// .
   227  	// .
   228  	// ⎢-0.0003     0.01     0.04     0.09      0.2⎥
   229  	// ⎢-0.0003     0.01     0.04     0.09      0.2⎥
   230  	// ⎢-0.0003     0.01     0.04     0.09      0.2⎥
   231  	// ⎣-0.0003     0.01     0.04     0.09      0.2⎦
   232  	// y:
   233  	// [0.3  0.3  0.3  0.3  ... 0.3  0.3  0.3  0.3]
   234  
   235  }