github.com/golang/mock@v1.6.0/mockgen/reflect.go (about) 1 // Copyright 2012 Google Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package main 16 17 // This file contains the model construction by reflection. 18 19 import ( 20 "bytes" 21 "encoding/gob" 22 "flag" 23 "fmt" 24 "go/build" 25 "io" 26 "io/ioutil" 27 "log" 28 "os" 29 "os/exec" 30 "path/filepath" 31 "runtime" 32 "strings" 33 "text/template" 34 35 "github.com/golang/mock/mockgen/model" 36 ) 37 38 var ( 39 progOnly = flag.Bool("prog_only", false, "(reflect mode) Only generate the reflection program; write it to stdout and exit.") 40 execOnly = flag.String("exec_only", "", "(reflect mode) If set, execute this reflection program.") 41 buildFlags = flag.String("build_flags", "", "(reflect mode) Additional flags for go build.") 42 ) 43 44 // reflectMode generates mocks via reflection on an interface. 45 func reflectMode(importPath string, symbols []string) (*model.Package, error) { 46 if *execOnly != "" { 47 return run(*execOnly) 48 } 49 50 program, err := writeProgram(importPath, symbols) 51 if err != nil { 52 return nil, err 53 } 54 55 if *progOnly { 56 if _, err := os.Stdout.Write(program); err != nil { 57 return nil, err 58 } 59 os.Exit(0) 60 } 61 62 wd, _ := os.Getwd() 63 64 // Try to run the reflection program in the current working directory. 65 if p, err := runInDir(program, wd); err == nil { 66 return p, nil 67 } 68 69 // Try to run the program in the same directory as the input package. 70 if p, err := build.Import(importPath, wd, build.FindOnly); err == nil { 71 dir := p.Dir 72 if p, err := runInDir(program, dir); err == nil { 73 return p, nil 74 } 75 } 76 77 // Try to run it in a standard temp directory. 78 return runInDir(program, "") 79 } 80 81 func writeProgram(importPath string, symbols []string) ([]byte, error) { 82 var program bytes.Buffer 83 data := reflectData{ 84 ImportPath: importPath, 85 Symbols: symbols, 86 } 87 if err := reflectProgram.Execute(&program, &data); err != nil { 88 return nil, err 89 } 90 return program.Bytes(), nil 91 } 92 93 // run the given program and parse the output as a model.Package. 94 func run(program string) (*model.Package, error) { 95 f, err := ioutil.TempFile("", "") 96 if err != nil { 97 return nil, err 98 } 99 100 filename := f.Name() 101 defer os.Remove(filename) 102 if err := f.Close(); err != nil { 103 return nil, err 104 } 105 106 // Run the program. 107 cmd := exec.Command(program, "-output", filename) 108 cmd.Stdout = os.Stdout 109 cmd.Stderr = os.Stderr 110 if err := cmd.Run(); err != nil { 111 return nil, err 112 } 113 114 f, err = os.Open(filename) 115 if err != nil { 116 return nil, err 117 } 118 119 // Process output. 120 var pkg model.Package 121 if err := gob.NewDecoder(f).Decode(&pkg); err != nil { 122 return nil, err 123 } 124 125 if err := f.Close(); err != nil { 126 return nil, err 127 } 128 129 return &pkg, nil 130 } 131 132 // runInDir writes the given program into the given dir, runs it there, and 133 // parses the output as a model.Package. 134 func runInDir(program []byte, dir string) (*model.Package, error) { 135 // We use TempDir instead of TempFile so we can control the filename. 136 tmpDir, err := ioutil.TempDir(dir, "gomock_reflect_") 137 if err != nil { 138 return nil, err 139 } 140 defer func() { 141 if err := os.RemoveAll(tmpDir); err != nil { 142 log.Printf("failed to remove temp directory: %s", err) 143 } 144 }() 145 const progSource = "prog.go" 146 var progBinary = "prog.bin" 147 if runtime.GOOS == "windows" { 148 // Windows won't execute a program unless it has a ".exe" suffix. 149 progBinary += ".exe" 150 } 151 152 if err := ioutil.WriteFile(filepath.Join(tmpDir, progSource), program, 0600); err != nil { 153 return nil, err 154 } 155 156 cmdArgs := []string{} 157 cmdArgs = append(cmdArgs, "build") 158 if *buildFlags != "" { 159 cmdArgs = append(cmdArgs, strings.Split(*buildFlags, " ")...) 160 } 161 cmdArgs = append(cmdArgs, "-o", progBinary, progSource) 162 163 // Build the program. 164 buf := bytes.NewBuffer(nil) 165 cmd := exec.Command("go", cmdArgs...) 166 cmd.Dir = tmpDir 167 cmd.Stdout = os.Stdout 168 cmd.Stderr = io.MultiWriter(os.Stderr, buf) 169 if err := cmd.Run(); err != nil { 170 sErr := buf.String() 171 if strings.Contains(sErr, `cannot find package "."`) && 172 strings.Contains(sErr, "github.com/golang/mock/mockgen/model") { 173 fmt.Fprint(os.Stderr, "Please reference the steps in the README to fix this error:\n\thttps://github.com/golang/mock#reflect-vendoring-error.") 174 return nil, err 175 } 176 return nil, err 177 } 178 179 return run(filepath.Join(tmpDir, progBinary)) 180 } 181 182 type reflectData struct { 183 ImportPath string 184 Symbols []string 185 } 186 187 // This program reflects on an interface value, and prints the 188 // gob encoding of a model.Package to standard output. 189 // JSON doesn't work because of the model.Type interface. 190 var reflectProgram = template.Must(template.New("program").Parse(` 191 package main 192 193 import ( 194 "encoding/gob" 195 "flag" 196 "fmt" 197 "os" 198 "path" 199 "reflect" 200 201 "github.com/golang/mock/mockgen/model" 202 203 pkg_ {{printf "%q" .ImportPath}} 204 ) 205 206 var output = flag.String("output", "", "The output file name, or empty to use stdout.") 207 208 func main() { 209 flag.Parse() 210 211 its := []struct{ 212 sym string 213 typ reflect.Type 214 }{ 215 {{range .Symbols}} 216 { {{printf "%q" .}}, reflect.TypeOf((*pkg_.{{.}})(nil)).Elem()}, 217 {{end}} 218 } 219 pkg := &model.Package{ 220 // NOTE: This behaves contrary to documented behaviour if the 221 // package name is not the final component of the import path. 222 // The reflect package doesn't expose the package name, though. 223 Name: path.Base({{printf "%q" .ImportPath}}), 224 } 225 226 for _, it := range its { 227 intf, err := model.InterfaceFromInterfaceType(it.typ) 228 if err != nil { 229 fmt.Fprintf(os.Stderr, "Reflection: %v\n", err) 230 os.Exit(1) 231 } 232 intf.Name = it.sym 233 pkg.Interfaces = append(pkg.Interfaces, intf) 234 } 235 236 outfile := os.Stdout 237 if len(*output) != 0 { 238 var err error 239 outfile, err = os.Create(*output) 240 if err != nil { 241 fmt.Fprintf(os.Stderr, "failed to open output file %q", *output) 242 } 243 defer func() { 244 if err := outfile.Close(); err != nil { 245 fmt.Fprintf(os.Stderr, "failed to close output file %q", *output) 246 os.Exit(1) 247 } 248 }() 249 } 250 251 if err := gob.NewEncoder(outfile).Encode(pkg); err != nil { 252 fmt.Fprintf(os.Stderr, "gob encode: %v\n", err) 253 os.Exit(1) 254 } 255 } 256 `))