gorgonia.org/gorgonia@v0.9.17/cmd/gencudaengine/main.go (about) 1 package main 2 3 import ( 4 "fmt" 5 "log" 6 "os" 7 "os/exec" 8 "os/user" 9 "path" 10 "strings" 11 "text/template" 12 ) 13 14 const genmsg = "// Code generated by gencudaengine, which is a API generation tool for Gorgonia. DO NOT EDIT." 15 16 const ( 17 arithOut = "arith.go" 18 unaryOut = "unary.go" 19 cmpOut = "cmp.go" 20 ) 21 22 var ( 23 gopath, cudaengloc string 24 ) 25 26 var funcmap = template.FuncMap{ 27 "lower": strings.ToLower, 28 } 29 30 type binOp struct { 31 Method string 32 ScalarMethod string 33 } 34 35 var ariths = []binOp{ 36 {"Add", "Add"}, 37 {"Sub", "Sub"}, 38 {"Mul", "Mul"}, 39 {"Div", "Div"}, 40 {"Pow", "Pow"}, 41 {"Mod", "Mod"}, 42 } 43 44 var cmps = []binOp{ 45 {"Lt", "Lt"}, 46 {"Lte", "Lte"}, 47 {"Gt", "Gt"}, 48 {"Gte", "Gte"}, 49 {"ElEq", "Eq"}, 50 {"ElNe", "Ne"}, 51 } 52 53 func init() { 54 gopath = os.Getenv("GOPATH") 55 if gopath == "" { 56 usr, err := user.Current() 57 if err != nil { 58 log.Fatal(err) 59 } 60 gopath = path.Join(usr.HomeDir, "go") 61 stat, err := os.Stat(gopath) 62 if err != nil { 63 log.Fatal(err) 64 } 65 if !stat.IsDir() { 66 log.Fatal("You need to define a $GOPATH") 67 } 68 } 69 cudaengloc = path.Join(gopath, "src/gorgonia.org/gorgonia/cuda") 70 } 71 72 func generateAriths() { 73 p := path.Join(cudaengloc, arithOut) 74 f, _ := os.OpenFile(p, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644) 75 fmt.Fprintf(f, "package cuda\n\n%v\n\n", genmsg) 76 77 for _, op := range ariths { 78 binopTmpl.Execute(f, op) 79 } 80 81 f.Close() 82 cmd := exec.Command("goimports", "-w", p) 83 if err := cmd.Run(); err != nil { 84 log.Fatalf("Go imports failed with %v for %q", err, p) 85 } 86 } 87 88 func generateCmps() { 89 p := path.Join(cudaengloc, cmpOut) 90 f, _ := os.OpenFile(p, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644) 91 fmt.Fprintf(f, "package cuda\n\n%v\n\n", genmsg) 92 93 for _, op := range cmps { 94 binopTmpl.Execute(f, op) 95 } 96 97 f.Close() 98 cmd := exec.Command("goimports", "-w", p) 99 if err := cmd.Run(); err != nil { 100 log.Fatalf("Go imports failed with %v for %q", err, p) 101 } 102 } 103 104 func main() { 105 generateAriths() 106 generateCmps() 107 }