github.com/wzzhu/tensor@v0.9.24/genlib2/main.go (about) 1 package main 2 3 import ( 4 "io" 5 "io/ioutil" 6 "log" 7 "os" 8 "os/exec" 9 "os/user" 10 "path" 11 "path/filepath" 12 "reflect" 13 "runtime" 14 "strings" 15 ) 16 17 const genmsg = "Code generated by genlib2. DO NOT EDIT." 18 19 var ( 20 gopath, tensorPkgLoc, nativePkgLoc, execLoc, storageLoc string 21 ) 22 23 type Kinds struct { 24 Kinds []reflect.Kind 25 } 26 27 func init() { 28 gopath = os.Getenv("GOPATH") 29 30 // now that go can have a default gopath, this checks that path 31 if gopath == "" { 32 usr, err := user.Current() 33 if err != nil { 34 log.Fatal(err) 35 } 36 gopath = path.Join(usr.HomeDir, "go") 37 stat, err := os.Stat(gopath) 38 if err != nil { 39 log.Fatal(err) 40 } 41 if !stat.IsDir() { 42 log.Fatal("You need to define a $GOPATH") 43 } 44 } 45 tensorPkgLoc = path.Join(gopath, "src/github.com/wzzhu/tensor") 46 nativePkgLoc = path.Join(gopath, "src/github.com/wzzhu/tensor/native") 47 execLoc = path.Join(gopath, "src/github.com/wzzhu/tensor/internal/execution") 48 storageLoc = path.Join(gopath, "src/github.com/wzzhu/tensor/internal/storage") 49 } 50 51 func main() { 52 pregenerate() 53 54 // storage 55 pipeline(storageLoc, "consts.go", Kinds{allKinds}, generateReflectTypes) 56 pipeline(storageLoc, "getset.go", Kinds{allKinds}, generateHeaderGetSet) 57 pipeline(tensorPkgLoc, "array_getset.go", Kinds{allKinds}, generateArrayMethods) 58 59 // execution 60 pipeline(execLoc, "generic_arith_vv.go", Kinds{allKinds}, generateGenericVecVecArith) 61 pipeline(execLoc, "generic_arith_mixed.go", Kinds{allKinds}, generateGenericMixedArith) 62 // pipeline(execLoc, "generic_arith.go", Kinds{allKinds}, generateGenericScalarScalarArith) // generate once and manually edit later 63 pipeline(execLoc, "generic_cmp_vv.go", Kinds{allKinds}, generateGenericVecVecCmp) 64 pipeline(execLoc, "generic_cmp_mixed.go", Kinds{allKinds}, generateGenericMixedCmp) 65 pipeline(execLoc, "generic_minmax.go", Kinds{allKinds}, generateMinMax) 66 pipeline(execLoc, "generic_map.go", Kinds{allKinds}, generateGenericMap) 67 pipeline(execLoc, "generic_unary.go", Kinds{allKinds}, generateGenericUncondUnary, generateGenericCondUnary, generateSpecialGenericUnaries) 68 pipeline(execLoc, "generic_reduce.go", Kinds{allKinds}, generateGenericReduce) 69 pipeline(execLoc, "generic_argmethods.go", Kinds{allKinds}, generateGenericArgMethods) 70 pipeline(tensorPkgLoc, "generic_utils.go", Kinds{allKinds}, generateUtils) 71 72 // level 1 aggregation 73 pipeline(execLoc, "eng_arith.go", Kinds{allKinds}, generateEArith) 74 pipeline(execLoc, "eng_map.go", Kinds{allKinds}, generateEMap) 75 pipeline(execLoc, "eng_cmp.go", Kinds{allKinds}, generateECmp) 76 pipeline(execLoc, "eng_minmaxbetween.go", Kinds{allKinds}, generateEMinMaxBetween) 77 pipeline(execLoc, "eng_reduce.go", Kinds{allKinds}, generateEReduce) 78 pipeline(execLoc, "eng_unary.go", Kinds{allKinds}, generateUncondEUnary, generateCondEUnary, generateSpecialEUnaries) 79 pipeline(execLoc, "reduction_specialization.go", Kinds{allKinds}, generateReductionSpecialization) 80 pipeline(execLoc, "eng_argmethods.go", Kinds{allKinds}, generateInternalEngArgmethods) 81 82 // level 2 aggregation 83 pipeline(tensorPkgLoc, "defaultengine_arith.go", Kinds{allKinds}, generateStdEngArith) 84 pipeline(tensorPkgLoc, "defaultengine_cmp.go", Kinds{allKinds}, generateStdEngCmp) 85 pipeline(tensorPkgLoc, "defaultengine_unary.go", Kinds{allKinds}, generateStdEngUncondUnary, generateStdEngCondUnary) 86 pipeline(tensorPkgLoc, "defaultengine_minmax.go", Kinds{allKinds}, generateStdEngMinMax) 87 88 // level 3 aggregation 89 pipeline(tensorPkgLoc, "dense_arith.go", Kinds{allKinds}, generateDenseArith) 90 pipeline(tensorPkgLoc, "dense_cmp.go", Kinds{allKinds}, generateDenseCmp) // generate once, manually edit later 91 92 // level 4 aggregation 93 pipeline(tensorPkgLoc, "api_unary.go", Kinds{allKinds}, generateUncondUnaryAPI, generateCondUnaryAPI, generateSpecialUnaryAPI) 94 95 // dense methods (old genlib style) 96 pipeline(tensorPkgLoc, "dense_generated.go", Kinds{allKinds}, generateDenseConstructionFns) 97 pipeline(tensorPkgLoc, "dense_io.go", Kinds{allKinds}, generateDenseIO) 98 pipeline(tensorPkgLoc, "dense_compat.go", Kinds{allKinds}, generateDenseCompat) 99 pipeline(tensorPkgLoc, "dense_maskcmp_methods.go", Kinds{allKinds}, generateDenseMaskedMethods) 100 101 // tests 102 pipeline(tensorPkgLoc, "test_test.go", Kinds{allKinds}, generateTestUtils) 103 pipeline(tensorPkgLoc, "dense_argmethods_test.go", Kinds{allKinds}, generateArgmethodsTests) 104 pipeline(tensorPkgLoc, "dense_getset_test.go", Kinds{allKinds}, generateDenseGetSetTests) 105 106 // old-genlib style tests 107 pipeline(tensorPkgLoc, "dense_reduction_test.go", Kinds{allKinds}, generateDenseReductionTests, generateDenseReductionMethodsTests) 108 pipeline(tensorPkgLoc, "dense_compat_test.go", Kinds{allKinds}, generateDenseCompatTests) 109 pipeline(tensorPkgLoc, "dense_generated_test.go", Kinds{allKinds}, generateDenseConsTests) 110 pipeline(tensorPkgLoc, "dense_maskcmp_methods_test.go", Kinds{allKinds}, generateMaskCmpMethodsTests) 111 112 // qc-style tests 113 pipeline(tensorPkgLoc, "api_arith_generated_test.go", Kinds{allKinds}, generateAPIArithTests, generateAPIArithScalarTests) 114 pipeline(tensorPkgLoc, "dense_arith_test.go", Kinds{allKinds}, generateDenseMethodArithTests, generateDenseMethodScalarTests) 115 pipeline(tensorPkgLoc, "api_unary_generated_test.go", Kinds{allKinds}, generateAPIUnaryTests) 116 pipeline(tensorPkgLoc, "api_cmp_generated_test.go", Kinds{allKinds}, generateAPICmpTests, generateAPICmpMixedTests) 117 pipeline(tensorPkgLoc, "dense_cmp_test.go", Kinds{allKinds}, generateDenseMethodCmpTests, generateDenseMethodCmpMixedTests) 118 119 // native iterators 120 pipeline(nativePkgLoc, "iterator_native.go", Kinds{allKinds}, generateNativeIterators) 121 pipeline(nativePkgLoc, "iterator_native_test.go", Kinds{allKinds}, generateNativeIteratorTests) 122 pipeline(nativePkgLoc, "iterator_native2.go", Kinds{allKinds}, generateNativeSelect) 123 pipeline(nativePkgLoc, "iterator_native2_test.go", Kinds{allKinds}, generateNativeSelectTests) 124 } 125 126 func pipeline(pkg, filename string, kinds Kinds, fns ...func(io.Writer, Kinds)) { 127 fullpath := path.Join(pkg, filename) 128 f, err := os.Create(fullpath) 129 if err != nil { 130 log.Printf("fullpath %q", fullpath) 131 log.Fatal(err) 132 } 133 defer f.Close() 134 writePkgName(f, pkg) 135 136 for _, fn := range fns { 137 fn(f, kinds) 138 } 139 140 // gofmt and goimports this stuff 141 cmd := exec.Command("goimports", "-w", fullpath) 142 if err = cmd.Run(); err != nil { 143 log.Fatalf("Go imports failed with %v for %q", err, fullpath) 144 } 145 146 // account for differences in the postix from the linux sed 147 if runtime.GOOS == "darwin" || strings.HasSuffix(runtime.GOOS, "bsd") { 148 cmd = exec.Command("sed", "-i", "", `s/github.com\/alecthomas\/assert/github.com\/stretchr\/testify\/assert/g`, fullpath) 149 } else { 150 cmd = exec.Command("sed", "-E", "-i", `s/github.com\/alecthomas\/assert/github.com\/stretchr\/testify\/assert/g`, fullpath) 151 } 152 if err = cmd.Run(); err != nil { 153 if err.Error() != "exit status 4" { // exit status 4 == not found 154 log.Fatalf("sed failed with %v for %q", err.Error(), fullpath) 155 } 156 } 157 158 cmd = exec.Command("gofmt", "-s", "-w", fullpath) 159 if err = cmd.Run(); err != nil { 160 log.Fatalf("Gofmt failed for %q", fullpath) 161 } 162 } 163 164 // pregenerate cleans up all files that were previously generated. 165 func pregenerate() error { 166 if err := cleanup(storageLoc); err != nil { 167 return err 168 } 169 if err := cleanup(execLoc); err != nil { 170 return err 171 } 172 if err := cleanup(nativePkgLoc); err != nil { 173 return err 174 } 175 return cleanup(tensorPkgLoc) 176 } 177 178 func cleanup(loc string) error { 179 pattern := path.Join(loc, "*.go") 180 matches, err := filepath.Glob(pattern) 181 if err != nil { 182 return err 183 } 184 for _, m := range matches { 185 b, err := ioutil.ReadFile(m) 186 if err != nil { 187 return err 188 } 189 s := string(b) 190 if strings.Contains(s, genmsg) { 191 if err := os.Remove(m); err != nil { 192 return err 193 } 194 } 195 } 196 return nil 197 }