gorgonia.org/gorgonia@v0.9.17/vm_genera_cuda.go (about)

     1  // +build cuda
     2  
     3  package gorgonia
     4  
     5  import (
     6  	"log"
     7  
     8  	"github.com/pkg/errors"
     9  	"gorgonia.org/tensor"
    10  )
    11  
    12  func (m *lispMachine) init() error {
    13  	if err := m.prepGraph(); err != nil {
    14  		return err
    15  	}
    16  
    17  	// VERY simple data analysis (even simpler than the one used in Compile)
    18  	// using replaceWithSelf reduces the need for hashing, hence less work is required
    19  	// However this also means that CSE won't be performed
    20  	df := newdataflow()
    21  	df.replaceWithSelf(m.sorted)
    22  	df.buildIntervals(m.sorted)
    23  	df.fixIntervalDevices(m.sorted)
    24  	m.df = df
    25  
    26  	if err := m.calcMemSize(); err != nil {
    27  		log.Printf("err1")
    28  		return err
    29  	}
    30  
    31  	if len(m.gpumem) == 0 {
    32  		m.ForceCPU()
    33  		return nil
    34  	}
    35  
    36  	if err := m.ExternMetadata.init(m.gpumem); err != nil {
    37  		m.ExternMetadata.initFail()
    38  		return err
    39  	}
    40  	m.loadStdLib()
    41  
    42  	if len(m.engines) == 0 {
    43  		m.ForceCPU()
    44  	}
    45  	return nil
    46  }
    47  
    48  func finalizeLispMachine(m *lispMachine) {
    49  	m.ExternMetadata.cleanup()
    50  	m.ExternMetadata.initFail()
    51  }
    52  
    53  func (m *lispMachine) WorkAvailable() <-chan bool {
    54  	if m.ExternMetadata.WorkAvailable() == nil {
    55  		return nil
    56  	}
    57  	return m.ExternMetadata.WorkAvailable()
    58  }
    59  
    60  func (m *lispMachine) calcMemSize() (err error) {
    61  	compileLogf("calcmemsize")
    62  	enterLogScope()
    63  	defer leaveLogScope()
    64  	var cpumem int64
    65  	var gpumem []int64
    66  	for _, n := range m.sorted {
    67  		interv := m.df.intervals[n]
    68  		dev := interv.result.device
    69  		compileLogf("n: %v | %v", n, interv)
    70  
    71  		var dt tensor.Dtype
    72  		if dt, err = dtypeOf(n.t); err != nil {
    73  			if n.isStmt {
    74  				continue
    75  			}
    76  			return errors.Wrapf(err, "Cannot calculate memsize of n(%v)", n)
    77  		}
    78  		switch {
    79  		case n.isArg():
    80  			cpumem += calcMemSize(dt, n.Shape())
    81  		case n.isStmt:
    82  		default:
    83  			// if !n.op.ReturnsPtr() {
    84  			if dev != CPU {
    85  				if len(gpumem) < int(dev)+1 {
    86  					diff := int(dev) + 1 - len(gpumem)
    87  					gpumem = append(gpumem, make([]int64, diff)...)
    88  				}
    89  			}
    90  
    91  			switch dev {
    92  			case CPU:
    93  				cpumem += calcMemSize(dt, n.Shape())
    94  			default:
    95  				compileLogf("n: %v. AddedDEF", n)
    96  				gpumem[int(dev)] += 4 * calcMemSize(dt, n.Shape())
    97  			}
    98  			// }
    99  		}
   100  	}
   101  
   102  	m.cpumem = cpumem
   103  	m.gpumem = gpumem
   104  	return nil
   105  }
   106  
   107  func (m *lispMachine) execDevTrans(op devTrans, n *Node, children Nodes) (err error) {
   108  	child := children[0]
   109  	m.logf("DevTrans: %v | %v | %v", op, n.boundTo, child.boundTo)
   110  
   111  	var dv *dualValue
   112  	var cv, cd, v, d Value
   113  	if child.boundTo != nil {
   114  		var ok bool
   115  		if dv, ok = child.boundTo.(*dualValue); ok {
   116  			cv = dv.Value
   117  			cd = dv.d
   118  		} else {
   119  			cv = child.boundTo
   120  		}
   121  	} else {
   122  		err = errors.Errorf("Cannot execute transfer when there is no value in child")
   123  		return
   124  	}
   125  
   126  	var synchronous bool
   127  	if op.to == CPU && op.from != CPU {
   128  		synchronous = true
   129  	}
   130  
   131  	if v, err = m.Transfer(op.to, op.from, cv, false); err != nil {
   132  		return
   133  	}
   134  
   135  	if cd != nil {
   136  		if d, err = m.Transfer(op.to, op.from, cd, false); err != nil {
   137  			return
   138  		}
   139  	} else {
   140  		var mem tensor.Memory
   141  		if mem, err = m.Get(op.to, calcMemSize(cv.Dtype(), child.shape)); err != nil {
   142  			return
   143  		}
   144  		if _, err = makeValueFromMem(child.t, child.shape, mem); err != nil {
   145  			return
   146  		}
   147  	}
   148  
   149  	if synchronous {
   150  		m.Signal()
   151  	}
   152  
   153  	dv = new(dualValue)
   154  	dv.Value = v
   155  	dv.d = d
   156  	n.boundTo = dv
   157  
   158  	return nil
   159  }
   160  
   161  // loads the standardlib
   162  func (m *lispMachine) loadStdLib() {
   163  	if cudaStdLib == nil {
   164  		return
   165  	}
   166  
   167  	for _, lib := range cudaStdLib {
   168  		for i := range m.engines {
   169  			e := &m.engines[i]
   170  			if err := e.LoadCUDAFunc(lib.name, lib.data, lib.funcs); err != nil {
   171  				panic(err)
   172  			}
   173  		}
   174  	}
   175  }
   176  
   177  // ForceCPU forces the lispMachine to have the nodes run on the CPU
   178  func (m *lispMachine) ForceCPU() {
   179  	m.cleanup()
   180  	m.initFail()
   181  	m.df = nil
   182  
   183  	for _, n := range m.sorted {
   184  		n.dataOn = CPU
   185  	}
   186  
   187  	// remove devTrans if any
   188  	for i := 0; i < len(m.sorted); i++ {
   189  		n := m.sorted[i]
   190  		if _, ok := n.op.(devTrans); ok {
   191  			copy(m.sorted[i:], m.sorted[i+1:])
   192  			m.sorted = m.sorted[:len(m.sorted)-1]
   193  			i--
   194  		}
   195  	}
   196  }