github.com/Ali-iotechsys/sqlboiler/v4@v4.0.0-20221208124957-6aec9a5f1f71/drivers/interface.go (about) 1 // Package drivers talks to various database backends and retrieves table, 2 // column, type, and foreign key information 3 package drivers 4 5 import ( 6 "sort" 7 "sync" 8 9 "github.com/friendsofgo/errors" 10 "github.com/volatiletech/sqlboiler/v4/importers" 11 "github.com/volatiletech/strmangle" 12 ) 13 14 // These constants are used in the config map passed into the driver 15 const ( 16 ConfigBlacklist = "blacklist" 17 ConfigWhitelist = "whitelist" 18 ConfigSchema = "schema" 19 ConfigAddEnumTypes = "add-enum-types" 20 ConfigEnumNullPrefix = "enum-null-prefix" 21 ConfigConcurrency = "concurrency" 22 23 ConfigUser = "user" 24 ConfigPass = "pass" 25 ConfigHost = "host" 26 ConfigPort = "port" 27 ConfigDBName = "dbname" 28 ConfigSSLMode = "sslmode" 29 30 // DefaultConcurrency defines the default amount of threads to use when loading tables info 31 DefaultConcurrency = 10 32 ) 33 34 // Interface abstracts either a side-effect imported driver or a binary 35 // that is called in order to produce the data required for generation. 36 type Interface interface { 37 // Assemble the database information into a nice struct 38 Assemble(config Config) (*DBInfo, error) 39 // Templates to add/replace for generation 40 Templates() (map[string]string, error) 41 // Imports to merge for generation 42 Imports() (importers.Collection, error) 43 } 44 45 // DBInfo is the database's table data and dialect. 46 type DBInfo struct { 47 Schema string `json:"schema"` 48 Tables []Table `json:"tables"` 49 Dialect Dialect `json:"dialect"` 50 } 51 52 // Dialect describes the databases requirements in terms of which features 53 // it speaks and what kind of quoting mechanisms it uses. 54 // 55 // WARNING: When updating this struct there is a copy of it inside 56 // the boil_queries template that is used for users to create queries 57 // without having to figure out what their dialect is. 58 type Dialect struct { 59 LQ rune `json:"lq"` 60 RQ rune `json:"rq"` 61 62 UseIndexPlaceholders bool `json:"use_index_placeholders"` 63 UseLastInsertID bool `json:"use_last_insert_id"` 64 UseSchema bool `json:"use_schema"` 65 UseDefaultKeyword bool `json:"use_default_keyword"` 66 67 // The following is mostly for T-SQL/MSSQL, what a show 68 UseTopClause bool `json:"use_top_clause"` 69 UseOutputClause bool `json:"use_output_clause"` 70 UseCaseWhenExistsClause bool `json:"use_case_when_exists_clause"` 71 72 // No longer used, left for backwards compatibility 73 // should be removed in v5 74 UseAutoColumns bool `json:"use_auto_columns"` 75 } 76 77 // Constructor breaks down the functionality required to implement a driver 78 // such that the drivers.Tables method can be used to reduce duplication in driver 79 // implementations. 80 type Constructor interface { 81 TableNames(schema string, whitelist, blacklist []string) ([]string, error) 82 Columns(schema, tableName string, whitelist, blacklist []string) ([]Column, error) 83 PrimaryKeyInfo(schema, tableName string) (*PrimaryKey, error) 84 ForeignKeyInfo(schema, tableName string) ([]ForeignKey, error) 85 86 // TranslateColumnType takes a Database column type and returns a go column type. 87 TranslateColumnType(Column) Column 88 } 89 90 // Constructor breaks down the functionality required to implement a driver 91 // such that the drivers.Views method can be used to reduce duplication in driver 92 // implementations. 93 type ViewConstructor interface { 94 ViewNames(schema string, whitelist, blacklist []string) ([]string, error) 95 ViewCapabilities(schema, viewName string) (ViewCapabilities, error) 96 ViewColumns(schema, tableName string, whitelist, blacklist []string) ([]Column, error) 97 98 // TranslateColumnType takes a Database column type and returns a go column type. 99 TranslateColumnType(Column) Column 100 } 101 102 type TableColumnTypeTranslator interface { 103 // TranslateTableColumnType takes a Database column type and table name and returns a go column type. 104 TranslateTableColumnType(c Column, tableName string) Column 105 } 106 107 // Tables returns the metadata for all tables, minus the tables 108 // specified in the blacklist. 109 func Tables(c Constructor, schema string, whitelist, blacklist []string) ([]Table, error) { 110 return TablesConcurrently(c, schema, whitelist, blacklist, 1) 111 } 112 113 // TablesConcurrently is a concurrent version of Tables. It returns the 114 // metadata for all tables, minus the tables specified in the blacklist. 115 func TablesConcurrently(c Constructor, schema string, whitelist, blacklist []string, concurrency int) ([]Table, error) { 116 var err error 117 var ret []Table 118 119 ret, err = tables(c, schema, whitelist, blacklist, concurrency) 120 if err != nil { 121 return nil, errors.Wrap(err, "unable to load tables") 122 } 123 124 if vc, ok := c.(ViewConstructor); ok { 125 v, err := views(vc, schema, whitelist, blacklist, concurrency) 126 if err != nil { 127 return nil, errors.Wrap(err, "unable to load views") 128 } 129 ret = append(ret, v...) 130 } 131 132 return ret, nil 133 } 134 135 func tables(c Constructor, schema string, whitelist, blacklist []string, concurrency int) ([]Table, error) { 136 var err error 137 138 names, err := c.TableNames(schema, whitelist, blacklist) 139 if err != nil { 140 return nil, errors.Wrap(err, "unable to get table names") 141 } 142 143 sort.Strings(names) 144 145 ret := make([]Table, len(names)) 146 147 limiter := newConcurrencyLimiter(concurrency) 148 wg := sync.WaitGroup{} 149 errs := make(chan error, len(names)) 150 for i, name := range names { 151 wg.Add(1) 152 limiter.get() 153 go func(i int, name string) { 154 defer wg.Done() 155 defer limiter.put() 156 t, err := table(c, schema, name, whitelist, blacklist) 157 if err != nil { 158 errs <- err 159 return 160 } 161 ret[i] = t 162 }(i, name) 163 } 164 165 wg.Wait() 166 167 // return first error occurred if any 168 if len(errs) > 0 { 169 return nil, <-errs 170 } 171 172 // Relationships have a dependency on foreign key nullability. 173 for i := range ret { 174 tbl := &ret[i] 175 setForeignKeyConstraints(tbl, ret) 176 } 177 for i := range ret { 178 tbl := &ret[i] 179 setRelationships(tbl, ret) 180 } 181 182 return ret, nil 183 } 184 185 // table returns columns info for a given table 186 func table(c Constructor, schema string, name string, whitelist, blacklist []string) (Table, error) { 187 var err error 188 t := &Table{ 189 Name: name, 190 } 191 192 if t.Columns, err = c.Columns(schema, name, whitelist, blacklist); err != nil { 193 return Table{}, errors.Wrapf(err, "unable to fetch table column info (%s)", name) 194 } 195 196 tr, ok := c.(TableColumnTypeTranslator) 197 if ok { 198 for i, col := range t.Columns { 199 t.Columns[i] = tr.TranslateTableColumnType(col, name) 200 } 201 } else { 202 for i, col := range t.Columns { 203 t.Columns[i] = c.TranslateColumnType(col) 204 } 205 } 206 207 if t.PKey, err = c.PrimaryKeyInfo(schema, name); err != nil { 208 return Table{}, errors.Wrapf(err, "unable to fetch table pkey info (%s)", name) 209 } 210 211 if t.FKeys, err = c.ForeignKeyInfo(schema, name); err != nil { 212 return Table{}, errors.Wrapf(err, "unable to fetch table fkey info (%s)", name) 213 } 214 215 filterPrimaryKey(t, whitelist, blacklist) 216 filterForeignKeys(t, whitelist, blacklist) 217 218 setIsJoinTable(t) 219 220 return *t, nil 221 } 222 223 // views returns the metadata for all views, minus the views 224 // specified in the blacklist. 225 func views(c ViewConstructor, schema string, whitelist, blacklist []string, concurrency int) ([]Table, error) { 226 var err error 227 228 names, err := c.ViewNames(schema, whitelist, blacklist) 229 if err != nil { 230 return nil, errors.Wrap(err, "unable to get view names") 231 } 232 233 sort.Strings(names) 234 235 ret := make([]Table, len(names)) 236 237 limiter := newConcurrencyLimiter(concurrency) 238 wg := sync.WaitGroup{} 239 errs := make(chan error, len(names)) 240 for i, name := range names { 241 wg.Add(1) 242 limiter.get() 243 go func(i int, name string) { 244 defer wg.Done() 245 defer limiter.put() 246 t, err := view(c, schema, name, whitelist, blacklist) 247 if err != nil { 248 errs <- err 249 return 250 } 251 ret[i] = t 252 }(i, name) 253 } 254 255 wg.Wait() 256 257 // return first error occurred if any 258 if len(errs) > 0 { 259 return nil, <-errs 260 } 261 262 return ret, nil 263 } 264 265 // view returns columns info for a given view 266 func view(c ViewConstructor, schema string, name string, whitelist, blacklist []string) (Table, error) { 267 var err error 268 t := Table{ 269 IsView: true, 270 Name: name, 271 } 272 273 if t.ViewCapabilities, err = c.ViewCapabilities(schema, name); err != nil { 274 return Table{}, errors.Wrapf(err, "unable to fetch view capabilities info (%s)", name) 275 } 276 277 if t.Columns, err = c.ViewColumns(schema, name, whitelist, blacklist); err != nil { 278 return Table{}, errors.Wrapf(err, "unable to fetch view column info (%s)", name) 279 } 280 281 tr, ok := c.(TableColumnTypeTranslator) 282 if ok { 283 for i, col := range t.Columns { 284 t.Columns[i] = tr.TranslateTableColumnType(col, name) 285 } 286 } else { 287 for i, col := range t.Columns { 288 t.Columns[i] = c.TranslateColumnType(col) 289 } 290 } 291 292 return t, nil 293 } 294 295 func knownColumn(table string, column string, whitelist, blacklist []string) bool { 296 return (len(whitelist) == 0 || 297 strmangle.SetInclude(table, whitelist) || 298 strmangle.SetInclude(table+"."+column, whitelist) || 299 strmangle.SetInclude("*."+column, whitelist)) && 300 (len(blacklist) == 0 || (!strmangle.SetInclude(table, blacklist) && 301 !strmangle.SetInclude(table+"."+column, blacklist) && 302 !strmangle.SetInclude("*."+column, blacklist))) 303 } 304 305 // filterPrimaryKey filter columns from the primary key that are not in whitelist or in blacklist 306 func filterPrimaryKey(t *Table, whitelist, blacklist []string) { 307 if t.PKey == nil { 308 return 309 } 310 311 pkeyColumns := make([]string, 0, len(t.PKey.Columns)) 312 for _, c := range t.PKey.Columns { 313 if knownColumn(t.Name, c, whitelist, blacklist) { 314 pkeyColumns = append(pkeyColumns, c) 315 } 316 } 317 t.PKey.Columns = pkeyColumns 318 } 319 320 // filterForeignKeys filter FK whose ForeignTable is not in whitelist or in blacklist 321 func filterForeignKeys(t *Table, whitelist, blacklist []string) { 322 var fkeys []ForeignKey 323 324 for _, fkey := range t.FKeys { 325 if knownColumn(fkey.ForeignTable, fkey.ForeignColumn, whitelist, blacklist) && 326 knownColumn(fkey.Table, fkey.Column, whitelist, blacklist) { 327 fkeys = append(fkeys, fkey) 328 } 329 } 330 t.FKeys = fkeys 331 } 332 333 // setIsJoinTable if there are: 334 // A composite primary key involving two columns 335 // Both primary key columns are also foreign keys 336 func setIsJoinTable(t *Table) { 337 if t.PKey == nil || len(t.PKey.Columns) != 2 || len(t.FKeys) < 2 || len(t.Columns) > 2 { 338 return 339 } 340 341 for _, c := range t.PKey.Columns { 342 found := false 343 for _, f := range t.FKeys { 344 if c == f.Column { 345 found = true 346 break 347 } 348 } 349 if !found { 350 return 351 } 352 } 353 354 t.IsJoinTable = true 355 } 356 357 func setForeignKeyConstraints(t *Table, tables []Table) { 358 for i, fkey := range t.FKeys { 359 localColumn := t.GetColumn(fkey.Column) 360 foreignTable := GetTable(tables, fkey.ForeignTable) 361 foreignColumn := foreignTable.GetColumn(fkey.ForeignColumn) 362 363 t.FKeys[i].Nullable = localColumn.Nullable 364 t.FKeys[i].Unique = localColumn.Unique 365 t.FKeys[i].ForeignColumnNullable = foreignColumn.Nullable 366 t.FKeys[i].ForeignColumnUnique = foreignColumn.Unique 367 } 368 } 369 370 func setRelationships(t *Table, tables []Table) { 371 t.ToOneRelationships = toOneRelationships(*t, tables) 372 t.ToManyRelationships = toManyRelationships(*t, tables) 373 } 374 375 // concurrencyCounter is a helper structure that can limit amount of concurrently processed requests 376 type concurrencyLimiter chan struct{} 377 378 func newConcurrencyLimiter(capacity int) concurrencyLimiter { 379 ret := make(concurrencyLimiter, capacity) 380 for i := 0; i < capacity; i++ { 381 ret <- struct{}{} 382 } 383 384 return ret 385 } 386 387 func (c concurrencyLimiter) get() { 388 <-c 389 } 390 391 func (c concurrencyLimiter) put() { 392 c <- struct{}{} 393 }