gorgonia.org/gorgonia@v0.9.17/cmd/cudagen/main.go (about) 1 package main 2 3 import ( 4 "bytes" 5 "flag" 6 "fmt" 7 "io/ioutil" 8 "log" 9 "os" 10 "os/exec" 11 "path" 12 "path/filepath" 13 "regexp" 14 "strings" 15 16 "gorgonia.org/cu" 17 ) 18 19 var debug = flag.Bool("debug", false, "compile with debug mode (-linelinfo is added to nvcc call)") 20 var sameModule = flag.Bool("same-module", false, "generate a cudamodules.go file which can be placed in the gorgonia dir instead of the application main dir") 21 22 var funcNameRegex = regexp.MustCompile("// .globl (.+?)\r?\n") 23 24 func stripExt(fullpath string) string { 25 _, filename := filepath.Split(fullpath) 26 ext := path.Ext(filename) 27 return filename[:len(filename)-len(ext)] 28 } 29 30 func compileCUDA(src string, maj, min int) ([]byte, error) { 31 target, err := ioutil.TempFile("", stripExt(src)+"_*.ptx") 32 if err != nil { 33 return nil, fmt.Errorf("failed to create temporary file for compilation output") 34 } 35 defer target.Close() 36 37 output := fmt.Sprintf("-o=%v", target.Name()) 38 arch := fmt.Sprintf("-arch=compute_%d%d", maj, min) 39 var cmd *exec.Cmd 40 if *debug { 41 cmd = exec.Command("nvcc", output, arch, "-lineinfo", "-ptx", "-Xptxas", "--allow-expensive-optimizations", "-fmad=false", "-ftz=false", "-prec-div=true", "-prec-sqrt=true", src) 42 } else { 43 cmd = exec.Command("nvcc", output, arch, "-ptx", "-Xptxas", "--allow-expensive-optimizations", "-fmad=false", "-ftz=false", "-prec-div=true", "-prec-sqrt=true", src) 44 } 45 var stderr bytes.Buffer 46 cmd.Stderr = &stderr 47 if err := cmd.Run(); err != nil || stderr.Len() != 0 { 48 return nil, fmt.Errorf("failed to compile with nvcc. Error: %v. nvcc error: %v", err, stderr.String()) 49 } 50 51 out, err := ioutil.ReadAll(target) 52 if err != nil { 53 return nil, fmt.Errorf("failed to read compilation output file. Error: %v", err) 54 } 55 if err := os.Remove(target.Name()); err != nil { 56 log.Printf("could not remove temporary file %v", target.Name()) 57 } 58 return out, nil 59 } 60 61 func packageLoc(name string) (string, error) { 62 cmd := exec.Command("go", "list", "-f", "{{.Dir}}", "-find", name) 63 var stdout, stderr bytes.Buffer 64 cmd.Stdout = &stdout 65 cmd.Stderr = &stderr 66 if err := cmd.Run(); err != nil || stderr.Len() != 0 { 67 return "", fmt.Errorf("failed to locate %v. Error: %v. go list error: %v", name, err, stderr.String()) 68 } 69 return strings.TrimSpace(stdout.String()), nil 70 } 71 72 func packageInWorkingDir() (string, error) { 73 cmd := exec.Command("go", "list", "-f", "{{.Name}}") 74 var stdout, stderr bytes.Buffer 75 cmd.Stdout = &stdout 76 cmd.Stderr = &stderr 77 if err := cmd.Run(); err != nil || stderr.Len() != 0 { 78 return "", fmt.Errorf("failed to get name of package in working directory. Error: %v. go list error: %v", err, stderr.String()) 79 } 80 return strings.TrimSpace(stdout.String()), nil 81 } 82 83 func gofmt(path string) error { 84 cmd := exec.Command("gofmt", "-w", path) 85 var stderr bytes.Buffer 86 cmd.Stderr = &stderr 87 if err := cmd.Run(); err != nil { 88 return fmt.Errorf("go imports failed with %v for %q. Error: %v", err, path, stderr.String()) 89 } 90 return nil 91 } 92 93 func main() { 94 flag.Parse() 95 96 var devices int 97 var err error 98 if devices, err = cu.NumDevices(); err != nil { 99 log.Fatalf("error while finding number of devices: %+v", err) 100 } 101 if devices == 0 { 102 log.Fatal("No CUDA-capable devices found") 103 } 104 105 // Get the lowest possible compute capability 106 major := int(^uint(0) >> 1) 107 minor := int(^uint(0) >> 1) 108 for d := 0; d < devices; d++ { 109 var dev cu.Device 110 if dev, err = cu.GetDevice(d); err != nil { 111 log.Fatalf("Unable to get GPU%d - %+v", d, err) 112 } 113 114 maj, min, err := dev.ComputeCapability() 115 if err != nil { 116 log.Fatalf("Unable to get compute compatibility of GPU%d - %v", d, err) 117 } 118 if maj > 0 && maj < major { 119 major = maj 120 minor = min 121 continue 122 } 123 124 if min > 0 && min < minor { 125 minor = min 126 } 127 } 128 129 cwd, err := os.Getwd() 130 if err != nil { 131 log.Fatal(err) 132 } 133 cudamodules := path.Join(cwd, "cudamodules.go") 134 packageName, err := packageInWorkingDir() 135 if err != nil { 136 log.Fatal(err) 137 } 138 139 gorgoniaLoc, err := packageLoc("gorgonia.org/gorgonia") 140 if err != nil { 141 log.Fatal(err) 142 } 143 cuLoc := path.Join(gorgoniaLoc, "cuda modules", "src", "*.cu") 144 145 matches, err := filepath.Glob(cuLoc) 146 if err != nil { 147 log.Fatal(err) 148 } 149 150 m := make(map[string][]byte) 151 funcs := make(map[string][]string) 152 for _, match := range matches { 153 name := stripExt(match) 154 data, err := compileCUDA(match, major, minor) 155 if err != nil { 156 log.Fatal(err) 157 } 158 m[name] = data 159 160 // Regex 161 var fns []string 162 matches := funcNameRegex.FindAllSubmatch(data, -1) 163 for _, bs := range matches { 164 fns = append(fns, string(bs[1])) 165 } 166 funcs[name] = fns 167 } 168 169 var buf bytes.Buffer 170 header := fmt.Sprintf(`// Code generated by Gorgonia cudagen. DO NOT EDIT. 171 // +build cuda 172 173 package %v 174 `, packageName) 175 buf.WriteString(header) 176 if ! *sameModule { 177 buf.WriteString("import \"gorgonia.org/gorgonia\"\n") 178 } 179 180 buf.WriteString("func init() {\n") 181 for name := range m { 182 if ! *sameModule { 183 buf.WriteString("gorgonia.") 184 } 185 buf.WriteString(fmt.Sprintf("AddToStdLib(%q, %sPTX, []string{\"%s\"})\n", name, name, strings.Join(funcs[name], "\", \""))) 186 } 187 buf.WriteString("}\n") 188 189 for name, data := range m { 190 buf.WriteString(fmt.Sprintf("const %vPTX = `", name)) 191 buf.Write(data) 192 buf.WriteString("`\n") 193 } 194 195 f, err := os.OpenFile(cudamodules, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644) 196 if err != nil { 197 log.Fatal(err) 198 } 199 defer f.Close() 200 if _, err = buf.WriteTo(f); err != nil { 201 log.Fatalf("unable to write output to %v", cudamodules) 202 } 203 204 if err = gofmt(cudamodules); err != nil { 205 log.Fatal(err) 206 } 207 208 fmt.Printf("Created %v\n", cudamodules) 209 }