github.com/Ali-iotechsys/sqlboiler/v4@v4.0.0-20221208124957-6aec9a5f1f71/drivers/interface.go (about)

     1  // Package drivers talks to various database backends and retrieves table,
     2  // column, type, and foreign key information
     3  package drivers
     4  
     5  import (
     6  	"sort"
     7  	"sync"
     8  
     9  	"github.com/friendsofgo/errors"
    10  	"github.com/volatiletech/sqlboiler/v4/importers"
    11  	"github.com/volatiletech/strmangle"
    12  )
    13  
    14  // These constants are used in the config map passed into the driver
    15  const (
    16  	ConfigBlacklist      = "blacklist"
    17  	ConfigWhitelist      = "whitelist"
    18  	ConfigSchema         = "schema"
    19  	ConfigAddEnumTypes   = "add-enum-types"
    20  	ConfigEnumNullPrefix = "enum-null-prefix"
    21  	ConfigConcurrency    = "concurrency"
    22  
    23  	ConfigUser    = "user"
    24  	ConfigPass    = "pass"
    25  	ConfigHost    = "host"
    26  	ConfigPort    = "port"
    27  	ConfigDBName  = "dbname"
    28  	ConfigSSLMode = "sslmode"
    29  
    30  	// DefaultConcurrency defines the default amount of threads to use when loading tables info
    31  	DefaultConcurrency = 10
    32  )
    33  
    34  // Interface abstracts either a side-effect imported driver or a binary
    35  // that is called in order to produce the data required for generation.
    36  type Interface interface {
    37  	// Assemble the database information into a nice struct
    38  	Assemble(config Config) (*DBInfo, error)
    39  	// Templates to add/replace for generation
    40  	Templates() (map[string]string, error)
    41  	// Imports to merge for generation
    42  	Imports() (importers.Collection, error)
    43  }
    44  
    45  // DBInfo is the database's table data and dialect.
    46  type DBInfo struct {
    47  	Schema  string  `json:"schema"`
    48  	Tables  []Table `json:"tables"`
    49  	Dialect Dialect `json:"dialect"`
    50  }
    51  
    52  // Dialect describes the databases requirements in terms of which features
    53  // it speaks and what kind of quoting mechanisms it uses.
    54  //
    55  // WARNING: When updating this struct there is a copy of it inside
    56  // the boil_queries template that is used for users to create queries
    57  // without having to figure out what their dialect is.
    58  type Dialect struct {
    59  	LQ rune `json:"lq"`
    60  	RQ rune `json:"rq"`
    61  
    62  	UseIndexPlaceholders bool `json:"use_index_placeholders"`
    63  	UseLastInsertID      bool `json:"use_last_insert_id"`
    64  	UseSchema            bool `json:"use_schema"`
    65  	UseDefaultKeyword    bool `json:"use_default_keyword"`
    66  
    67  	// The following is mostly for T-SQL/MSSQL, what a show
    68  	UseTopClause            bool `json:"use_top_clause"`
    69  	UseOutputClause         bool `json:"use_output_clause"`
    70  	UseCaseWhenExistsClause bool `json:"use_case_when_exists_clause"`
    71  
    72  	// No longer used, left for backwards compatibility
    73  	// should be removed in v5
    74  	UseAutoColumns bool `json:"use_auto_columns"`
    75  }
    76  
    77  // Constructor breaks down the functionality required to implement a driver
    78  // such that the drivers.Tables method can be used to reduce duplication in driver
    79  // implementations.
    80  type Constructor interface {
    81  	TableNames(schema string, whitelist, blacklist []string) ([]string, error)
    82  	Columns(schema, tableName string, whitelist, blacklist []string) ([]Column, error)
    83  	PrimaryKeyInfo(schema, tableName string) (*PrimaryKey, error)
    84  	ForeignKeyInfo(schema, tableName string) ([]ForeignKey, error)
    85  
    86  	// TranslateColumnType takes a Database column type and returns a go column type.
    87  	TranslateColumnType(Column) Column
    88  }
    89  
    90  // Constructor breaks down the functionality required to implement a driver
    91  // such that the drivers.Views method can be used to reduce duplication in driver
    92  // implementations.
    93  type ViewConstructor interface {
    94  	ViewNames(schema string, whitelist, blacklist []string) ([]string, error)
    95  	ViewCapabilities(schema, viewName string) (ViewCapabilities, error)
    96  	ViewColumns(schema, tableName string, whitelist, blacklist []string) ([]Column, error)
    97  
    98  	// TranslateColumnType takes a Database column type and returns a go column type.
    99  	TranslateColumnType(Column) Column
   100  }
   101  
   102  type TableColumnTypeTranslator interface {
   103  	// TranslateTableColumnType takes a Database column type and table name and returns a go column type.
   104  	TranslateTableColumnType(c Column, tableName string) Column
   105  }
   106  
   107  // Tables returns the metadata for all tables, minus the tables
   108  // specified in the blacklist.
   109  func Tables(c Constructor, schema string, whitelist, blacklist []string) ([]Table, error) {
   110  	return TablesConcurrently(c, schema, whitelist, blacklist, 1)
   111  }
   112  
   113  // TablesConcurrently is a concurrent version of Tables. It returns the
   114  // metadata for all tables, minus the tables specified in the blacklist.
   115  func TablesConcurrently(c Constructor, schema string, whitelist, blacklist []string, concurrency int) ([]Table, error) {
   116  	var err error
   117  	var ret []Table
   118  
   119  	ret, err = tables(c, schema, whitelist, blacklist, concurrency)
   120  	if err != nil {
   121  		return nil, errors.Wrap(err, "unable to load tables")
   122  	}
   123  
   124  	if vc, ok := c.(ViewConstructor); ok {
   125  		v, err := views(vc, schema, whitelist, blacklist, concurrency)
   126  		if err != nil {
   127  			return nil, errors.Wrap(err, "unable to load views")
   128  		}
   129  		ret = append(ret, v...)
   130  	}
   131  
   132  	return ret, nil
   133  }
   134  
   135  func tables(c Constructor, schema string, whitelist, blacklist []string, concurrency int) ([]Table, error) {
   136  	var err error
   137  
   138  	names, err := c.TableNames(schema, whitelist, blacklist)
   139  	if err != nil {
   140  		return nil, errors.Wrap(err, "unable to get table names")
   141  	}
   142  
   143  	sort.Strings(names)
   144  
   145  	ret := make([]Table, len(names))
   146  
   147  	limiter := newConcurrencyLimiter(concurrency)
   148  	wg := sync.WaitGroup{}
   149  	errs := make(chan error, len(names))
   150  	for i, name := range names {
   151  		wg.Add(1)
   152  		limiter.get()
   153  		go func(i int, name string) {
   154  			defer wg.Done()
   155  			defer limiter.put()
   156  			t, err := table(c, schema, name, whitelist, blacklist)
   157  			if err != nil {
   158  				errs <- err
   159  				return
   160  			}
   161  			ret[i] = t
   162  		}(i, name)
   163  	}
   164  
   165  	wg.Wait()
   166  
   167  	// return first error occurred if any
   168  	if len(errs) > 0 {
   169  		return nil, <-errs
   170  	}
   171  
   172  	// Relationships have a dependency on foreign key nullability.
   173  	for i := range ret {
   174  		tbl := &ret[i]
   175  		setForeignKeyConstraints(tbl, ret)
   176  	}
   177  	for i := range ret {
   178  		tbl := &ret[i]
   179  		setRelationships(tbl, ret)
   180  	}
   181  
   182  	return ret, nil
   183  }
   184  
   185  // table returns columns info for a given table
   186  func table(c Constructor, schema string, name string, whitelist, blacklist []string) (Table, error) {
   187  	var err error
   188  	t := &Table{
   189  		Name: name,
   190  	}
   191  
   192  	if t.Columns, err = c.Columns(schema, name, whitelist, blacklist); err != nil {
   193  		return Table{}, errors.Wrapf(err, "unable to fetch table column info (%s)", name)
   194  	}
   195  
   196  	tr, ok := c.(TableColumnTypeTranslator)
   197  	if ok {
   198  		for i, col := range t.Columns {
   199  			t.Columns[i] = tr.TranslateTableColumnType(col, name)
   200  		}
   201  	} else {
   202  		for i, col := range t.Columns {
   203  			t.Columns[i] = c.TranslateColumnType(col)
   204  		}
   205  	}
   206  
   207  	if t.PKey, err = c.PrimaryKeyInfo(schema, name); err != nil {
   208  		return Table{}, errors.Wrapf(err, "unable to fetch table pkey info (%s)", name)
   209  	}
   210  
   211  	if t.FKeys, err = c.ForeignKeyInfo(schema, name); err != nil {
   212  		return Table{}, errors.Wrapf(err, "unable to fetch table fkey info (%s)", name)
   213  	}
   214  
   215  	filterPrimaryKey(t, whitelist, blacklist)
   216  	filterForeignKeys(t, whitelist, blacklist)
   217  
   218  	setIsJoinTable(t)
   219  
   220  	return *t, nil
   221  }
   222  
   223  // views returns the metadata for all views, minus the views
   224  // specified in the blacklist.
   225  func views(c ViewConstructor, schema string, whitelist, blacklist []string, concurrency int) ([]Table, error) {
   226  	var err error
   227  
   228  	names, err := c.ViewNames(schema, whitelist, blacklist)
   229  	if err != nil {
   230  		return nil, errors.Wrap(err, "unable to get view names")
   231  	}
   232  
   233  	sort.Strings(names)
   234  
   235  	ret := make([]Table, len(names))
   236  
   237  	limiter := newConcurrencyLimiter(concurrency)
   238  	wg := sync.WaitGroup{}
   239  	errs := make(chan error, len(names))
   240  	for i, name := range names {
   241  		wg.Add(1)
   242  		limiter.get()
   243  		go func(i int, name string) {
   244  			defer wg.Done()
   245  			defer limiter.put()
   246  			t, err := view(c, schema, name, whitelist, blacklist)
   247  			if err != nil {
   248  				errs <- err
   249  				return
   250  			}
   251  			ret[i] = t
   252  		}(i, name)
   253  	}
   254  
   255  	wg.Wait()
   256  
   257  	// return first error occurred if any
   258  	if len(errs) > 0 {
   259  		return nil, <-errs
   260  	}
   261  
   262  	return ret, nil
   263  }
   264  
   265  // view returns columns info for a given view
   266  func view(c ViewConstructor, schema string, name string, whitelist, blacklist []string) (Table, error) {
   267  	var err error
   268  	t := Table{
   269  		IsView: true,
   270  		Name:   name,
   271  	}
   272  
   273  	if t.ViewCapabilities, err = c.ViewCapabilities(schema, name); err != nil {
   274  		return Table{}, errors.Wrapf(err, "unable to fetch view capabilities info (%s)", name)
   275  	}
   276  
   277  	if t.Columns, err = c.ViewColumns(schema, name, whitelist, blacklist); err != nil {
   278  		return Table{}, errors.Wrapf(err, "unable to fetch view column info (%s)", name)
   279  	}
   280  
   281  	tr, ok := c.(TableColumnTypeTranslator)
   282  	if ok {
   283  		for i, col := range t.Columns {
   284  			t.Columns[i] = tr.TranslateTableColumnType(col, name)
   285  		}
   286  	} else {
   287  		for i, col := range t.Columns {
   288  			t.Columns[i] = c.TranslateColumnType(col)
   289  		}
   290  	}
   291  
   292  	return t, nil
   293  }
   294  
   295  func knownColumn(table string, column string, whitelist, blacklist []string) bool {
   296  	return (len(whitelist) == 0 ||
   297  		strmangle.SetInclude(table, whitelist) ||
   298  		strmangle.SetInclude(table+"."+column, whitelist) ||
   299  		strmangle.SetInclude("*."+column, whitelist)) &&
   300  		(len(blacklist) == 0 || (!strmangle.SetInclude(table, blacklist) &&
   301  			!strmangle.SetInclude(table+"."+column, blacklist) &&
   302  			!strmangle.SetInclude("*."+column, blacklist)))
   303  }
   304  
   305  // filterPrimaryKey filter columns from the primary key that are not in whitelist or in blacklist
   306  func filterPrimaryKey(t *Table, whitelist, blacklist []string) {
   307  	if t.PKey == nil {
   308  		return
   309  	}
   310  
   311  	pkeyColumns := make([]string, 0, len(t.PKey.Columns))
   312  	for _, c := range t.PKey.Columns {
   313  		if knownColumn(t.Name, c, whitelist, blacklist) {
   314  			pkeyColumns = append(pkeyColumns, c)
   315  		}
   316  	}
   317  	t.PKey.Columns = pkeyColumns
   318  }
   319  
   320  // filterForeignKeys filter FK whose ForeignTable is not in whitelist or in blacklist
   321  func filterForeignKeys(t *Table, whitelist, blacklist []string) {
   322  	var fkeys []ForeignKey
   323  
   324  	for _, fkey := range t.FKeys {
   325  		if knownColumn(fkey.ForeignTable, fkey.ForeignColumn, whitelist, blacklist) &&
   326  			knownColumn(fkey.Table, fkey.Column, whitelist, blacklist) {
   327  			fkeys = append(fkeys, fkey)
   328  		}
   329  	}
   330  	t.FKeys = fkeys
   331  }
   332  
   333  // setIsJoinTable if there are:
   334  // A composite primary key involving two columns
   335  // Both primary key columns are also foreign keys
   336  func setIsJoinTable(t *Table) {
   337  	if t.PKey == nil || len(t.PKey.Columns) != 2 || len(t.FKeys) < 2 || len(t.Columns) > 2 {
   338  		return
   339  	}
   340  
   341  	for _, c := range t.PKey.Columns {
   342  		found := false
   343  		for _, f := range t.FKeys {
   344  			if c == f.Column {
   345  				found = true
   346  				break
   347  			}
   348  		}
   349  		if !found {
   350  			return
   351  		}
   352  	}
   353  
   354  	t.IsJoinTable = true
   355  }
   356  
   357  func setForeignKeyConstraints(t *Table, tables []Table) {
   358  	for i, fkey := range t.FKeys {
   359  		localColumn := t.GetColumn(fkey.Column)
   360  		foreignTable := GetTable(tables, fkey.ForeignTable)
   361  		foreignColumn := foreignTable.GetColumn(fkey.ForeignColumn)
   362  
   363  		t.FKeys[i].Nullable = localColumn.Nullable
   364  		t.FKeys[i].Unique = localColumn.Unique
   365  		t.FKeys[i].ForeignColumnNullable = foreignColumn.Nullable
   366  		t.FKeys[i].ForeignColumnUnique = foreignColumn.Unique
   367  	}
   368  }
   369  
   370  func setRelationships(t *Table, tables []Table) {
   371  	t.ToOneRelationships = toOneRelationships(*t, tables)
   372  	t.ToManyRelationships = toManyRelationships(*t, tables)
   373  }
   374  
   375  // concurrencyCounter is a helper structure that can limit amount of concurrently processed requests
   376  type concurrencyLimiter chan struct{}
   377  
   378  func newConcurrencyLimiter(capacity int) concurrencyLimiter {
   379  	ret := make(concurrencyLimiter, capacity)
   380  	for i := 0; i < capacity; i++ {
   381  		ret <- struct{}{}
   382  	}
   383  
   384  	return ret
   385  }
   386  
   387  func (c concurrencyLimiter) get() {
   388  	<-c
   389  }
   390  
   391  func (c concurrencyLimiter) put() {
   392  	c <- struct{}{}
   393  }