gorgonia.org/gorgonia@v0.9.17/cmd/cudagen/main.go (about)

     1  package main
     2  
     3  import (
     4  	"bytes"
     5  	"flag"
     6  	"fmt"
     7  	"io/ioutil"
     8  	"log"
     9  	"os"
    10  	"os/exec"
    11  	"path"
    12  	"path/filepath"
    13  	"regexp"
    14  	"strings"
    15  
    16  	"gorgonia.org/cu"
    17  )
    18  
    19  var debug = flag.Bool("debug", false, "compile with debug mode (-linelinfo is added to nvcc call)")
    20  var sameModule = flag.Bool("same-module", false, "generate a cudamodules.go file which can be placed in the gorgonia dir instead of the application main dir")
    21  
    22  var funcNameRegex = regexp.MustCompile("// .globl	(.+?)\r?\n")
    23  
    24  func stripExt(fullpath string) string {
    25  	_, filename := filepath.Split(fullpath)
    26  	ext := path.Ext(filename)
    27  	return filename[:len(filename)-len(ext)]
    28  }
    29  
    30  func compileCUDA(src string, maj, min int) ([]byte, error) {
    31  	target, err := ioutil.TempFile("", stripExt(src)+"_*.ptx")
    32  	if err != nil {
    33  		return nil, fmt.Errorf("failed to create temporary file for compilation output")
    34  	}
    35  	defer target.Close()
    36  
    37  	output := fmt.Sprintf("-o=%v", target.Name())
    38  	arch := fmt.Sprintf("-arch=compute_%d%d", maj, min)
    39  	var cmd *exec.Cmd
    40  	if *debug {
    41  		cmd = exec.Command("nvcc", output, arch, "-lineinfo", "-ptx", "-Xptxas", "--allow-expensive-optimizations", "-fmad=false", "-ftz=false", "-prec-div=true", "-prec-sqrt=true", src)
    42  	} else {
    43  		cmd = exec.Command("nvcc", output, arch, "-ptx", "-Xptxas", "--allow-expensive-optimizations", "-fmad=false", "-ftz=false", "-prec-div=true", "-prec-sqrt=true", src)
    44  	}
    45  	var stderr bytes.Buffer
    46  	cmd.Stderr = &stderr
    47  	if err := cmd.Run(); err != nil || stderr.Len() != 0 {
    48  		return nil, fmt.Errorf("failed to compile with nvcc. Error: %v. nvcc error: %v", err, stderr.String())
    49  	}
    50  
    51  	out, err := ioutil.ReadAll(target)
    52  	if err != nil {
    53  		return nil, fmt.Errorf("failed to read compilation output file. Error: %v", err)
    54  	}
    55  	if err := os.Remove(target.Name()); err != nil {
    56  		log.Printf("could not remove temporary file %v", target.Name())
    57  	}
    58  	return out, nil
    59  }
    60  
    61  func packageLoc(name string) (string, error) {
    62  	cmd := exec.Command("go", "list", "-f", "{{.Dir}}", "-find", name)
    63  	var stdout, stderr bytes.Buffer
    64  	cmd.Stdout = &stdout
    65  	cmd.Stderr = &stderr
    66  	if err := cmd.Run(); err != nil || stderr.Len() != 0 {
    67  		return "", fmt.Errorf("failed to locate %v. Error: %v. go list error: %v", name, err, stderr.String())
    68  	}
    69  	return strings.TrimSpace(stdout.String()), nil
    70  }
    71  
    72  func packageInWorkingDir() (string, error) {
    73  	cmd := exec.Command("go", "list", "-f", "{{.Name}}")
    74  	var stdout, stderr bytes.Buffer
    75  	cmd.Stdout = &stdout
    76  	cmd.Stderr = &stderr
    77  	if err := cmd.Run(); err != nil || stderr.Len() != 0 {
    78  		return "", fmt.Errorf("failed to get name of package in working directory. Error: %v. go list error: %v", err, stderr.String())
    79  	}
    80  	return strings.TrimSpace(stdout.String()), nil
    81  }
    82  
    83  func gofmt(path string) error {
    84  	cmd := exec.Command("gofmt", "-w", path)
    85  	var stderr bytes.Buffer
    86  	cmd.Stderr = &stderr
    87  	if err := cmd.Run(); err != nil {
    88  		return fmt.Errorf("go imports failed with %v for %q. Error: %v", err, path, stderr.String())
    89  	}
    90  	return nil
    91  }
    92  
    93  func main() {
    94  	flag.Parse()
    95  
    96  	var devices int
    97  	var err error
    98  	if devices, err = cu.NumDevices(); err != nil {
    99  		log.Fatalf("error while finding number of devices: %+v", err)
   100  	}
   101  	if devices == 0 {
   102  		log.Fatal("No CUDA-capable devices found")
   103  	}
   104  
   105  	// Get the lowest possible compute capability
   106  	major := int(^uint(0) >> 1)
   107  	minor := int(^uint(0) >> 1)
   108  	for d := 0; d < devices; d++ {
   109  		var dev cu.Device
   110  		if dev, err = cu.GetDevice(d); err != nil {
   111  			log.Fatalf("Unable to get GPU%d - %+v", d, err)
   112  		}
   113  
   114  		maj, min, err := dev.ComputeCapability()
   115  		if err != nil {
   116  			log.Fatalf("Unable to get compute compatibility of GPU%d - %v", d, err)
   117  		}
   118  		if maj > 0 && maj < major {
   119  			major = maj
   120  			minor = min
   121  			continue
   122  		}
   123  
   124  		if min > 0 && min < minor {
   125  			minor = min
   126  		}
   127  	}
   128  
   129  	cwd, err := os.Getwd()
   130  	if err != nil {
   131  		log.Fatal(err)
   132  	}
   133  	cudamodules := path.Join(cwd, "cudamodules.go")
   134  	packageName, err := packageInWorkingDir()
   135  	if err != nil {
   136  		log.Fatal(err)
   137  	}
   138  
   139  	gorgoniaLoc, err := packageLoc("gorgonia.org/gorgonia")
   140  	if err != nil {
   141  		log.Fatal(err)
   142  	}
   143  	cuLoc := path.Join(gorgoniaLoc, "cuda modules", "src", "*.cu")
   144  
   145  	matches, err := filepath.Glob(cuLoc)
   146  	if err != nil {
   147  		log.Fatal(err)
   148  	}
   149  
   150  	m := make(map[string][]byte)
   151  	funcs := make(map[string][]string)
   152  	for _, match := range matches {
   153  		name := stripExt(match)
   154  		data, err := compileCUDA(match, major, minor)
   155  		if err != nil {
   156  			log.Fatal(err)
   157  		}
   158  		m[name] = data
   159  
   160  		// Regex
   161  		var fns []string
   162  		matches := funcNameRegex.FindAllSubmatch(data, -1)
   163  		for _, bs := range matches {
   164  			fns = append(fns, string(bs[1]))
   165  		}
   166  		funcs[name] = fns
   167  	}
   168  
   169  	var buf bytes.Buffer
   170  	header := fmt.Sprintf(`// Code generated by Gorgonia cudagen. DO NOT EDIT.
   171  // +build cuda
   172  
   173  package %v
   174  `, packageName)
   175  	buf.WriteString(header)
   176  	if ! *sameModule {
   177  		buf.WriteString("import \"gorgonia.org/gorgonia\"\n")
   178  	}
   179  
   180  	buf.WriteString("func init() {\n")
   181  	for name := range m {
   182  		if ! *sameModule {
   183  			buf.WriteString("gorgonia.")
   184  		}
   185  		buf.WriteString(fmt.Sprintf("AddToStdLib(%q, %sPTX, []string{\"%s\"})\n", name, name, strings.Join(funcs[name], "\", \"")))
   186  	}
   187  	buf.WriteString("}\n")
   188  
   189  	for name, data := range m {
   190  		buf.WriteString(fmt.Sprintf("const %vPTX = `", name))
   191  		buf.Write(data)
   192  		buf.WriteString("`\n")
   193  	}
   194  
   195  	f, err := os.OpenFile(cudamodules, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644)
   196  	if err != nil {
   197  		log.Fatal(err)
   198  	}
   199  	defer f.Close()
   200  	if _, err = buf.WriteTo(f); err != nil {
   201  		log.Fatalf("unable to write output to %v", cudamodules)
   202  	}
   203  
   204  	if err = gofmt(cudamodules); err != nil {
   205  		log.Fatal(err)
   206  	}
   207  
   208  	fmt.Printf("Created %v\n", cudamodules)
   209  }