github.com/wzzhu/tensor@v0.9.24/genlib2/main.go (about)

     1  package main
     2  
     3  import (
     4  	"io"
     5  	"io/ioutil"
     6  	"log"
     7  	"os"
     8  	"os/exec"
     9  	"os/user"
    10  	"path"
    11  	"path/filepath"
    12  	"reflect"
    13  	"runtime"
    14  	"strings"
    15  )
    16  
    17  const genmsg = "Code generated by genlib2. DO NOT EDIT."
    18  
    19  var (
    20  	gopath, tensorPkgLoc, nativePkgLoc, execLoc, storageLoc string
    21  )
    22  
    23  type Kinds struct {
    24  	Kinds []reflect.Kind
    25  }
    26  
    27  func init() {
    28  	gopath = os.Getenv("GOPATH")
    29  
    30  	// now that go can have a default gopath, this checks that path
    31  	if gopath == "" {
    32  		usr, err := user.Current()
    33  		if err != nil {
    34  			log.Fatal(err)
    35  		}
    36  		gopath = path.Join(usr.HomeDir, "go")
    37  		stat, err := os.Stat(gopath)
    38  		if err != nil {
    39  			log.Fatal(err)
    40  		}
    41  		if !stat.IsDir() {
    42  			log.Fatal("You need to define a $GOPATH")
    43  		}
    44  	}
    45  	tensorPkgLoc = path.Join(gopath, "src/github.com/wzzhu/tensor")
    46  	nativePkgLoc = path.Join(gopath, "src/github.com/wzzhu/tensor/native")
    47  	execLoc = path.Join(gopath, "src/github.com/wzzhu/tensor/internal/execution")
    48  	storageLoc = path.Join(gopath, "src/github.com/wzzhu/tensor/internal/storage")
    49  }
    50  
    51  func main() {
    52  	pregenerate()
    53  
    54  	// storage
    55  	pipeline(storageLoc, "consts.go", Kinds{allKinds}, generateReflectTypes)
    56  	pipeline(storageLoc, "getset.go", Kinds{allKinds}, generateHeaderGetSet)
    57  	pipeline(tensorPkgLoc, "array_getset.go", Kinds{allKinds}, generateArrayMethods)
    58  
    59  	// execution
    60  	pipeline(execLoc, "generic_arith_vv.go", Kinds{allKinds}, generateGenericVecVecArith)
    61  	pipeline(execLoc, "generic_arith_mixed.go", Kinds{allKinds}, generateGenericMixedArith)
    62  	// pipeline(execLoc, "generic_arith.go", Kinds{allKinds}, generateGenericScalarScalarArith) // generate once and manually edit later
    63  	pipeline(execLoc, "generic_cmp_vv.go", Kinds{allKinds}, generateGenericVecVecCmp)
    64  	pipeline(execLoc, "generic_cmp_mixed.go", Kinds{allKinds}, generateGenericMixedCmp)
    65  	pipeline(execLoc, "generic_minmax.go", Kinds{allKinds}, generateMinMax)
    66  	pipeline(execLoc, "generic_map.go", Kinds{allKinds}, generateGenericMap)
    67  	pipeline(execLoc, "generic_unary.go", Kinds{allKinds}, generateGenericUncondUnary, generateGenericCondUnary, generateSpecialGenericUnaries)
    68  	pipeline(execLoc, "generic_reduce.go", Kinds{allKinds}, generateGenericReduce)
    69  	pipeline(execLoc, "generic_argmethods.go", Kinds{allKinds}, generateGenericArgMethods)
    70  	pipeline(tensorPkgLoc, "generic_utils.go", Kinds{allKinds}, generateUtils)
    71  
    72  	// level 1 aggregation
    73  	pipeline(execLoc, "eng_arith.go", Kinds{allKinds}, generateEArith)
    74  	pipeline(execLoc, "eng_map.go", Kinds{allKinds}, generateEMap)
    75  	pipeline(execLoc, "eng_cmp.go", Kinds{allKinds}, generateECmp)
    76  	pipeline(execLoc, "eng_minmaxbetween.go", Kinds{allKinds}, generateEMinMaxBetween)
    77  	pipeline(execLoc, "eng_reduce.go", Kinds{allKinds}, generateEReduce)
    78  	pipeline(execLoc, "eng_unary.go", Kinds{allKinds}, generateUncondEUnary, generateCondEUnary, generateSpecialEUnaries)
    79  	pipeline(execLoc, "reduction_specialization.go", Kinds{allKinds}, generateReductionSpecialization)
    80  	pipeline(execLoc, "eng_argmethods.go", Kinds{allKinds}, generateInternalEngArgmethods)
    81  
    82  	// level 2 aggregation
    83  	pipeline(tensorPkgLoc, "defaultengine_arith.go", Kinds{allKinds}, generateStdEngArith)
    84  	pipeline(tensorPkgLoc, "defaultengine_cmp.go", Kinds{allKinds}, generateStdEngCmp)
    85  	pipeline(tensorPkgLoc, "defaultengine_unary.go", Kinds{allKinds}, generateStdEngUncondUnary, generateStdEngCondUnary)
    86  	pipeline(tensorPkgLoc, "defaultengine_minmax.go", Kinds{allKinds}, generateStdEngMinMax)
    87  
    88  	// level 3 aggregation
    89  	pipeline(tensorPkgLoc, "dense_arith.go", Kinds{allKinds}, generateDenseArith)
    90  	pipeline(tensorPkgLoc, "dense_cmp.go", Kinds{allKinds}, generateDenseCmp) // generate once, manually edit later
    91  
    92  	// level 4 aggregation
    93  	pipeline(tensorPkgLoc, "api_unary.go", Kinds{allKinds}, generateUncondUnaryAPI, generateCondUnaryAPI, generateSpecialUnaryAPI)
    94  
    95  	// dense methods (old genlib style)
    96  	pipeline(tensorPkgLoc, "dense_generated.go", Kinds{allKinds}, generateDenseConstructionFns)
    97  	pipeline(tensorPkgLoc, "dense_io.go", Kinds{allKinds}, generateDenseIO)
    98  	pipeline(tensorPkgLoc, "dense_compat.go", Kinds{allKinds}, generateDenseCompat)
    99  	pipeline(tensorPkgLoc, "dense_maskcmp_methods.go", Kinds{allKinds}, generateDenseMaskedMethods)
   100  
   101  	// tests
   102  	pipeline(tensorPkgLoc, "test_test.go", Kinds{allKinds}, generateTestUtils)
   103  	pipeline(tensorPkgLoc, "dense_argmethods_test.go", Kinds{allKinds}, generateArgmethodsTests)
   104  	pipeline(tensorPkgLoc, "dense_getset_test.go", Kinds{allKinds}, generateDenseGetSetTests)
   105  
   106  	// old-genlib style tests
   107  	pipeline(tensorPkgLoc, "dense_reduction_test.go", Kinds{allKinds}, generateDenseReductionTests, generateDenseReductionMethodsTests)
   108  	pipeline(tensorPkgLoc, "dense_compat_test.go", Kinds{allKinds}, generateDenseCompatTests)
   109  	pipeline(tensorPkgLoc, "dense_generated_test.go", Kinds{allKinds}, generateDenseConsTests)
   110  	pipeline(tensorPkgLoc, "dense_maskcmp_methods_test.go", Kinds{allKinds}, generateMaskCmpMethodsTests)
   111  
   112  	// qc-style tests
   113  	pipeline(tensorPkgLoc, "api_arith_generated_test.go", Kinds{allKinds}, generateAPIArithTests, generateAPIArithScalarTests)
   114  	pipeline(tensorPkgLoc, "dense_arith_test.go", Kinds{allKinds}, generateDenseMethodArithTests, generateDenseMethodScalarTests)
   115  	pipeline(tensorPkgLoc, "api_unary_generated_test.go", Kinds{allKinds}, generateAPIUnaryTests)
   116  	pipeline(tensorPkgLoc, "api_cmp_generated_test.go", Kinds{allKinds}, generateAPICmpTests, generateAPICmpMixedTests)
   117  	pipeline(tensorPkgLoc, "dense_cmp_test.go", Kinds{allKinds}, generateDenseMethodCmpTests, generateDenseMethodCmpMixedTests)
   118  
   119  	// native iterators
   120  	pipeline(nativePkgLoc, "iterator_native.go", Kinds{allKinds}, generateNativeIterators)
   121  	pipeline(nativePkgLoc, "iterator_native_test.go", Kinds{allKinds}, generateNativeIteratorTests)
   122  	pipeline(nativePkgLoc, "iterator_native2.go", Kinds{allKinds}, generateNativeSelect)
   123  	pipeline(nativePkgLoc, "iterator_native2_test.go", Kinds{allKinds}, generateNativeSelectTests)
   124  }
   125  
   126  func pipeline(pkg, filename string, kinds Kinds, fns ...func(io.Writer, Kinds)) {
   127  	fullpath := path.Join(pkg, filename)
   128  	f, err := os.Create(fullpath)
   129  	if err != nil {
   130  		log.Printf("fullpath %q", fullpath)
   131  		log.Fatal(err)
   132  	}
   133  	defer f.Close()
   134  	writePkgName(f, pkg)
   135  
   136  	for _, fn := range fns {
   137  		fn(f, kinds)
   138  	}
   139  
   140  	// gofmt and goimports this stuff
   141  	cmd := exec.Command("goimports", "-w", fullpath)
   142  	if err = cmd.Run(); err != nil {
   143  		log.Fatalf("Go imports failed with %v for %q", err, fullpath)
   144  	}
   145  
   146  	// account for differences in the postix from the linux sed
   147  	if runtime.GOOS == "darwin" || strings.HasSuffix(runtime.GOOS, "bsd") {
   148  		cmd = exec.Command("sed", "-i", "", `s/github.com\/alecthomas\/assert/github.com\/stretchr\/testify\/assert/g`, fullpath)
   149  	} else {
   150  		cmd = exec.Command("sed", "-E", "-i", `s/github.com\/alecthomas\/assert/github.com\/stretchr\/testify\/assert/g`, fullpath)
   151  	}
   152  	if err = cmd.Run(); err != nil {
   153  		if err.Error() != "exit status 4" { // exit status 4 == not found
   154  			log.Fatalf("sed failed with %v for %q", err.Error(), fullpath)
   155  		}
   156  	}
   157  
   158  	cmd = exec.Command("gofmt", "-s", "-w", fullpath)
   159  	if err = cmd.Run(); err != nil {
   160  		log.Fatalf("Gofmt failed for %q", fullpath)
   161  	}
   162  }
   163  
   164  // pregenerate cleans up all files that were previously generated.
   165  func pregenerate() error {
   166  	if err := cleanup(storageLoc); err != nil {
   167  		return err
   168  	}
   169  	if err := cleanup(execLoc); err != nil {
   170  		return err
   171  	}
   172  	if err := cleanup(nativePkgLoc); err != nil {
   173  		return err
   174  	}
   175  	return cleanup(tensorPkgLoc)
   176  }
   177  
   178  func cleanup(loc string) error {
   179  	pattern := path.Join(loc, "*.go")
   180  	matches, err := filepath.Glob(pattern)
   181  	if err != nil {
   182  		return err
   183  	}
   184  	for _, m := range matches {
   185  		b, err := ioutil.ReadFile(m)
   186  		if err != nil {
   187  			return err
   188  		}
   189  		s := string(b)
   190  		if strings.Contains(s, genmsg) {
   191  			if err := os.Remove(m); err != nil {
   192  				return err
   193  			}
   194  		}
   195  	}
   196  	return nil
   197  }