github.com/juju/juju@v0.0.0-20240430160146-1752b71fcf00/scripts/dqlite/cmd/infer_schema.go (about) 1 // Copyright 2022 Canonical Ltd. 2 // Licensed under the AGPLv3, see LICENCE file for details. 3 4 package main 5 6 import ( 7 "bytes" 8 "flag" 9 "fmt" 10 "go/ast" 11 "go/parser" 12 "go/printer" 13 "go/token" 14 "io" 15 "os" 16 "path/filepath" 17 "strings" 18 "unicode" 19 20 "github.com/juju/collections/set" 21 "github.com/juju/loggo" 22 ) 23 24 var ( 25 jujuStatePkg = flag.String("juju-pkg-name", "github.com/juju/juju/state", "the pkg name to scan for mongo doc structs") 26 erdOutputFile = flag.String("output", "-", "the file (or - for STDOUT) to write the generated ER diagram") 27 28 logger = loggo.GetLogger("infer_schema") 29 ) 30 31 // structAST represents a parsed struct representing a mongo document. 32 type structAST struct { 33 // The name of the struct identifier. 34 Name string 35 36 // The file where the struct was defined. 37 SrcFile string 38 39 // The AST for the struct. 40 Decl *ast.StructType 41 } 42 43 func main() { 44 flag.Parse() 45 46 if err := inferSchema(); err != nil { 47 fmt.Fprintf(os.Stderr, "error: %v\n", err) 48 os.Exit(1) 49 } 50 } 51 52 func inferSchema() error { 53 structASTs, fset, err := extractMongoDocStructASTs() 54 if err != nil { 55 return err 56 } 57 58 // Cluster structs by type 59 clusters := clusterASTS(structASTs) 60 61 // Render ERD 62 var w io.Writer 63 if *erdOutputFile == "-" { 64 w = os.Stdout 65 } else { 66 of, err := os.Open(*erdOutputFile) 67 if err != nil { 68 return fmt.Errorf("unable to open %q for writing: %v", *erdOutputFile, err) 69 } 70 w = of 71 defer func() { _ = of.Close() }() 72 } 73 renderERD(w, clusters, fset) 74 75 return nil 76 } 77 78 func extractMongoDocStructASTs() ([]structAST, *token.FileSet, error) { 79 logger.Infof("parsing files in package %q", *jujuStatePkg) 80 fset, pkgInfo, err := loadPkg(*jujuStatePkg) 81 if err != nil { 82 return nil, nil, fmt.Errorf("unable to resolve type information from package %q: %w", *jujuStatePkg, err) 83 } 84 85 logger.Infof("extracting mongo doc mapping structs") 86 structASTs, fset := extractStructASTs(fset, pkgInfo.Files, isMongoDocMapping) 87 logger.Infof("extracted %d mongo doc mapping structs", len(structASTs)) 88 89 return structASTs, fset, nil 90 } 91 92 // loadPkg uses go/loader to compile pkgName (including any of its 93 // direct and indirect dependencies) and returns back the obtained ASTs and type 94 // information. 95 func loadPkg(pkgName string) (*token.FileSet, *ast.Package, error) { 96 pathToPkg := filepath.Join(os.Getenv("GOPATH"), "src", pkgName) 97 98 fset := token.NewFileSet() 99 pkgs, err := parser.ParseDir(fset, pathToPkg, func(fi os.FileInfo) bool { 100 // Ignore test files 101 return !strings.Contains(fi.Name(), "_test") 102 }, parser.ParseComments) 103 if err != nil { 104 return nil, nil, err 105 } 106 107 pkg, found := pkgs[filepath.Base(pkgName)] 108 if !found { 109 for k := range pkgs { 110 fmt.Printf("pkg: %q\n", k) 111 } 112 return nil, nil, fmt.Errorf("unable to identify package %q contents", pkgName) 113 } 114 115 return fset, pkg, nil 116 } 117 118 // extractASTs returns a list of struct ASTs within a package that satisfy the 119 // provided selectFn func. 120 func extractStructASTs(fset *token.FileSet, fileASTs map[string]*ast.File, selectFn func(*ast.TypeSpec) bool) ([]structAST, *token.FileSet) { 121 stripPrefix := filepath.Join(os.Getenv("GOPATH"), "src") + string(filepath.Separator) 122 123 var structASTs []structAST 124 for _, fileAST := range fileASTs { 125 for _, decl := range fileAST.Decls { 126 genDecl, ok := decl.(*ast.GenDecl) 127 if !ok { 128 continue 129 } 130 131 for _, spec := range genDecl.Specs { 132 typeSpec, ok := spec.(*ast.TypeSpec) 133 if !ok { 134 continue 135 } 136 137 structType, ok := typeSpec.Type.(*ast.StructType) 138 if !ok { 139 continue 140 } 141 142 if !selectFn(typeSpec) { 143 continue 144 } 145 146 structPos := fset.Position(fileAST.Pos()) 147 structASTs = append(structASTs, 148 structAST{ 149 Name: normalizeName(typeSpec.Name.Name), 150 SrcFile: strings.TrimPrefix(structPos.Filename, stripPrefix), 151 Decl: structType, 152 }, 153 ) 154 } 155 } 156 } 157 158 return structASTs, fset 159 } 160 161 func isMongoDocMapping(tspec *ast.TypeSpec) bool { 162 name := tspec.Name.Name 163 lcIdent := name[0] >= 'a' && name[0] <= 'z' 164 return lcIdent && strings.HasSuffix(name, "Doc") 165 } 166 167 func clusterASTS(structASTs []structAST) map[string][]structAST { 168 // Construct a set of possible prefix names 169 prefixSet := set.NewStrings() 170 for _, str := range structASTs { 171 if strings.ContainsRune(str.Name, '_') { 172 continue // assume this can't be a prefix 173 } 174 175 prefixSet.Add(str.Name) 176 } 177 178 // Construct a set of possible foreign names 179 foreignSet := make(map[string]string) 180 for _, str := range structASTs { 181 for _, field := range str.Decl.Fields.List { 182 if ident, ok := field.Type.(*ast.Ident); ok { 183 if !strings.HasSuffix(ident.Name, "Status") { 184 continue 185 } 186 187 value := strings.ToLower(strings.TrimSuffix(ident.Name, "Status")) 188 if value == "" { 189 continue 190 } 191 foreignSet[str.Name] = value 192 } 193 } 194 } 195 196 // Group structs sharing each prefix 197 clusters := make(map[string][]structAST) 198 nextStruct: 199 for _, str := range structASTs { 200 var added bool 201 if name, ok := foreignSet[str.Name]; ok { 202 clusters[name] = append(clusters[name], str) 203 added = true 204 } 205 if prefixSet.Contains(str.Name) { 206 clusters[str.Name] = append(clusters[str.Name], str) 207 added = true 208 } 209 210 if added { 211 continue 212 } 213 214 // Can we cluster it with any of the prefixes? 215 for prefix := range prefixSet { 216 if strings.HasPrefix(str.Name, prefix) { 217 clusters[prefix] = append(clusters[prefix], str) 218 continue nextStruct 219 } 220 } 221 222 // Add as a standalone cluster 223 clusters[str.Name] = append(clusters[str.Name], str) 224 } 225 226 // Co-locate ASTs that don't have any other ASTs in their cluster 227 for key, asts := range clusters { 228 if len(asts) != 1 { 229 continue 230 } 231 232 clusters[""] = append(clusters[""], asts...) 233 delete(clusters, key) 234 } 235 236 return clusters 237 } 238 239 func renderERD(w io.Writer, clusters map[string][]structAST, fset *token.FileSet) { 240 braceEscaper := strings.NewReplacer( 241 "{", "\\{", 242 "}", "\\}", 243 ) 244 245 fmt.Fprintln(w, "graph {") 246 for clusterName, structASTs := range clusters { 247 if clusterName == "" { 248 fmt.Fprintln(w, " subgraph {") 249 } else { 250 fmt.Fprintf(w, " subgraph cluster_%s {\n", clusterName) 251 fmt.Fprintf(w, " color=blue;") 252 fmt.Fprintf(w, " label=\"%s group\";\n", clusterName) 253 } 254 255 prefix := strings.Repeat(" ", 4) 256 for _, str := range structASTs { 257 fmt.Fprintf(w, "%s# %s\n", prefix, str.SrcFile) 258 fmt.Fprintf(w, "%s%s [shape=record, label=<{<b>%s</b><br/>%s", prefix, str.Name, str.Name, str.SrcFile) 259 for _, field := range str.Decl.Fields.List { 260 if ignoreField(field.Names[0].Name) { 261 continue // not needed 262 } 263 fieldName := normalizeName(field.Names[0].Name) 264 fmt.Fprintf(w, ` | %s (%s)`, fieldName, braceEscaper.Replace(fmtType(field.Type, fset))) 265 } 266 fmt.Fprintf(w, "}>];\n") 267 } 268 269 fmt.Fprintln(w, " }") 270 } 271 fmt.Fprintln(w, "}") 272 } 273 274 func fmtType(typ interface{}, fset *token.FileSet) string { 275 var buf bytes.Buffer 276 _ = printer.Fprint(&buf, fset, typ) 277 return buf.String() 278 } 279 280 var skipFieldList = []string{ 281 "modeluuid", 282 "revno", 283 } 284 285 func ignoreField(name string) bool { 286 name = strings.ToLower(name) 287 288 for _, skipField := range skipFieldList { 289 if strings.Contains(name, strings.ToLower(skipField)) { 290 return true 291 } 292 } 293 294 return false 295 } 296 297 func normalizeName(name string) string { 298 var buf bytes.Buffer 299 300 in := strings.TrimSuffix(name, "Doc") 301 for i, r := range in { 302 if i > 0 && unicode.IsUpper(r) && !unicode.IsUpper(rune(in[i-1])) { 303 buf.WriteRune('_') 304 } 305 buf.WriteRune(unicode.ToLower(r)) 306 } 307 308 return buf.String() 309 }