github.com/Ali-iotechsys/sqlboiler/v4@v4.0.0-20221208124957-6aec9a5f1f71/drivers/sqlboiler-mssql/driver/mssql.go (about) 1 package driver 2 3 import ( 4 "database/sql" 5 "embed" 6 "encoding/base64" 7 "fmt" 8 "io/fs" 9 "net/url" 10 "strings" 11 12 // Side effect import go-mssqldb 13 "github.com/friendsofgo/errors" 14 _ "github.com/microsoft/go-mssqldb" 15 "github.com/volatiletech/sqlboiler/v4/drivers" 16 "github.com/volatiletech/sqlboiler/v4/importers" 17 "github.com/volatiletech/strmangle" 18 ) 19 20 //go:embed override 21 var templates embed.FS 22 23 func init() { 24 drivers.RegisterFromInit("mssql", &MSSQLDriver{}) 25 } 26 27 // Assemble is more useful for calling into the library so you don't 28 // have to instantiate an empty type. 29 func Assemble(config drivers.Config) (dbinfo *drivers.DBInfo, err error) { 30 driver := MSSQLDriver{} 31 return driver.Assemble(config) 32 } 33 34 // MSSQLDriver holds the database connection string and a handle 35 // to the database connection. 36 type MSSQLDriver struct { 37 connStr string 38 conn *sql.DB 39 } 40 41 // Templates that should be added/overridden 42 func (MSSQLDriver) Templates() (map[string]string, error) { 43 tpls := make(map[string]string) 44 fs.WalkDir(templates, "override", func(path string, d fs.DirEntry, err error) error { 45 if err != nil { 46 return err 47 } 48 49 if d.IsDir() { 50 return nil 51 } 52 53 b, err := fs.ReadFile(templates, path) 54 if err != nil { 55 return err 56 } 57 tpls[strings.Replace(path, "override/", "", 1)] = base64.StdEncoding.EncodeToString(b) 58 59 return nil 60 }) 61 62 return tpls, nil 63 } 64 65 // Assemble all the information we need to provide back to the driver 66 func (m *MSSQLDriver) Assemble(config drivers.Config) (dbinfo *drivers.DBInfo, err error) { 67 defer func() { 68 if r := recover(); r != nil && err == nil { 69 dbinfo = nil 70 err = r.(error) 71 } 72 }() 73 74 user := config.MustString(drivers.ConfigUser) 75 pass, _ := config.String(drivers.ConfigPass) 76 dbname := config.MustString(drivers.ConfigDBName) 77 host := config.MustString(drivers.ConfigHost) 78 port := config.DefaultInt(drivers.ConfigPort, 1433) 79 sslmode := config.DefaultString(drivers.ConfigSSLMode, "true") 80 81 schema := config.DefaultString(drivers.ConfigSchema, "dbo") 82 whitelist, _ := config.StringSlice(drivers.ConfigWhitelist) 83 blacklist, _ := config.StringSlice(drivers.ConfigBlacklist) 84 concurrency := config.DefaultInt(drivers.ConfigConcurrency, drivers.DefaultConcurrency) 85 86 m.connStr = MSSQLBuildQueryString(user, pass, dbname, host, port, sslmode) 87 m.conn, err = sql.Open("mssql", m.connStr) 88 if err != nil { 89 return nil, errors.Wrap(err, "sqlboiler-mssql failed to connect to database") 90 } 91 92 defer func() { 93 if e := m.conn.Close(); e != nil { 94 dbinfo = nil 95 err = e 96 } 97 }() 98 99 dbinfo = &drivers.DBInfo{ 100 Schema: schema, 101 Dialect: drivers.Dialect{ 102 LQ: '[', 103 RQ: ']', 104 105 UseIndexPlaceholders: true, 106 UseSchema: true, 107 UseDefaultKeyword: true, 108 109 UseTopClause: true, 110 UseOutputClause: true, 111 UseCaseWhenExistsClause: true, 112 }, 113 } 114 dbinfo.Tables, err = drivers.TablesConcurrently(m, schema, whitelist, blacklist, concurrency) 115 if err != nil { 116 return nil, err 117 } 118 119 return dbinfo, err 120 } 121 122 // MSSQLBuildQueryString builds a query string for MSSQL. 123 func MSSQLBuildQueryString(user, pass, dbname, host string, port int, sslmode string) string { 124 query := url.Values{} 125 query.Add("database", dbname) 126 query.Add("encrypt", sslmode) 127 128 u := &url.URL{ 129 Scheme: "sqlserver", 130 User: url.UserPassword(user, pass), 131 Host: fmt.Sprintf("%s:%d", host, port), 132 RawQuery: query.Encode(), 133 } 134 135 // If the host is an "sqlserver instance" then we set the Path not the Host 136 // so the url package doesn't escape the / 137 if strings.Contains(host, "/") { 138 u.Path = host 139 u.Host = "" 140 } 141 142 return u.String() 143 } 144 145 // TableNames connects to the postgres database and 146 // retrieves all table names from the information_schema where the 147 // table schema is schema. It uses a whitelist and blacklist. 148 func (m *MSSQLDriver) TableNames(schema string, whitelist, blacklist []string) ([]string, error) { 149 var names []string 150 151 query := ` 152 SELECT table_name 153 FROM information_schema.tables 154 WHERE table_schema = ? AND table_type = 'BASE TABLE'` 155 156 args := []interface{}{schema} 157 if len(whitelist) > 0 { 158 tables := drivers.TablesFromList(whitelist) 159 if len(tables) > 0 { 160 query += fmt.Sprintf(" AND table_name IN (%s)", strings.Repeat(",?", len(tables))[1:]) 161 for _, w := range tables { 162 args = append(args, w) 163 } 164 } 165 } else if len(blacklist) > 0 { 166 tables := drivers.TablesFromList(blacklist) 167 if len(tables) > 0 { 168 query += fmt.Sprintf(" AND table_name not IN (%s)", strings.Repeat(",?", len(tables))[1:]) 169 for _, b := range tables { 170 args = append(args, b) 171 } 172 } 173 } 174 175 query += ` ORDER BY table_name;` 176 177 rows, err := m.conn.Query(query, args...) 178 179 if err != nil { 180 return nil, err 181 } 182 183 defer rows.Close() 184 for rows.Next() { 185 var name string 186 if err := rows.Scan(&name); err != nil { 187 return nil, err 188 } 189 names = append(names, name) 190 } 191 192 return names, nil 193 } 194 195 // ViewNames connects to the postgres database and 196 // retrieves all view names from the information_schema where the 197 // view schema is schema. It uses a whitelist and blacklist. 198 func (m *MSSQLDriver) ViewNames(schema string, whitelist, blacklist []string) ([]string, error) { 199 var names []string 200 201 query := `select table_name from information_schema.views where table_schema = ?` 202 args := []interface{}{schema} 203 if len(whitelist) > 0 { 204 tables := drivers.TablesFromList(whitelist) 205 if len(tables) > 0 { 206 query += fmt.Sprintf(" and table_name in (%s)", strings.Repeat(",?", len(tables))[1:]) 207 for _, w := range tables { 208 args = append(args, w) 209 } 210 } 211 } else if len(blacklist) > 0 { 212 tables := drivers.TablesFromList(blacklist) 213 if len(tables) > 0 { 214 query += fmt.Sprintf(" and table_name not in (%s)", strings.Repeat(",?", len(tables))[1:]) 215 for _, b := range tables { 216 args = append(args, b) 217 } 218 } 219 } 220 221 query += ` order by table_name;` 222 223 rows, err := m.conn.Query(query, args...) 224 225 if err != nil { 226 return nil, err 227 } 228 229 defer rows.Close() 230 for rows.Next() { 231 var name string 232 if err := rows.Scan(&name); err != nil { 233 return nil, err 234 } 235 236 names = append(names, name) 237 } 238 239 return names, nil 240 } 241 242 // ViewCapabilities return what actions are allowed for a view. 243 func (m *MSSQLDriver) ViewCapabilities(schema, name string) (drivers.ViewCapabilities, error) { 244 // This depends on the specific query and is not possible to ensure 245 // from just the schema 246 capabilities := drivers.ViewCapabilities{ 247 CanInsert: false, 248 CanUpsert: false, 249 } 250 251 return capabilities, nil 252 } 253 254 func (m *MSSQLDriver) ViewColumns(schema, tableName string, whitelist, blacklist []string) ([]drivers.Column, error) { 255 return m.Columns(schema, tableName, whitelist, blacklist) 256 } 257 258 // Columns takes a table name and attempts to retrieve the table information 259 // from the database information_schema.columns. It retrieves the column names 260 // and column types and returns those as a []Column after TranslateColumnType() 261 // converts the SQL types to Go types, for example: "varchar" to "string" 262 func (m *MSSQLDriver) Columns(schema, tableName string, whitelist, blacklist []string) ([]drivers.Column, error) { 263 var columns []drivers.Column 264 args := []interface{}{schema, tableName} 265 query := ` 266 SELECT column_name, 267 CASE 268 WHEN character_maximum_length IS NULL THEN data_type 269 ELSE data_type + '(' + CAST(character_maximum_length AS VARCHAR) + ')' 270 END AS full_type, 271 data_type, 272 column_default, 273 CASE 274 WHEN is_nullable = 'YES' THEN 1 275 ELSE 0 276 END AS is_nullable, 277 CASE 278 WHEN EXISTS (SELECT c.column_name 279 FROM information_schema.table_constraints tc 280 INNER JOIN information_schema.key_column_usage kcu 281 ON tc.constraint_name = kcu.constraint_name 282 AND tc.table_name = kcu.table_name 283 AND tc.table_schema = kcu.table_schema 284 WHERE c.column_name = kcu.column_name 285 AND tc.table_name = c.table_name 286 AND (tc.constraint_type = 'PRIMARY KEY' OR tc.constraint_type = 'UNIQUE') 287 AND (SELECT COUNT(*) 288 FROM information_schema.key_column_usage 289 WHERE table_schema = kcu.table_schema 290 AND table_name = tc.table_name 291 AND constraint_name = tc.constraint_name) = 1) THEN 1 292 ELSE 0 293 END AS is_unique, 294 COLUMNPROPERTY(object_id($1 + '.' + $2), c.column_name, 'IsIdentity') as is_identity, 295 COLUMNPROPERTY(object_id($1 + '.' + $2), c.column_name, 'IsComputed') as is_computed 296 FROM information_schema.columns c 297 WHERE table_schema = $1 AND table_name = $2` 298 299 if len(whitelist) > 0 { 300 cols := drivers.ColumnsFromList(whitelist, tableName) 301 if len(cols) > 0 { 302 query += fmt.Sprintf(" and c.column_name in (%s)", strmangle.Placeholders(true, len(cols), 3, 1)) 303 for _, w := range cols { 304 args = append(args, w) 305 } 306 } 307 } else if len(blacklist) > 0 { 308 cols := drivers.ColumnsFromList(blacklist, tableName) 309 if len(cols) > 0 { 310 query += fmt.Sprintf(" and c.column_name not in (%s)", strmangle.Placeholders(true, len(cols), 3, 1)) 311 for _, w := range cols { 312 args = append(args, w) 313 } 314 } 315 } 316 317 query += ` ORDER BY ordinal_position;` 318 319 rows, err := m.conn.Query(query, args...) 320 if err != nil { 321 return nil, err 322 } 323 defer rows.Close() 324 325 for rows.Next() { 326 var colName, colType, colFullType string 327 var nullable, unique, identity, computed bool 328 var defaultValue *string 329 if err := rows.Scan(&colName, &colFullType, &colType, &defaultValue, &nullable, &unique, &identity, &computed); err != nil { 330 return nil, errors.Wrapf(err, "unable to scan for table %s", tableName) 331 } 332 333 computed = computed || strings.EqualFold(colType, "timestamp") || strings.EqualFold(colType, "rowversion") 334 335 column := drivers.Column{ 336 Name: colName, 337 FullDBType: colFullType, 338 DBType: colType, 339 Nullable: nullable, 340 Unique: unique, 341 AutoGenerated: computed || identity, 342 } 343 344 if defaultValue != nil { 345 column.Default = *defaultValue 346 } 347 348 // A generated column technically has a default value 349 if column.Default == "" && column.AutoGenerated { 350 column.Default = "AUTO_GENERATED" 351 } 352 353 columns = append(columns, column) 354 } 355 356 return columns, nil 357 } 358 359 // PrimaryKeyInfo looks up the primary key for a table. 360 func (m *MSSQLDriver) PrimaryKeyInfo(schema, tableName string) (*drivers.PrimaryKey, error) { 361 pkey := &drivers.PrimaryKey{} 362 var err error 363 364 query := ` 365 SELECT constraint_name 366 FROM information_schema.table_constraints 367 WHERE table_name = ? AND constraint_type = 'PRIMARY KEY' AND table_schema = ?;` 368 369 row := m.conn.QueryRow(query, tableName, schema) 370 if err = row.Scan(&pkey.Name); err != nil { 371 if errors.Is(err, sql.ErrNoRows) { 372 return nil, nil 373 } 374 return nil, err 375 } 376 377 queryColumns := ` 378 SELECT column_name 379 FROM information_schema.key_column_usage 380 WHERE table_name = ? AND constraint_name = ? AND table_schema = ? 381 ORDER BY ordinal_position;` 382 383 var rows *sql.Rows 384 if rows, err = m.conn.Query(queryColumns, tableName, pkey.Name, schema); err != nil { 385 return nil, err 386 } 387 defer rows.Close() 388 389 var columns []string 390 for rows.Next() { 391 var column string 392 393 err = rows.Scan(&column) 394 if err != nil { 395 return nil, err 396 } 397 398 columns = append(columns, column) 399 } 400 401 if err = rows.Err(); err != nil { 402 return nil, err 403 } 404 405 pkey.Columns = columns 406 407 return pkey, nil 408 } 409 410 // ForeignKeyInfo retrieves the foreign keys for a given table name. 411 func (m *MSSQLDriver) ForeignKeyInfo(schema, tableName string) ([]drivers.ForeignKey, error) { 412 var fkeys []drivers.ForeignKey 413 414 query := ` 415 SELECT ccu.constraint_name , 416 ccu.table_name AS local_table , 417 ccu.column_name AS local_column , 418 kcu.table_name AS foreign_table , 419 kcu.column_name AS foreign_column 420 FROM information_schema.constraint_column_usage ccu 421 INNER JOIN information_schema.referential_constraints rc ON ccu.constraint_name = rc.constraint_name 422 INNER JOIN information_schema.key_column_usage kcu ON kcu.constraint_name = rc.unique_constraint_name 423 WHERE ccu.table_schema = ? 424 AND ccu.constraint_schema = ? 425 AND ccu.table_name = ? 426 ORDER BY ccu.constraint_name, local_table, local_column, foreign_table, foreign_column 427 ` 428 429 var rows *sql.Rows 430 var err error 431 if rows, err = m.conn.Query(query, schema, schema, tableName); err != nil { 432 return nil, err 433 } 434 435 for rows.Next() { 436 var fkey drivers.ForeignKey 437 var sourceTable string 438 439 fkey.Table = tableName 440 err = rows.Scan(&fkey.Name, &sourceTable, &fkey.Column, &fkey.ForeignTable, &fkey.ForeignColumn) 441 if err != nil { 442 return nil, err 443 } 444 445 fkeys = append(fkeys, fkey) 446 } 447 448 if err = rows.Err(); err != nil { 449 return nil, err 450 } 451 452 return fkeys, nil 453 } 454 455 // TranslateColumnType converts postgres database types to Go types, for example 456 // "varchar" to "string" and "bigint" to "int64". It returns this parsed data 457 // as a Column object. 458 func (m *MSSQLDriver) TranslateColumnType(c drivers.Column) drivers.Column { 459 if c.Nullable { 460 switch c.DBType { 461 case "tinyint": 462 c.Type = "null.Int8" 463 case "smallint": 464 c.Type = "null.Int16" 465 case "mediumint": 466 c.Type = "null.Int32" 467 case "int": 468 c.Type = "null.Int" 469 case "bigint": 470 c.Type = "null.Int64" 471 case "real": 472 c.Type = "null.Float32" 473 case "float": 474 c.Type = "null.Float64" 475 case "boolean", "bool", "bit": 476 c.Type = "null.Bool" 477 case "date", "datetime", "datetime2", "datetimeoffset", "smalldatetime", "time": 478 c.Type = "null.Time" 479 case "binary", "varbinary": 480 c.Type = "null.Bytes" 481 case "timestamp", "rowversion": 482 c.Type = "null.Bytes" 483 case "xml": 484 c.Type = "null.String" 485 case "uniqueidentifier": 486 c.Type = "mssql.UniqueIdentifier" 487 c.DBType = "uuid" 488 case "numeric", "decimal", "dec": 489 c.Type = "types.NullDecimal" 490 default: 491 c.Type = "null.String" 492 } 493 } else { 494 switch c.DBType { 495 case "tinyint": 496 c.Type = "int8" 497 case "smallint": 498 c.Type = "int16" 499 case "mediumint": 500 c.Type = "int32" 501 case "int": 502 c.Type = "int" 503 case "bigint": 504 c.Type = "int64" 505 case "real": 506 c.Type = "float32" 507 case "float": 508 c.Type = "float64" 509 case "boolean", "bool", "bit": 510 c.Type = "bool" 511 case "date", "datetime", "datetime2", "datetimeoffset", "smalldatetime", "time": 512 c.Type = "time.Time" 513 case "binary", "varbinary": 514 c.Type = "[]byte" 515 case "timestamp", "rowversion": 516 c.Type = "[]byte" 517 case "xml": 518 c.Type = "string" 519 case "uniqueidentifier": 520 c.Type = "mssql.UniqueIdentifier" 521 c.DBType = "uuid" 522 case "numeric", "decimal", "dec": 523 c.Type = "types.Decimal" 524 default: 525 c.Type = "string" 526 } 527 } 528 529 return c 530 } 531 532 // Imports returns important imports for the driver 533 func (MSSQLDriver) Imports() (col importers.Collection, err error) { 534 col.All = importers.Set{ 535 Standard: importers.List{ 536 `"strconv"`, 537 }, 538 } 539 col.Singleton = importers.Map{ 540 "mssql_upsert": { 541 Standard: importers.List{ 542 `"fmt"`, 543 `"strings"`, 544 }, 545 ThirdParty: importers.List{ 546 `"github.com/volatiletech/strmangle"`, 547 `"github.com/volatiletech/sqlboiler/v4/drivers"`, 548 }, 549 }, 550 } 551 col.TestSingleton = importers.Map{ 552 "mssql_suites_test": { 553 Standard: importers.List{ 554 `"testing"`, 555 }, 556 }, 557 "mssql_main_test": { 558 Standard: importers.List{ 559 `"bytes"`, 560 `"database/sql"`, 561 `"fmt"`, 562 `"os"`, 563 `"os/exec"`, 564 `"regexp"`, 565 `"strings"`, 566 }, 567 ThirdParty: importers.List{ 568 `"github.com/kat-co/vala"`, 569 `"github.com/friendsofgo/errors"`, 570 `"github.com/spf13/viper"`, 571 `"github.com/volatiletech/sqlboiler/v4/drivers/sqlboiler-mssql/driver"`, 572 `"github.com/volatiletech/randomize"`, 573 `_ "github.com/microsoft/go-mssqldb"`, 574 }, 575 }, 576 } 577 578 col.BasedOnType = importers.Map{ 579 "null.Float32": { 580 ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`}, 581 }, 582 "null.Float64": { 583 ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`}, 584 }, 585 "null.Int": { 586 ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`}, 587 }, 588 "null.Int8": { 589 ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`}, 590 }, 591 "null.Int16": { 592 ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`}, 593 }, 594 "null.Int32": { 595 ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`}, 596 }, 597 "null.Int64": { 598 ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`}, 599 }, 600 "null.Uint": { 601 ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`}, 602 }, 603 "null.Uint8": { 604 ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`}, 605 }, 606 "null.Uint16": { 607 ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`}, 608 }, 609 "null.Uint32": { 610 ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`}, 611 }, 612 "null.Uint64": { 613 ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`}, 614 }, 615 "null.String": { 616 ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`}, 617 }, 618 "null.Bool": { 619 ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`}, 620 }, 621 "null.Time": { 622 ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`}, 623 }, 624 "null.Bytes": { 625 ThirdParty: importers.List{`"github.com/volatiletech/null/v8"`}, 626 }, 627 "time.Time": { 628 Standard: importers.List{`"time"`}, 629 }, 630 "types.Decimal": { 631 Standard: importers.List{`"github.com/volatiletech/sqlboiler/v4/types"`}, 632 }, 633 "types.NullDecimal": { 634 Standard: importers.List{`"github.com/volatiletech/sqlboiler/v4/types"`}, 635 }, 636 "mssql.UniqueIdentifier": { 637 Standard: importers.List{`"github.com/microsoft/go-mssqldb"`}, 638 }, 639 } 640 return col, err 641 }