github.com/gogf/gf@v1.16.9/database/gdb/gdb_driver_mssql.go (about)

     1  // Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
     2  //
     3  // This Source Code Form is subject to the terms of the MIT License.
     4  // If a copy of the MIT was not distributed with this file,
     5  // You can obtain one at https://github.com/gogf/gf.
     6  //
     7  // Note:
     8  // 1. It needs manually import: _ "github.com/denisenkom/go-mssqldb"
     9  // 2. It does not support Save/Replace features.
    10  // 3. It does not support LastInsertId.
    11  
    12  package gdb
    13  
    14  import (
    15  	"context"
    16  	"database/sql"
    17  	"fmt"
    18  	"github.com/gogf/gf/errors/gcode"
    19  	"strconv"
    20  	"strings"
    21  
    22  	"github.com/gogf/gf/errors/gerror"
    23  
    24  	"github.com/gogf/gf/internal/intlog"
    25  	"github.com/gogf/gf/text/gstr"
    26  
    27  	"github.com/gogf/gf/text/gregex"
    28  )
    29  
    30  // DriverMssql is the driver for SQL server database.
    31  type DriverMssql struct {
    32  	*Core
    33  }
    34  
    35  // New creates and returns a database object for SQL server.
    36  // It implements the interface of gdb.Driver for extra database driver installation.
    37  func (d *DriverMssql) New(core *Core, node *ConfigNode) (DB, error) {
    38  	return &DriverMssql{
    39  		Core: core,
    40  	}, nil
    41  }
    42  
    43  // Open creates and returns a underlying sql.DB object for mssql.
    44  func (d *DriverMssql) Open(config *ConfigNode) (*sql.DB, error) {
    45  	source := ""
    46  	if config.Link != "" {
    47  		source = config.Link
    48  	} else {
    49  		source = fmt.Sprintf(
    50  			"user id=%s;password=%s;server=%s;port=%s;database=%s;encrypt=disable",
    51  			config.User, config.Pass, config.Host, config.Port, config.Name,
    52  		)
    53  	}
    54  	intlog.Printf(d.GetCtx(), "Open: %s", source)
    55  	if db, err := sql.Open("sqlserver", source); err == nil {
    56  		return db, nil
    57  	} else {
    58  		return nil, err
    59  	}
    60  }
    61  
    62  // FilteredLink retrieves and returns filtered `linkInfo` that can be using for
    63  // logging or tracing purpose.
    64  func (d *DriverMssql) FilteredLink() string {
    65  	linkInfo := d.GetConfig().Link
    66  	if linkInfo == "" {
    67  		return ""
    68  	}
    69  	s, _ := gregex.ReplaceString(
    70  		`(.+);\s*password=(.+);\s*server=(.+)`,
    71  		`$1;password=xxx;server=$3`,
    72  		d.GetConfig().Link,
    73  	)
    74  	return s
    75  }
    76  
    77  // GetChars returns the security char for this type of database.
    78  func (d *DriverMssql) GetChars() (charLeft string, charRight string) {
    79  	return "\"", "\""
    80  }
    81  
    82  // DoCommit deals with the sql string before commits it to underlying sql driver.
    83  func (d *DriverMssql) DoCommit(ctx context.Context, link Link, sql string, args []interface{}) (newSql string, newArgs []interface{}, err error) {
    84  	defer func() {
    85  		newSql, newArgs, err = d.Core.DoCommit(ctx, link, newSql, newArgs)
    86  	}()
    87  	var index int
    88  	// Convert place holder char '?' to string "@px".
    89  	str, _ := gregex.ReplaceStringFunc("\\?", sql, func(s string) string {
    90  		index++
    91  		return fmt.Sprintf("@p%d", index)
    92  	})
    93  	str, _ = gregex.ReplaceString("\"", "", str)
    94  	return d.parseSql(str), args, nil
    95  }
    96  
    97  // parseSql does some replacement of the sql before commits it to underlying driver,
    98  // for support of microsoft sql server.
    99  func (d *DriverMssql) parseSql(sql string) string {
   100  	// SELECT * FROM USER WHERE ID=1 LIMIT 1
   101  	if m, _ := gregex.MatchString(`^SELECT(.+)LIMIT 1$`, sql); len(m) > 1 {
   102  		return fmt.Sprintf(`SELECT TOP 1 %s`, m[1])
   103  	}
   104  	// SELECT * FROM USER WHERE AGE>18 ORDER BY ID DESC LIMIT 100, 200
   105  	patten := `^\s*(?i)(SELECT)|(LIMIT\s*(\d+)\s*,\s*(\d+))`
   106  	if gregex.IsMatchString(patten, sql) == false {
   107  		return sql
   108  	}
   109  	res, err := gregex.MatchAllString(patten, sql)
   110  	if err != nil {
   111  		return ""
   112  	}
   113  	index := 0
   114  	keyword := strings.TrimSpace(res[index][0])
   115  	keyword = strings.ToUpper(keyword)
   116  	index++
   117  	switch keyword {
   118  	case "SELECT":
   119  		// LIMIT statement checks.
   120  		if len(res) < 2 ||
   121  			(strings.HasPrefix(res[index][0], "LIMIT") == false &&
   122  				strings.HasPrefix(res[index][0], "limit") == false) {
   123  			break
   124  		}
   125  		if gregex.IsMatchString("((?i)SELECT)(.+)((?i)LIMIT)", sql) == false {
   126  			break
   127  		}
   128  		// ORDER BY statement checks.
   129  		selectStr := ""
   130  		orderStr := ""
   131  		haveOrder := gregex.IsMatchString("((?i)SELECT)(.+)((?i)ORDER BY)", sql)
   132  		if haveOrder {
   133  			queryExpr, _ := gregex.MatchString("((?i)SELECT)(.+)((?i)ORDER BY)", sql)
   134  			if len(queryExpr) != 4 ||
   135  				strings.EqualFold(queryExpr[1], "SELECT") == false ||
   136  				strings.EqualFold(queryExpr[3], "ORDER BY") == false {
   137  				break
   138  			}
   139  			selectStr = queryExpr[2]
   140  			orderExpr, _ := gregex.MatchString("((?i)ORDER BY)(.+)((?i)LIMIT)", sql)
   141  			if len(orderExpr) != 4 ||
   142  				strings.EqualFold(orderExpr[1], "ORDER BY") == false ||
   143  				strings.EqualFold(orderExpr[3], "LIMIT") == false {
   144  				break
   145  			}
   146  			orderStr = orderExpr[2]
   147  		} else {
   148  			queryExpr, _ := gregex.MatchString("((?i)SELECT)(.+)((?i)LIMIT)", sql)
   149  			if len(queryExpr) != 4 ||
   150  				strings.EqualFold(queryExpr[1], "SELECT") == false ||
   151  				strings.EqualFold(queryExpr[3], "LIMIT") == false {
   152  				break
   153  			}
   154  			selectStr = queryExpr[2]
   155  		}
   156  		first, limit := 0, 0
   157  		for i := 1; i < len(res[index]); i++ {
   158  			if len(strings.TrimSpace(res[index][i])) == 0 {
   159  				continue
   160  			}
   161  
   162  			if strings.HasPrefix(res[index][i], "LIMIT") ||
   163  				strings.HasPrefix(res[index][i], "limit") {
   164  				first, _ = strconv.Atoi(res[index][i+1])
   165  				limit, _ = strconv.Atoi(res[index][i+2])
   166  				break
   167  			}
   168  		}
   169  		if haveOrder {
   170  			sql = fmt.Sprintf(
   171  				"SELECT * FROM "+
   172  					"(SELECT ROW_NUMBER() OVER (ORDER BY %s) as ROWNUMBER_, %s ) as TMP_ "+
   173  					"WHERE TMP_.ROWNUMBER_ > %d AND TMP_.ROWNUMBER_ <= %d",
   174  				orderStr, selectStr, first, first+limit,
   175  			)
   176  		} else {
   177  			if first == 0 {
   178  				first = limit
   179  			}
   180  			sql = fmt.Sprintf(
   181  				"SELECT * FROM (SELECT TOP %d * FROM (SELECT TOP %d %s) as TMP1_ ) as TMP2_ ",
   182  				limit, first+limit, selectStr,
   183  			)
   184  		}
   185  	default:
   186  	}
   187  	return sql
   188  }
   189  
   190  // Tables retrieves and returns the tables of current schema.
   191  // It's mainly used in cli tool chain for automatically generating the models.
   192  func (d *DriverMssql) Tables(ctx context.Context, schema ...string) (tables []string, err error) {
   193  	var result Result
   194  	link, err := d.SlaveLink(schema...)
   195  	if err != nil {
   196  		return nil, err
   197  	}
   198  
   199  	result, err = d.DoGetAll(ctx, link, `SELECT NAME FROM SYSOBJECTS WHERE XTYPE='U' AND STATUS >= 0 ORDER BY NAME`)
   200  	if err != nil {
   201  		return
   202  	}
   203  	for _, m := range result {
   204  		for _, v := range m {
   205  			tables = append(tables, v.String())
   206  		}
   207  	}
   208  	return
   209  }
   210  
   211  // TableFields retrieves and returns the fields information of specified table of current schema.
   212  //
   213  // Also see DriverMysql.TableFields.
   214  func (d *DriverMssql) TableFields(ctx context.Context, table string, schema ...string) (fields map[string]*TableField, err error) {
   215  	charL, charR := d.GetChars()
   216  	table = gstr.Trim(table, charL+charR)
   217  	if gstr.Contains(table, " ") {
   218  		return nil, gerror.NewCode(gcode.CodeInvalidParameter, "function TableFields supports only single table operations")
   219  	}
   220  	useSchema := d.db.GetSchema()
   221  	if len(schema) > 0 && schema[0] != "" {
   222  		useSchema = schema[0]
   223  	}
   224  	tableFieldsCacheKey := fmt.Sprintf(
   225  		`mssql_table_fields_%s_%s@group:%s`,
   226  		table, useSchema, d.GetGroup(),
   227  	)
   228  	v := tableFieldsMap.GetOrSetFuncLock(tableFieldsCacheKey, func() interface{} {
   229  		var (
   230  			result    Result
   231  			link, err = d.SlaveLink(useSchema)
   232  		)
   233  		if err != nil {
   234  			return nil
   235  		}
   236  		structureSql := fmt.Sprintf(`
   237  SELECT 
   238  	a.name Field,
   239  	CASE b.name 
   240  		WHEN 'datetime' THEN 'datetime'
   241  		WHEN 'numeric' THEN b.name + '(' + convert(varchar(20), a.xprec) + ',' + convert(varchar(20), a.xscale) + ')' 
   242  		WHEN 'char' THEN b.name + '(' + convert(varchar(20), a.length)+ ')'
   243  		WHEN 'varchar' THEN b.name + '(' + convert(varchar(20), a.length)+ ')'
   244  		ELSE b.name + '(' + convert(varchar(20),a.length)+ ')' END AS Type,
   245  	CASE WHEN a.isnullable=1 THEN 'YES' ELSE 'NO' end AS [Null],
   246  	CASE WHEN exists (
   247  		SELECT 1 FROM sysobjects WHERE xtype='PK' AND name IN (
   248  			SELECT name FROM sysindexes WHERE indid IN (
   249  				SELECT indid FROM sysindexkeys WHERE id = a.id AND colid=a.colid
   250  			)
   251  		)
   252  	) THEN 'PRI' ELSE '' END AS [Key],
   253  	CASE WHEN COLUMNPROPERTY(a.id,a.name,'IsIdentity')=1 THEN 'auto_increment' ELSE '' END Extra,
   254  	isnull(e.text,'') AS [Default],
   255  	isnull(g.[value],'') AS [Comment]
   256  FROM syscolumns a
   257  LEFT JOIN systypes b ON a.xtype=b.xtype AND a.xusertype=b.xusertype
   258  INNER JOIN sysobjects d ON a.id=d.id AND d.xtype='U' AND d.name<>'dtproperties'
   259  LEFT JOIN syscomments e ON a.cdefault=e.id
   260  LEFT JOIN sys.extended_properties g ON a.id=g.major_id AND a.colid=g.minor_id
   261  LEFT JOIN sys.extended_properties f ON d.id=f.major_id AND f.minor_id =0
   262  WHERE d.name='%s'
   263  ORDER BY a.id,a.colorder`,
   264  			strings.ToUpper(table),
   265  		)
   266  		structureSql, _ = gregex.ReplaceString(`[\n\r\s]+`, " ", gstr.Trim(structureSql))
   267  		result, err = d.DoGetAll(ctx, link, structureSql)
   268  		if err != nil {
   269  			return nil
   270  		}
   271  		fields = make(map[string]*TableField)
   272  		for i, m := range result {
   273  			fields[strings.ToLower(m["Field"].String())] = &TableField{
   274  				Index:   i,
   275  				Name:    strings.ToLower(m["Field"].String()),
   276  				Type:    strings.ToLower(m["Type"].String()),
   277  				Null:    m["Null"].Bool(),
   278  				Key:     m["Key"].String(),
   279  				Default: m["Default"].Val(),
   280  				Extra:   m["Extra"].String(),
   281  				Comment: m["Comment"].String(),
   282  			}
   283  		}
   284  		return fields
   285  	})
   286  	if v != nil {
   287  		fields = v.(map[string]*TableField)
   288  	}
   289  	return
   290  }
   291  
   292  // DoInsert is not supported in mssql.
   293  func (d *DriverMssql) DoInsert(ctx context.Context, link Link, table string, list List, option DoInsertOption) (result sql.Result, err error) {
   294  	switch option.InsertOption {
   295  	case insertOptionSave:
   296  		return nil, gerror.NewCode(gcode.CodeNotSupported, `Save operation is not supported by mssql driver`)
   297  
   298  	case insertOptionReplace:
   299  		return nil, gerror.NewCode(gcode.CodeNotSupported, `Replace operation is not supported by mssql driver`)
   300  
   301  	default:
   302  		return d.Core.DoInsert(ctx, link, table, list, option)
   303  	}
   304  }