github.com/golang/mock@v1.6.0/mockgen/reflect.go (about)

     1  // Copyright 2012 Google Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package main
    16  
    17  // This file contains the model construction by reflection.
    18  
    19  import (
    20  	"bytes"
    21  	"encoding/gob"
    22  	"flag"
    23  	"fmt"
    24  	"go/build"
    25  	"io"
    26  	"io/ioutil"
    27  	"log"
    28  	"os"
    29  	"os/exec"
    30  	"path/filepath"
    31  	"runtime"
    32  	"strings"
    33  	"text/template"
    34  
    35  	"github.com/golang/mock/mockgen/model"
    36  )
    37  
    38  var (
    39  	progOnly   = flag.Bool("prog_only", false, "(reflect mode) Only generate the reflection program; write it to stdout and exit.")
    40  	execOnly   = flag.String("exec_only", "", "(reflect mode) If set, execute this reflection program.")
    41  	buildFlags = flag.String("build_flags", "", "(reflect mode) Additional flags for go build.")
    42  )
    43  
    44  // reflectMode generates mocks via reflection on an interface.
    45  func reflectMode(importPath string, symbols []string) (*model.Package, error) {
    46  	if *execOnly != "" {
    47  		return run(*execOnly)
    48  	}
    49  
    50  	program, err := writeProgram(importPath, symbols)
    51  	if err != nil {
    52  		return nil, err
    53  	}
    54  
    55  	if *progOnly {
    56  		if _, err := os.Stdout.Write(program); err != nil {
    57  			return nil, err
    58  		}
    59  		os.Exit(0)
    60  	}
    61  
    62  	wd, _ := os.Getwd()
    63  
    64  	// Try to run the reflection program  in the current working directory.
    65  	if p, err := runInDir(program, wd); err == nil {
    66  		return p, nil
    67  	}
    68  
    69  	// Try to run the program in the same directory as the input package.
    70  	if p, err := build.Import(importPath, wd, build.FindOnly); err == nil {
    71  		dir := p.Dir
    72  		if p, err := runInDir(program, dir); err == nil {
    73  			return p, nil
    74  		}
    75  	}
    76  
    77  	// Try to run it in a standard temp directory.
    78  	return runInDir(program, "")
    79  }
    80  
    81  func writeProgram(importPath string, symbols []string) ([]byte, error) {
    82  	var program bytes.Buffer
    83  	data := reflectData{
    84  		ImportPath: importPath,
    85  		Symbols:    symbols,
    86  	}
    87  	if err := reflectProgram.Execute(&program, &data); err != nil {
    88  		return nil, err
    89  	}
    90  	return program.Bytes(), nil
    91  }
    92  
    93  // run the given program and parse the output as a model.Package.
    94  func run(program string) (*model.Package, error) {
    95  	f, err := ioutil.TempFile("", "")
    96  	if err != nil {
    97  		return nil, err
    98  	}
    99  
   100  	filename := f.Name()
   101  	defer os.Remove(filename)
   102  	if err := f.Close(); err != nil {
   103  		return nil, err
   104  	}
   105  
   106  	// Run the program.
   107  	cmd := exec.Command(program, "-output", filename)
   108  	cmd.Stdout = os.Stdout
   109  	cmd.Stderr = os.Stderr
   110  	if err := cmd.Run(); err != nil {
   111  		return nil, err
   112  	}
   113  
   114  	f, err = os.Open(filename)
   115  	if err != nil {
   116  		return nil, err
   117  	}
   118  
   119  	// Process output.
   120  	var pkg model.Package
   121  	if err := gob.NewDecoder(f).Decode(&pkg); err != nil {
   122  		return nil, err
   123  	}
   124  
   125  	if err := f.Close(); err != nil {
   126  		return nil, err
   127  	}
   128  
   129  	return &pkg, nil
   130  }
   131  
   132  // runInDir writes the given program into the given dir, runs it there, and
   133  // parses the output as a model.Package.
   134  func runInDir(program []byte, dir string) (*model.Package, error) {
   135  	// We use TempDir instead of TempFile so we can control the filename.
   136  	tmpDir, err := ioutil.TempDir(dir, "gomock_reflect_")
   137  	if err != nil {
   138  		return nil, err
   139  	}
   140  	defer func() {
   141  		if err := os.RemoveAll(tmpDir); err != nil {
   142  			log.Printf("failed to remove temp directory: %s", err)
   143  		}
   144  	}()
   145  	const progSource = "prog.go"
   146  	var progBinary = "prog.bin"
   147  	if runtime.GOOS == "windows" {
   148  		// Windows won't execute a program unless it has a ".exe" suffix.
   149  		progBinary += ".exe"
   150  	}
   151  
   152  	if err := ioutil.WriteFile(filepath.Join(tmpDir, progSource), program, 0600); err != nil {
   153  		return nil, err
   154  	}
   155  
   156  	cmdArgs := []string{}
   157  	cmdArgs = append(cmdArgs, "build")
   158  	if *buildFlags != "" {
   159  		cmdArgs = append(cmdArgs, strings.Split(*buildFlags, " ")...)
   160  	}
   161  	cmdArgs = append(cmdArgs, "-o", progBinary, progSource)
   162  
   163  	// Build the program.
   164  	buf := bytes.NewBuffer(nil)
   165  	cmd := exec.Command("go", cmdArgs...)
   166  	cmd.Dir = tmpDir
   167  	cmd.Stdout = os.Stdout
   168  	cmd.Stderr = io.MultiWriter(os.Stderr, buf)
   169  	if err := cmd.Run(); err != nil {
   170  		sErr := buf.String()
   171  		if strings.Contains(sErr, `cannot find package "."`) &&
   172  			strings.Contains(sErr, "github.com/golang/mock/mockgen/model") {
   173  			fmt.Fprint(os.Stderr, "Please reference the steps in the README to fix this error:\n\thttps://github.com/golang/mock#reflect-vendoring-error.")
   174  			return nil, err
   175  		}
   176  		return nil, err
   177  	}
   178  
   179  	return run(filepath.Join(tmpDir, progBinary))
   180  }
   181  
   182  type reflectData struct {
   183  	ImportPath string
   184  	Symbols    []string
   185  }
   186  
   187  // This program reflects on an interface value, and prints the
   188  // gob encoding of a model.Package to standard output.
   189  // JSON doesn't work because of the model.Type interface.
   190  var reflectProgram = template.Must(template.New("program").Parse(`
   191  package main
   192  
   193  import (
   194  	"encoding/gob"
   195  	"flag"
   196  	"fmt"
   197  	"os"
   198  	"path"
   199  	"reflect"
   200  
   201  	"github.com/golang/mock/mockgen/model"
   202  
   203  	pkg_ {{printf "%q" .ImportPath}}
   204  )
   205  
   206  var output = flag.String("output", "", "The output file name, or empty to use stdout.")
   207  
   208  func main() {
   209  	flag.Parse()
   210  
   211  	its := []struct{
   212  		sym string
   213  		typ reflect.Type
   214  	}{
   215  		{{range .Symbols}}
   216  		{ {{printf "%q" .}}, reflect.TypeOf((*pkg_.{{.}})(nil)).Elem()},
   217  		{{end}}
   218  	}
   219  	pkg := &model.Package{
   220  		// NOTE: This behaves contrary to documented behaviour if the
   221  		// package name is not the final component of the import path.
   222  		// The reflect package doesn't expose the package name, though.
   223  		Name: path.Base({{printf "%q" .ImportPath}}),
   224  	}
   225  
   226  	for _, it := range its {
   227  		intf, err := model.InterfaceFromInterfaceType(it.typ)
   228  		if err != nil {
   229  			fmt.Fprintf(os.Stderr, "Reflection: %v\n", err)
   230  			os.Exit(1)
   231  		}
   232  		intf.Name = it.sym
   233  		pkg.Interfaces = append(pkg.Interfaces, intf)
   234  	}
   235  
   236  	outfile := os.Stdout
   237  	if len(*output) != 0 {
   238  		var err error
   239  		outfile, err = os.Create(*output)
   240  		if err != nil {
   241  			fmt.Fprintf(os.Stderr, "failed to open output file %q", *output)
   242  		}
   243  		defer func() {
   244  			if err := outfile.Close(); err != nil {
   245  				fmt.Fprintf(os.Stderr, "failed to close output file %q", *output)
   246  				os.Exit(1)
   247  			}
   248  		}()
   249  	}
   250  
   251  	if err := gob.NewEncoder(outfile).Encode(pkg); err != nil {
   252  		fmt.Fprintf(os.Stderr, "gob encode: %v\n", err)
   253  		os.Exit(1)
   254  	}
   255  }
   256  `))