github.com/Ali-iotechsys/sqlboiler/v4@v4.0.0-20221208124957-6aec9a5f1f71/drivers/sqlboiler-mssql/driver/mssql.go (about)

     1  package driver
     2  
     3  import (
     4  	"database/sql"
     5  	"embed"
     6  	"encoding/base64"
     7  	"fmt"
     8  	"io/fs"
     9  	"net/url"
    10  	"strings"
    11  
    12  	// Side effect import go-mssqldb
    13  	"github.com/friendsofgo/errors"
    14  	_ "github.com/microsoft/go-mssqldb"
    15  	"github.com/volatiletech/sqlboiler/v4/drivers"
    16  	"github.com/volatiletech/sqlboiler/v4/importers"
    17  	"github.com/volatiletech/strmangle"
    18  )
    19  
    20  //go:embed override
    21  var templates embed.FS
    22  
    23  func init() {
    24  	drivers.RegisterFromInit("mssql", &MSSQLDriver{})
    25  }
    26  
    27  // Assemble is more useful for calling into the library so you don't
    28  // have to instantiate an empty type.
    29  func Assemble(config drivers.Config) (dbinfo *drivers.DBInfo, err error) {
    30  	driver := MSSQLDriver{}
    31  	return driver.Assemble(config)
    32  }
    33  
    34  // MSSQLDriver holds the database connection string and a handle
    35  // to the database connection.
    36  type MSSQLDriver struct {
    37  	connStr string
    38  	conn    *sql.DB
    39  }
    40  
    41  // Templates that should be added/overridden
    42  func (MSSQLDriver) Templates() (map[string]string, error) {
    43  	tpls := make(map[string]string)
    44  	fs.WalkDir(templates, "override", func(path string, d fs.DirEntry, err error) error {
    45  		if err != nil {
    46  			return err
    47  		}
    48  
    49  		if d.IsDir() {
    50  			return nil
    51  		}
    52  
    53  		b, err := fs.ReadFile(templates, path)
    54  		if err != nil {
    55  			return err
    56  		}
    57  		tpls[strings.Replace(path, "override/", "", 1)] = base64.StdEncoding.EncodeToString(b)
    58  
    59  		return nil
    60  	})
    61  
    62  	return tpls, nil
    63  }
    64  
    65  // Assemble all the information we need to provide back to the driver
    66  func (m *MSSQLDriver) Assemble(config drivers.Config) (dbinfo *drivers.DBInfo, err error) {
    67  	defer func() {
    68  		if r := recover(); r != nil && err == nil {
    69  			dbinfo = nil
    70  			err = r.(error)
    71  		}
    72  	}()
    73  
    74  	user := config.MustString(drivers.ConfigUser)
    75  	pass, _ := config.String(drivers.ConfigPass)
    76  	dbname := config.MustString(drivers.ConfigDBName)
    77  	host := config.MustString(drivers.ConfigHost)
    78  	port := config.DefaultInt(drivers.ConfigPort, 1433)
    79  	sslmode := config.DefaultString(drivers.ConfigSSLMode, "true")
    80  
    81  	schema := config.DefaultString(drivers.ConfigSchema, "dbo")
    82  	whitelist, _ := config.StringSlice(drivers.ConfigWhitelist)
    83  	blacklist, _ := config.StringSlice(drivers.ConfigBlacklist)
    84  	concurrency := config.DefaultInt(drivers.ConfigConcurrency, drivers.DefaultConcurrency)
    85  
    86  	m.connStr = MSSQLBuildQueryString(user, pass, dbname, host, port, sslmode)
    87  	m.conn, err = sql.Open("mssql", m.connStr)
    88  	if err != nil {
    89  		return nil, errors.Wrap(err, "sqlboiler-mssql failed to connect to database")
    90  	}
    91  
    92  	defer func() {
    93  		if e := m.conn.Close(); e != nil {
    94  			dbinfo = nil
    95  			err = e
    96  		}
    97  	}()
    98  
    99  	dbinfo = &drivers.DBInfo{
   100  		Schema: schema,
   101  		Dialect: drivers.Dialect{
   102  			LQ: '[',
   103  			RQ: ']',
   104  
   105  			UseIndexPlaceholders: true,
   106  			UseSchema:            true,
   107  			UseDefaultKeyword:    true,
   108  
   109  			UseTopClause:            true,
   110  			UseOutputClause:         true,
   111  			UseCaseWhenExistsClause: true,
   112  		},
   113  	}
   114  	dbinfo.Tables, err = drivers.TablesConcurrently(m, schema, whitelist, blacklist, concurrency)
   115  	if err != nil {
   116  		return nil, err
   117  	}
   118  
   119  	return dbinfo, err
   120  }
   121  
   122  // MSSQLBuildQueryString builds a query string for MSSQL.
   123  func MSSQLBuildQueryString(user, pass, dbname, host string, port int, sslmode string) string {
   124  	query := url.Values{}
   125  	query.Add("database", dbname)
   126  	query.Add("encrypt", sslmode)
   127  
   128  	u := &url.URL{
   129  		Scheme:   "sqlserver",
   130  		User:     url.UserPassword(user, pass),
   131  		Host:     fmt.Sprintf("%s:%d", host, port),
   132  		RawQuery: query.Encode(),
   133  	}
   134  
   135  	// If the host is an "sqlserver instance" then we set the Path not the Host
   136  	// so the url package doesn't escape the /
   137  	if strings.Contains(host, "/") {
   138  		u.Path = host
   139  		u.Host = ""
   140  	}
   141  
   142  	return u.String()
   143  }
   144  
   145  // TableNames connects to the postgres database and
   146  // retrieves all table names from the information_schema where the
   147  // table schema is schema. It uses a whitelist and blacklist.
   148  func (m *MSSQLDriver) TableNames(schema string, whitelist, blacklist []string) ([]string, error) {
   149  	var names []string
   150  
   151  	query := `
   152  		SELECT table_name
   153  		FROM   information_schema.tables
   154  		WHERE  table_schema = ? AND table_type = 'BASE TABLE'`
   155  
   156  	args := []interface{}{schema}
   157  	if len(whitelist) > 0 {
   158  		tables := drivers.TablesFromList(whitelist)
   159  		if len(tables) > 0 {
   160  			query += fmt.Sprintf(" AND table_name IN (%s)", strings.Repeat(",?", len(tables))[1:])
   161  			for _, w := range tables {
   162  				args = append(args, w)
   163  			}
   164  		}
   165  	} else if len(blacklist) > 0 {
   166  		tables := drivers.TablesFromList(blacklist)
   167  		if len(tables) > 0 {
   168  			query += fmt.Sprintf(" AND table_name not IN (%s)", strings.Repeat(",?", len(tables))[1:])
   169  			for _, b := range tables {
   170  				args = append(args, b)
   171  			}
   172  		}
   173  	}
   174  
   175  	query += ` ORDER BY table_name;`
   176  
   177  	rows, err := m.conn.Query(query, args...)
   178  
   179  	if err != nil {
   180  		return nil, err
   181  	}
   182  
   183  	defer rows.Close()
   184  	for rows.Next() {
   185  		var name string
   186  		if err := rows.Scan(&name); err != nil {
   187  			return nil, err
   188  		}
   189  		names = append(names, name)
   190  	}
   191  
   192  	return names, nil
   193  }
   194  
   195  // ViewNames connects to the postgres database and
   196  // retrieves all view names from the information_schema where the
   197  // view schema is schema. It uses a whitelist and blacklist.
   198  func (m *MSSQLDriver) ViewNames(schema string, whitelist, blacklist []string) ([]string, error) {
   199  	var names []string
   200  
   201  	query := `select table_name from information_schema.views where table_schema = ?`
   202  	args := []interface{}{schema}
   203  	if len(whitelist) > 0 {
   204  		tables := drivers.TablesFromList(whitelist)
   205  		if len(tables) > 0 {
   206  			query += fmt.Sprintf(" and table_name in (%s)", strings.Repeat(",?", len(tables))[1:])
   207  			for _, w := range tables {
   208  				args = append(args, w)
   209  			}
   210  		}
   211  	} else if len(blacklist) > 0 {
   212  		tables := drivers.TablesFromList(blacklist)
   213  		if len(tables) > 0 {
   214  			query += fmt.Sprintf(" and table_name not in (%s)", strings.Repeat(",?", len(tables))[1:])
   215  			for _, b := range tables {
   216  				args = append(args, b)
   217  			}
   218  		}
   219  	}
   220  
   221  	query += ` order by table_name;`
   222  
   223  	rows, err := m.conn.Query(query, args...)
   224  
   225  	if err != nil {
   226  		return nil, err
   227  	}
   228  
   229  	defer rows.Close()
   230  	for rows.Next() {
   231  		var name string
   232  		if err := rows.Scan(&name); err != nil {
   233  			return nil, err
   234  		}
   235  
   236  		names = append(names, name)
   237  	}
   238  
   239  	return names, nil
   240  }
   241  
   242  // ViewCapabilities return what actions are allowed for a view.
   243  func (m *MSSQLDriver) ViewCapabilities(schema, name string) (drivers.ViewCapabilities, error) {
   244  	// This depends on the specific query and is not possible to ensure
   245  	// from just the schema
   246  	capabilities := drivers.ViewCapabilities{
   247  		CanInsert: false,
   248  		CanUpsert: false,
   249  	}
   250  
   251  	return capabilities, nil
   252  }
   253  
   254  func (m *MSSQLDriver) ViewColumns(schema, tableName string, whitelist, blacklist []string) ([]drivers.Column, error) {
   255  	return m.Columns(schema, tableName, whitelist, blacklist)
   256  }
   257  
   258  // Columns takes a table name and attempts to retrieve the table information
   259  // from the database information_schema.columns. It retrieves the column names
   260  // and column types and returns those as a []Column after TranslateColumnType()
   261  // converts the SQL types to Go types, for example: "varchar" to "string"
   262  func (m *MSSQLDriver) Columns(schema, tableName string, whitelist, blacklist []string) ([]drivers.Column, error) {
   263  	var columns []drivers.Column
   264  	args := []interface{}{schema, tableName}
   265  	query := `
   266  	SELECT column_name,
   267         CASE
   268           WHEN character_maximum_length IS NULL THEN data_type
   269           ELSE data_type + '(' + CAST(character_maximum_length AS VARCHAR) + ')'
   270         END AS full_type,
   271         data_type,
   272  	   column_default,
   273         CASE
   274           WHEN is_nullable = 'YES' THEN 1
   275           ELSE 0
   276         END AS is_nullable,
   277         CASE
   278           WHEN EXISTS (SELECT c.column_name
   279                        FROM information_schema.table_constraints tc
   280                          INNER JOIN information_schema.key_column_usage kcu
   281                                  ON tc.constraint_name = kcu.constraint_name
   282                                 AND tc.table_name = kcu.table_name
   283                                 AND tc.table_schema = kcu.table_schema
   284                        WHERE c.column_name = kcu.column_name
   285                        AND   tc.table_name = c.table_name
   286                        AND   (tc.constraint_type = 'PRIMARY KEY' OR tc.constraint_type = 'UNIQUE')
   287                        AND   (SELECT COUNT(*)
   288                               FROM information_schema.key_column_usage
   289                               WHERE table_schema = kcu.table_schema
   290                               AND   table_name = tc.table_name
   291                               AND   constraint_name = tc.constraint_name) = 1) THEN 1
   292           ELSE 0
   293         END AS is_unique,
   294  	   COLUMNPROPERTY(object_id($1 + '.' + $2), c.column_name, 'IsIdentity') as is_identity,
   295  	   COLUMNPROPERTY(object_id($1 + '.' + $2), c.column_name, 'IsComputed') as is_computed
   296  	FROM information_schema.columns c
   297  	WHERE table_schema = $1 AND table_name = $2`
   298  
   299  	if len(whitelist) > 0 {
   300  		cols := drivers.ColumnsFromList(whitelist, tableName)
   301  		if len(cols) > 0 {
   302  			query += fmt.Sprintf(" and c.column_name in (%s)", strmangle.Placeholders(true, len(cols), 3, 1))
   303  			for _, w := range cols {
   304  				args = append(args, w)
   305  			}
   306  		}
   307  	} else if len(blacklist) > 0 {
   308  		cols := drivers.ColumnsFromList(blacklist, tableName)
   309  		if len(cols) > 0 {
   310  			query += fmt.Sprintf(" and c.column_name not in (%s)", strmangle.Placeholders(true, len(cols), 3, 1))
   311  			for _, w := range cols {
   312  				args = append(args, w)
   313  			}
   314  		}
   315  	}
   316  
   317  	query += ` ORDER BY ordinal_position;`
   318  
   319  	rows, err := m.conn.Query(query, args...)
   320  	if err != nil {
   321  		return nil, err
   322  	}
   323  	defer rows.Close()
   324  
   325  	for rows.Next() {
   326  		var colName, colType, colFullType string
   327  		var nullable, unique, identity, computed bool
   328  		var defaultValue *string
   329  		if err := rows.Scan(&colName, &colFullType, &colType, &defaultValue, &nullable, &unique, &identity, &computed); err != nil {
   330  			return nil, errors.Wrapf(err, "unable to scan for table %s", tableName)
   331  		}
   332  
   333  		computed = computed || strings.EqualFold(colType, "timestamp") || strings.EqualFold(colType, "rowversion")
   334  
   335  		column := drivers.Column{
   336  			Name:          colName,
   337  			FullDBType:    colFullType,
   338  			DBType:        colType,
   339  			Nullable:      nullable,
   340  			Unique:        unique,
   341  			AutoGenerated: computed || identity,
   342  		}
   343  
   344  		if defaultValue != nil {
   345  			column.Default = *defaultValue
   346  		}
   347  
   348  		// A generated column technically has a default value
   349  		if column.Default == "" && column.AutoGenerated {
   350  			column.Default = "AUTO_GENERATED"
   351  		}
   352  
   353  		columns = append(columns, column)
   354  	}
   355  
   356  	return columns, nil
   357  }
   358  
   359  // PrimaryKeyInfo looks up the primary key for a table.
   360  func (m *MSSQLDriver) PrimaryKeyInfo(schema, tableName string) (*drivers.PrimaryKey, error) {
   361  	pkey := &drivers.PrimaryKey{}
   362  	var err error
   363  
   364  	query := `
   365  	SELECT constraint_name
   366  	FROM   information_schema.table_constraints
   367  	WHERE  table_name = ? AND constraint_type = 'PRIMARY KEY' AND table_schema = ?;`
   368  
   369  	row := m.conn.QueryRow(query, tableName, schema)
   370  	if err = row.Scan(&pkey.Name); err != nil {
   371  		if errors.Is(err, sql.ErrNoRows) {
   372  			return nil, nil
   373  		}
   374  		return nil, err
   375  	}
   376  
   377  	queryColumns := `
   378  	SELECT column_name
   379  	FROM   information_schema.key_column_usage
   380  	WHERE  table_name = ? AND constraint_name = ? AND table_schema = ?
   381  	ORDER BY ordinal_position;`
   382  
   383  	var rows *sql.Rows
   384  	if rows, err = m.conn.Query(queryColumns, tableName, pkey.Name, schema); err != nil {
   385  		return nil, err
   386  	}
   387  	defer rows.Close()
   388  
   389  	var columns []string
   390  	for rows.Next() {
   391  		var column string
   392  
   393  		err = rows.Scan(&column)
   394  		if err != nil {
   395  			return nil, err
   396  		}
   397  
   398  		columns = append(columns, column)
   399  	}
   400  
   401  	if err = rows.Err(); err != nil {
   402  		return nil, err
   403  	}
   404  
   405  	pkey.Columns = columns
   406  
   407  	return pkey, nil
   408  }
   409  
   410  // ForeignKeyInfo retrieves the foreign keys for a given table name.
   411  func (m *MSSQLDriver) ForeignKeyInfo(schema, tableName string) ([]drivers.ForeignKey, error) {
   412  	var fkeys []drivers.ForeignKey
   413  
   414  	query := `
   415  	SELECT ccu.constraint_name ,
   416  		ccu.table_name AS local_table ,
   417  		ccu.column_name AS local_column ,
   418  		kcu.table_name AS foreign_table ,
   419  		kcu.column_name AS foreign_column
   420  	FROM information_schema.constraint_column_usage ccu
   421  	INNER JOIN information_schema.referential_constraints rc ON ccu.constraint_name = rc.constraint_name
   422  	INNER JOIN information_schema.key_column_usage kcu ON kcu.constraint_name = rc.unique_constraint_name
   423  	WHERE ccu.table_schema = ?
   424  	  AND ccu.constraint_schema = ?
   425  	  AND ccu.table_name = ?
   426  	ORDER BY ccu.constraint_name, local_table, local_column, foreign_table, foreign_column
   427  	`
   428  
   429  	var rows *sql.Rows
   430  	var err error
   431  	if rows, err = m.conn.Query(query, schema, schema, tableName); err != nil {
   432  		return nil, err
   433  	}
   434  
   435  	for rows.Next() {
   436  		var fkey drivers.ForeignKey
   437  		var sourceTable string
   438  
   439  		fkey.Table = tableName
   440  		err = rows.Scan(&fkey.Name, &sourceTable, &fkey.Column, &fkey.ForeignTable, &fkey.ForeignColumn)
   441  		if err != nil {
   442  			return nil, err
   443  		}
   444  
   445  		fkeys = append(fkeys, fkey)
   446  	}
   447  
   448  	if err = rows.Err(); err != nil {
   449  		return nil, err
   450  	}
   451  
   452  	return fkeys, nil
   453  }
   454  
   455  // TranslateColumnType converts postgres database types to Go types, for example
   456  // "varchar" to "string" and "bigint" to "int64". It returns this parsed data
   457  // as a Column object.
   458  func (m *MSSQLDriver) TranslateColumnType(c drivers.Column) drivers.Column {
   459  	if c.Nullable {
   460  		switch c.DBType {
   461  		case "tinyint":
   462  			c.Type = "null.Int8"
   463  		case "smallint":
   464  			c.Type = "null.Int16"
   465  		case "mediumint":
   466  			c.Type = "null.Int32"
   467  		case "int":
   468  			c.Type = "null.Int"
   469  		case "bigint":
   470  			c.Type = "null.Int64"
   471  		case "real":
   472  			c.Type = "null.Float32"
   473  		case "float":
   474  			c.Type = "null.Float64"
   475  		case "boolean", "bool", "bit":
   476  			c.Type = "null.Bool"
   477  		case "date", "datetime", "datetime2", "datetimeoffset", "smalldatetime", "time":
   478  			c.Type = "null.Time"
   479  		case "binary", "varbinary":
   480  			c.Type = "null.Bytes"
   481  		case "timestamp", "rowversion":
   482  			c.Type = "null.Bytes"
   483  		case "xml":
   484  			c.Type = "null.String"
   485  		case "uniqueidentifier":
   486  			c.Type = "mssql.UniqueIdentifier"
   487  			c.DBType = "uuid"
   488  		case "numeric", "decimal", "dec":
   489  			c.Type = "types.NullDecimal"
   490  		default:
   491  			c.Type = "null.String"
   492  		}
   493  	} else {
   494  		switch c.DBType {
   495  		case "tinyint":
   496  			c.Type = "int8"
   497  		case "smallint":
   498  			c.Type = "int16"
   499  		case "mediumint":
   500  			c.Type = "int32"
   501  		case "int":
   502  			c.Type = "int"
   503  		case "bigint":
   504  			c.Type = "int64"
   505  		case "real":
   506  			c.Type = "float32"
   507  		case "float":
   508  			c.Type = "float64"
   509  		case "boolean", "bool", "bit":
   510  			c.Type = "bool"
   511  		case "date", "datetime", "datetime2", "datetimeoffset", "smalldatetime", "time":
   512  			c.Type = "time.Time"
   513  		case "binary", "varbinary":
   514  			c.Type = "[]byte"
   515  		case "timestamp", "rowversion":
   516  			c.Type = "[]byte"
   517  		case "xml":
   518  			c.Type = "string"
   519  		case "uniqueidentifier":
   520  			c.Type = "mssql.UniqueIdentifier"
   521  			c.DBType = "uuid"
   522  		case "numeric", "decimal", "dec":
   523  			c.Type = "types.Decimal"
   524  		default:
   525  			c.Type = "string"
   526  		}
   527  	}
   528  
   529  	return c
   530  }
   531  
   532  // Imports returns important imports for the driver
   533  func (MSSQLDriver) Imports() (col importers.Collection, err error) {
   534  	col.All = importers.Set{
   535  		Standard: importers.List{
   536  			`"strconv"`,
   537  		},
   538  	}
   539  	col.Singleton = importers.Map{
   540  		"mssql_upsert": {
   541  			Standard: importers.List{
   542  				`"fmt"`,
   543  				`"strings"`,
   544  			},
   545  			ThirdParty: importers.List{
   546  				`"github.com/volatiletech/strmangle"`,
   547  				`"github.com/volatiletech/sqlboiler/v4/drivers"`,
   548  			},
   549  		},
   550  	}
   551  	col.TestSingleton = importers.Map{
   552  		"mssql_suites_test": {
   553  			Standard: importers.List{
   554  				`"testing"`,
   555  			},
   556  		},
   557  		"mssql_main_test": {
   558  			Standard: importers.List{
   559  				`"bytes"`,
   560  				`"database/sql"`,
   561  				`"fmt"`,
   562  				`"os"`,
   563  				`"os/exec"`,
   564  				`"regexp"`,
   565  				`"strings"`,
   566  			},
   567  			ThirdParty: importers.List{
   568  				`"github.com/kat-co/vala"`,
   569  				`"github.com/friendsofgo/errors"`,
   570  				`"github.com/spf13/viper"`,
   571  				`"github.com/volatiletech/sqlboiler/v4/drivers/sqlboiler-mssql/driver"`,
   572  				`"github.com/volatiletech/randomize"`,
   573  				`_ "github.com/microsoft/go-mssqldb"`,
   574  			},
   575  		},
   576  	}
   577  
   578  	col.BasedOnType = importers.Map{
   579  		"null.Float32": {
   580  			ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`},
   581  		},
   582  		"null.Float64": {
   583  			ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`},
   584  		},
   585  		"null.Int": {
   586  			ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`},
   587  		},
   588  		"null.Int8": {
   589  			ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`},
   590  		},
   591  		"null.Int16": {
   592  			ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`},
   593  		},
   594  		"null.Int32": {
   595  			ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`},
   596  		},
   597  		"null.Int64": {
   598  			ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`},
   599  		},
   600  		"null.Uint": {
   601  			ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`},
   602  		},
   603  		"null.Uint8": {
   604  			ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`},
   605  		},
   606  		"null.Uint16": {
   607  			ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`},
   608  		},
   609  		"null.Uint32": {
   610  			ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`},
   611  		},
   612  		"null.Uint64": {
   613  			ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`},
   614  		},
   615  		"null.String": {
   616  			ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`},
   617  		},
   618  		"null.Bool": {
   619  			ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`},
   620  		},
   621  		"null.Time": {
   622  			ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`},
   623  		},
   624  		"null.Bytes": {
   625  			ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`},
   626  		},
   627  		"time.Time": {
   628  			Standard: importers.List{`"time"`},
   629  		},
   630  		"types.Decimal": {
   631  			Standard: importers.List{`"github.com/volatiletech/sqlboiler/v4/types"`},
   632  		},
   633  		"types.NullDecimal": {
   634  			Standard: importers.List{`"github.com/volatiletech/sqlboiler/v4/types"`},
   635  		},
   636  		"mssql.UniqueIdentifier": {
   637  			Standard: importers.List{`"github.com/microsoft/go-mssqldb"`},
   638  		},
   639  	}
   640  	return col, err
   641  }