github.com/juju/juju@v0.0.0-20240430160146-1752b71fcf00/scripts/dqlite/cmd/infer_schema.go (about)

     1  // Copyright 2022 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package main
     5  
     6  import (
     7  	"bytes"
     8  	"flag"
     9  	"fmt"
    10  	"go/ast"
    11  	"go/parser"
    12  	"go/printer"
    13  	"go/token"
    14  	"io"
    15  	"os"
    16  	"path/filepath"
    17  	"strings"
    18  	"unicode"
    19  
    20  	"github.com/juju/collections/set"
    21  	"github.com/juju/loggo"
    22  )
    23  
    24  var (
    25  	jujuStatePkg  = flag.String("juju-pkg-name", "github.com/juju/juju/state", "the pkg name to scan for mongo doc structs")
    26  	erdOutputFile = flag.String("output", "-", "the file (or - for STDOUT) to write the generated ER diagram")
    27  
    28  	logger = loggo.GetLogger("infer_schema")
    29  )
    30  
    31  // structAST represents a parsed struct representing a mongo document.
    32  type structAST struct {
    33  	// The name of the struct identifier.
    34  	Name string
    35  
    36  	// The file where the struct was defined.
    37  	SrcFile string
    38  
    39  	// The AST for the struct.
    40  	Decl *ast.StructType
    41  }
    42  
    43  func main() {
    44  	flag.Parse()
    45  
    46  	if err := inferSchema(); err != nil {
    47  		fmt.Fprintf(os.Stderr, "error: %v\n", err)
    48  		os.Exit(1)
    49  	}
    50  }
    51  
    52  func inferSchema() error {
    53  	structASTs, fset, err := extractMongoDocStructASTs()
    54  	if err != nil {
    55  		return err
    56  	}
    57  
    58  	// Cluster structs by type
    59  	clusters := clusterASTS(structASTs)
    60  
    61  	// Render ERD
    62  	var w io.Writer
    63  	if *erdOutputFile == "-" {
    64  		w = os.Stdout
    65  	} else {
    66  		of, err := os.Open(*erdOutputFile)
    67  		if err != nil {
    68  			return fmt.Errorf("unable to open %q for writing: %v", *erdOutputFile, err)
    69  		}
    70  		w = of
    71  		defer func() { _ = of.Close() }()
    72  	}
    73  	renderERD(w, clusters, fset)
    74  
    75  	return nil
    76  }
    77  
    78  func extractMongoDocStructASTs() ([]structAST, *token.FileSet, error) {
    79  	logger.Infof("parsing files in package %q", *jujuStatePkg)
    80  	fset, pkgInfo, err := loadPkg(*jujuStatePkg)
    81  	if err != nil {
    82  		return nil, nil, fmt.Errorf("unable to resolve type information from package %q: %w", *jujuStatePkg, err)
    83  	}
    84  
    85  	logger.Infof("extracting mongo doc mapping structs")
    86  	structASTs, fset := extractStructASTs(fset, pkgInfo.Files, isMongoDocMapping)
    87  	logger.Infof("extracted %d mongo doc mapping structs", len(structASTs))
    88  
    89  	return structASTs, fset, nil
    90  }
    91  
    92  // loadPkg uses go/loader to compile pkgName (including any of its
    93  // direct and indirect dependencies) and returns back the obtained ASTs and type
    94  // information.
    95  func loadPkg(pkgName string) (*token.FileSet, *ast.Package, error) {
    96  	pathToPkg := filepath.Join(os.Getenv("GOPATH"), "src", pkgName)
    97  
    98  	fset := token.NewFileSet()
    99  	pkgs, err := parser.ParseDir(fset, pathToPkg, func(fi os.FileInfo) bool {
   100  		// Ignore test files
   101  		return !strings.Contains(fi.Name(), "_test")
   102  	}, parser.ParseComments)
   103  	if err != nil {
   104  		return nil, nil, err
   105  	}
   106  
   107  	pkg, found := pkgs[filepath.Base(pkgName)]
   108  	if !found {
   109  		for k := range pkgs {
   110  			fmt.Printf("pkg: %q\n", k)
   111  		}
   112  		return nil, nil, fmt.Errorf("unable to identify package %q contents", pkgName)
   113  	}
   114  
   115  	return fset, pkg, nil
   116  }
   117  
   118  // extractASTs returns a list of struct ASTs within a package that satisfy the
   119  // provided selectFn func.
   120  func extractStructASTs(fset *token.FileSet, fileASTs map[string]*ast.File, selectFn func(*ast.TypeSpec) bool) ([]structAST, *token.FileSet) {
   121  	stripPrefix := filepath.Join(os.Getenv("GOPATH"), "src") + string(filepath.Separator)
   122  
   123  	var structASTs []structAST
   124  	for _, fileAST := range fileASTs {
   125  		for _, decl := range fileAST.Decls {
   126  			genDecl, ok := decl.(*ast.GenDecl)
   127  			if !ok {
   128  				continue
   129  			}
   130  
   131  			for _, spec := range genDecl.Specs {
   132  				typeSpec, ok := spec.(*ast.TypeSpec)
   133  				if !ok {
   134  					continue
   135  				}
   136  
   137  				structType, ok := typeSpec.Type.(*ast.StructType)
   138  				if !ok {
   139  					continue
   140  				}
   141  
   142  				if !selectFn(typeSpec) {
   143  					continue
   144  				}
   145  
   146  				structPos := fset.Position(fileAST.Pos())
   147  				structASTs = append(structASTs,
   148  					structAST{
   149  						Name:    normalizeName(typeSpec.Name.Name),
   150  						SrcFile: strings.TrimPrefix(structPos.Filename, stripPrefix),
   151  						Decl:    structType,
   152  					},
   153  				)
   154  			}
   155  		}
   156  	}
   157  
   158  	return structASTs, fset
   159  }
   160  
   161  func isMongoDocMapping(tspec *ast.TypeSpec) bool {
   162  	name := tspec.Name.Name
   163  	lcIdent := name[0] >= 'a' && name[0] <= 'z'
   164  	return lcIdent && strings.HasSuffix(name, "Doc")
   165  }
   166  
   167  func clusterASTS(structASTs []structAST) map[string][]structAST {
   168  	// Construct a set of possible prefix names
   169  	prefixSet := set.NewStrings()
   170  	for _, str := range structASTs {
   171  		if strings.ContainsRune(str.Name, '_') {
   172  			continue // assume this can't be a prefix
   173  		}
   174  
   175  		prefixSet.Add(str.Name)
   176  	}
   177  
   178  	// Construct a set of possible foreign names
   179  	foreignSet := make(map[string]string)
   180  	for _, str := range structASTs {
   181  		for _, field := range str.Decl.Fields.List {
   182  			if ident, ok := field.Type.(*ast.Ident); ok {
   183  				if !strings.HasSuffix(ident.Name, "Status") {
   184  					continue
   185  				}
   186  
   187  				value := strings.ToLower(strings.TrimSuffix(ident.Name, "Status"))
   188  				if value == "" {
   189  					continue
   190  				}
   191  				foreignSet[str.Name] = value
   192  			}
   193  		}
   194  	}
   195  
   196  	// Group structs sharing each prefix
   197  	clusters := make(map[string][]structAST)
   198  nextStruct:
   199  	for _, str := range structASTs {
   200  		var added bool
   201  		if name, ok := foreignSet[str.Name]; ok {
   202  			clusters[name] = append(clusters[name], str)
   203  			added = true
   204  		}
   205  		if prefixSet.Contains(str.Name) {
   206  			clusters[str.Name] = append(clusters[str.Name], str)
   207  			added = true
   208  		}
   209  
   210  		if added {
   211  			continue
   212  		}
   213  
   214  		// Can we cluster it with any of the prefixes?
   215  		for prefix := range prefixSet {
   216  			if strings.HasPrefix(str.Name, prefix) {
   217  				clusters[prefix] = append(clusters[prefix], str)
   218  				continue nextStruct
   219  			}
   220  		}
   221  
   222  		// Add as a standalone cluster
   223  		clusters[str.Name] = append(clusters[str.Name], str)
   224  	}
   225  
   226  	// Co-locate ASTs that don't have any other ASTs in their cluster
   227  	for key, asts := range clusters {
   228  		if len(asts) != 1 {
   229  			continue
   230  		}
   231  
   232  		clusters[""] = append(clusters[""], asts...)
   233  		delete(clusters, key)
   234  	}
   235  
   236  	return clusters
   237  }
   238  
   239  func renderERD(w io.Writer, clusters map[string][]structAST, fset *token.FileSet) {
   240  	braceEscaper := strings.NewReplacer(
   241  		"{", "\\{",
   242  		"}", "\\}",
   243  	)
   244  
   245  	fmt.Fprintln(w, "graph {")
   246  	for clusterName, structASTs := range clusters {
   247  		if clusterName == "" {
   248  			fmt.Fprintln(w, "  subgraph {")
   249  		} else {
   250  			fmt.Fprintf(w, "  subgraph cluster_%s {\n", clusterName)
   251  			fmt.Fprintf(w, "    color=blue;")
   252  			fmt.Fprintf(w, "    label=\"%s group\";\n", clusterName)
   253  		}
   254  
   255  		prefix := strings.Repeat(" ", 4)
   256  		for _, str := range structASTs {
   257  			fmt.Fprintf(w, "%s# %s\n", prefix, str.SrcFile)
   258  			fmt.Fprintf(w, "%s%s [shape=record, label=<{<b>%s</b><br/>%s", prefix, str.Name, str.Name, str.SrcFile)
   259  			for _, field := range str.Decl.Fields.List {
   260  				if ignoreField(field.Names[0].Name) {
   261  					continue // not needed
   262  				}
   263  				fieldName := normalizeName(field.Names[0].Name)
   264  				fmt.Fprintf(w, ` | %s (%s)`, fieldName, braceEscaper.Replace(fmtType(field.Type, fset)))
   265  			}
   266  			fmt.Fprintf(w, "}>];\n")
   267  		}
   268  
   269  		fmt.Fprintln(w, "  }")
   270  	}
   271  	fmt.Fprintln(w, "}")
   272  }
   273  
   274  func fmtType(typ interface{}, fset *token.FileSet) string {
   275  	var buf bytes.Buffer
   276  	_ = printer.Fprint(&buf, fset, typ)
   277  	return buf.String()
   278  }
   279  
   280  var skipFieldList = []string{
   281  	"modeluuid",
   282  	"revno",
   283  }
   284  
   285  func ignoreField(name string) bool {
   286  	name = strings.ToLower(name)
   287  
   288  	for _, skipField := range skipFieldList {
   289  		if strings.Contains(name, strings.ToLower(skipField)) {
   290  			return true
   291  		}
   292  	}
   293  
   294  	return false
   295  }
   296  
   297  func normalizeName(name string) string {
   298  	var buf bytes.Buffer
   299  
   300  	in := strings.TrimSuffix(name, "Doc")
   301  	for i, r := range in {
   302  		if i > 0 && unicode.IsUpper(r) && !unicode.IsUpper(rune(in[i-1])) {
   303  			buf.WriteRune('_')
   304  		}
   305  		buf.WriteRune(unicode.ToLower(r))
   306  	}
   307  
   308  	return buf.String()
   309  }