github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/codegen/importer.go (about)

     1  package codegen
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"go/build"
     7  	"io"
     8  	"reflect"
     9  	"sort"
    10  	"strconv"
    11  	"strings"
    12  )
    13  
    14  type ImportPkg struct {
    15  	*build.Package
    16  	Alias string
    17  }
    18  
    19  func (importPkg *ImportPkg) GetID() string {
    20  	if importPkg.Alias != "" {
    21  		return importPkg.Alias
    22  	}
    23  	return importPkg.Name
    24  }
    25  
    26  func (importPkg *ImportPkg) String() string {
    27  	if importPkg.Alias != "" {
    28  		return importPkg.Alias + " " + strconv.Quote(importPkg.ImportPath) + "\n"
    29  	}
    30  	return strconv.Quote(importPkg.ImportPath)
    31  }
    32  
    33  type Importer struct {
    34  	Local string
    35  	pkgs  map[string]*ImportPkg
    36  }
    37  
    38  func getPkgImportPathAndExpose(s string) (pkgImportPath string, expose string) {
    39  	idxSlash := strings.LastIndex(s, "/")
    40  	idxDot := strings.LastIndex(s, ".")
    41  	if idxDot > idxSlash {
    42  		return s[0:idxDot], s[idxDot+1:]
    43  	}
    44  	return s, ""
    45  }
    46  
    47  func (importer *Importer) ExposeVar(name string) string {
    48  	return ToUpperCamelCase(name)
    49  }
    50  
    51  func (importer *Importer) Var(name string) string {
    52  	return ToLowerCamelCase(name)
    53  }
    54  
    55  func (importer *Importer) PureUse(importPath string, subPkgs ...string) string {
    56  	pkgPath, expose := getPkgImportPathAndExpose(strings.Join(append([]string{importPath}, subPkgs...), "/"))
    57  
    58  	importPkg := importer.Import(pkgPath, false)
    59  
    60  	if expose != "" {
    61  		if pkgPath == importer.Local {
    62  			return expose
    63  		}
    64  		return fmt.Sprintf("%s.%s", importPkg.GetID(), expose)
    65  	}
    66  
    67  	return importPkg.GetID()
    68  }
    69  
    70  // use and alias
    71  func (importer *Importer) Use(importPath string, subPkgs ...string) string {
    72  	pkgPath, expose := getPkgImportPathAndExpose(strings.Join(append([]string{importPath}, subPkgs...), "/"))
    73  
    74  	importPkg := importer.Import(pkgPath, true)
    75  
    76  	if expose != "" {
    77  		if pkgPath == importer.Local {
    78  			return expose
    79  		}
    80  		return fmt.Sprintf("%s.%s", importPkg.GetID(), expose)
    81  	}
    82  
    83  	return importPkg.GetID()
    84  }
    85  
    86  func (importer *Importer) Import(importPath string, alias bool) *ImportPkg {
    87  	importPath = DeVendor(importPath)
    88  	if importer.pkgs == nil {
    89  		importer.pkgs = map[string]*ImportPkg{}
    90  	}
    91  
    92  	importPkg, exists := importer.pkgs[importPath]
    93  	if !exists {
    94  		pkg, err := build.Import(importPath, "", build.ImportComment)
    95  		if err != nil {
    96  			panic(err)
    97  		}
    98  		importPkg = &ImportPkg{
    99  			Package: pkg,
   100  		}
   101  		if alias {
   102  			importPkg.Alias = ToLowerSnakeCase(importPath)
   103  		}
   104  		importer.pkgs[importPath] = importPkg
   105  	}
   106  
   107  	return importPkg
   108  }
   109  
   110  func DeVendor(importPath string) string {
   111  	parts := strings.Split(importPath, "/vendor/")
   112  	return parts[len(parts)-1]
   113  }
   114  
   115  func (importer *Importer) WriteToImports(w io.Writer) {
   116  	if len(importer.pkgs) > 0 {
   117  		for _, importPkg := range importer.pkgs {
   118  			io.WriteString(w, importPkg.String()+"\n")
   119  		}
   120  	}
   121  }
   122  
   123  func (importer *Importer) String() string {
   124  	buf := &bytes.Buffer{}
   125  	if len(importer.pkgs) > 0 {
   126  		buf.WriteString("import (\n")
   127  		importer.WriteToImports(buf)
   128  		buf.WriteString(")")
   129  	}
   130  	return buf.String()
   131  }
   132  
   133  func (importer *Importer) Type(tpe reflect.Type) string {
   134  	if tpe.PkgPath() != "" {
   135  		return importer.Use(fmt.Sprintf("%s.%s", tpe.PkgPath(), tpe.Name()))
   136  	}
   137  
   138  	switch tpe.Kind() {
   139  	case reflect.Slice:
   140  		return fmt.Sprintf("[]%s", importer.Type(tpe.Elem()))
   141  	case reflect.Map:
   142  		return fmt.Sprintf("map[%s]%s", importer.Type(tpe.Key()), importer.Type(tpe.Elem()))
   143  	default:
   144  		return tpe.String()
   145  	}
   146  }
   147  
   148  func (importer *Importer) Sdump(v interface{}) string {
   149  	tpe := reflect.TypeOf(v)
   150  	rv := reflect.ValueOf(v)
   151  
   152  	switch rv.Kind() {
   153  	case reflect.Map:
   154  		parts := make([]string, 0)
   155  		isMulti := rv.Len() > 1
   156  		for _, key := range rv.MapKeys() {
   157  			s := importer.Sdump(key.Interface()) + ": " + importer.Sdump(rv.MapIndex(key).Interface())
   158  			if isMulti {
   159  				parts = append(parts, s+",\n")
   160  			} else {
   161  				parts = append(parts, s)
   162  			}
   163  		}
   164  		sort.Strings(parts)
   165  
   166  		if isMulti {
   167  			f := `%s{
   168  				%s
   169  			}`
   170  			return fmt.Sprintf(f, importer.Type(tpe), strings.Join(parts, ""))
   171  		}
   172  		f := "%s{%s}"
   173  		return fmt.Sprintf(f, importer.Type(tpe), strings.Join(parts, ", "))
   174  	case reflect.Slice:
   175  		buf := new(bytes.Buffer)
   176  		for i := 0; i < rv.Len(); i++ {
   177  			s := importer.Sdump(rv.Index(i).Interface())
   178  			if i == 0 {
   179  				buf.WriteString(s)
   180  			} else {
   181  				buf.WriteString(", " + s)
   182  			}
   183  		}
   184  		return fmt.Sprintf("%s{%s}", importer.Type(tpe), buf.String())
   185  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint8:
   186  		return fmt.Sprintf("%d", v)
   187  	case reflect.Bool:
   188  		return strconv.FormatBool(v.(bool))
   189  	case reflect.Float32:
   190  		return strconv.FormatFloat(float64(v.(float32)), 'f', -1, 32)
   191  	case reflect.Float64:
   192  		return strconv.FormatFloat(v.(float64), 'f', -1, 64)
   193  	case reflect.Invalid:
   194  		return "nil"
   195  	case reflect.String:
   196  		return strconv.Quote(v.(string))
   197  	}
   198  	return ""
   199  }