gorgonia.org/gorgonia@v0.9.17/cmd/gencudaengine/main.go (about)

     1  package main
     2  
     3  import (
     4  	"fmt"
     5  	"log"
     6  	"os"
     7  	"os/exec"
     8  	"os/user"
     9  	"path"
    10  	"strings"
    11  	"text/template"
    12  )
    13  
    14  const genmsg = "// Code generated by gencudaengine, which is a API generation tool for Gorgonia. DO NOT EDIT."
    15  
    16  const (
    17  	arithOut = "arith.go"
    18  	unaryOut = "unary.go"
    19  	cmpOut   = "cmp.go"
    20  )
    21  
    22  var (
    23  	gopath, cudaengloc string
    24  )
    25  
    26  var funcmap = template.FuncMap{
    27  	"lower": strings.ToLower,
    28  }
    29  
    30  type binOp struct {
    31  	Method       string
    32  	ScalarMethod string
    33  }
    34  
    35  var ariths = []binOp{
    36  	{"Add", "Add"},
    37  	{"Sub", "Sub"},
    38  	{"Mul", "Mul"},
    39  	{"Div", "Div"},
    40  	{"Pow", "Pow"},
    41  	{"Mod", "Mod"},
    42  }
    43  
    44  var cmps = []binOp{
    45  	{"Lt", "Lt"},
    46  	{"Lte", "Lte"},
    47  	{"Gt", "Gt"},
    48  	{"Gte", "Gte"},
    49  	{"ElEq", "Eq"},
    50  	{"ElNe", "Ne"},
    51  }
    52  
    53  func init() {
    54  	gopath = os.Getenv("GOPATH")
    55  	if gopath == "" {
    56  		usr, err := user.Current()
    57  		if err != nil {
    58  			log.Fatal(err)
    59  		}
    60  		gopath = path.Join(usr.HomeDir, "go")
    61  		stat, err := os.Stat(gopath)
    62  		if err != nil {
    63  			log.Fatal(err)
    64  		}
    65  		if !stat.IsDir() {
    66  			log.Fatal("You need to define a $GOPATH")
    67  		}
    68  	}
    69  	cudaengloc = path.Join(gopath, "src/gorgonia.org/gorgonia/cuda")
    70  }
    71  
    72  func generateAriths() {
    73  	p := path.Join(cudaengloc, arithOut)
    74  	f, _ := os.OpenFile(p, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
    75  	fmt.Fprintf(f, "package cuda\n\n%v\n\n", genmsg)
    76  
    77  	for _, op := range ariths {
    78  		binopTmpl.Execute(f, op)
    79  	}
    80  
    81  	f.Close()
    82  	cmd := exec.Command("goimports", "-w", p)
    83  	if err := cmd.Run(); err != nil {
    84  		log.Fatalf("Go imports failed with %v for %q", err, p)
    85  	}
    86  }
    87  
    88  func generateCmps() {
    89  	p := path.Join(cudaengloc, cmpOut)
    90  	f, _ := os.OpenFile(p, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
    91  	fmt.Fprintf(f, "package cuda\n\n%v\n\n", genmsg)
    92  
    93  	for _, op := range cmps {
    94  		binopTmpl.Execute(f, op)
    95  	}
    96  
    97  	f.Close()
    98  	cmd := exec.Command("goimports", "-w", p)
    99  	if err := cmd.Run(); err != nil {
   100  		log.Fatalf("Go imports failed with %v for %q", err, p)
   101  	}
   102  }
   103  
   104  func main() {
   105  	generateAriths()
   106  	generateCmps()
   107  }