github.com/easysoft/zendata@v0.0.0-20240513203326-705bd5a7fd67/internal/pkg/service/output-sql.go (about)

     1  package service
     2  
     3  import (
     4  	"fmt"
     5  	"strings"
     6  
     7  	consts "github.com/easysoft/zendata/internal/pkg/const"
     8  	"github.com/easysoft/zendata/internal/pkg/helper"
     9  	logUtils "github.com/easysoft/zendata/pkg/utils/log"
    10  	"github.com/easysoft/zendata/pkg/utils/vari"
    11  )
    12  
    13  func (s *OutputService) GenSql() {
    14  	records := s.GenRecords()
    15  
    16  	lines := make([]interface{}, 0)
    17  
    18  	sqlHeader := s.getInsertSqlHeader()
    19  	if vari.GlobalVars.DBDsn != "" {
    20  		lines = append(lines, sqlHeader)
    21  	}
    22  
    23  	logUtils.PrintLine(sqlHeader)
    24  
    25  	for index, record := range records {
    26  		valuesForSql := make([]string, 0)
    27  
    28  		for j, colName := range vari.GlobalVars.ExportFields {
    29  			colVal := fmt.Sprintf("%v", record[colName])
    30  
    31  			if !vari.GlobalVars.ColIsNumArr[j] {
    32  				switch vari.GlobalVars.DBType {
    33  				case consts.DBTypeMysql:
    34  					colVal = "'" + helper.EscapeValueOfMysql(colVal) + "'"
    35  				case consts.DBTypeOracle:
    36  					colVal = "'" + helper.EscapeValueOfOracle(colVal) + "'"
    37  				case consts.DBTypeSqlServer:
    38  					colVal = "'" + helper.EscapeValueOfSqlServer(colVal) + "'"
    39  				default:
    40  				}
    41  			}
    42  
    43  			valuesForSql = append(valuesForSql, colVal)
    44  		}
    45  
    46  		if vari.GlobalVars.DBDsn != "" { // add to return array for sql exec
    47  			sql := strings.Join(valuesForSql, ", ")
    48  			lines = append(lines, sql)
    49  		} else {
    50  			sql := s.genSqlLine(valuesForSql)
    51  			sql = strings.Repeat(" ", len("INSERT")+1) + sql
    52  
    53  			if index < len(records)-1 {
    54  				sql += ","
    55  				logUtils.PrintLine(sql + "\n")
    56  
    57  			} else {
    58  				if vari.GlobalVars.DBType == consts.DBTypeSqlServer {
    59  					logUtils.PrintLine(sql + "; GO")
    60  				} else {
    61  					logUtils.PrintLine(sql + ";")
    62  				}
    63  			}
    64  		}
    65  	}
    66  
    67  	logUtils.PrintLine("\n")
    68  
    69  	return
    70  }
    71  
    72  // return Table (column1, column2, ...)
    73  func (s *OutputService) getInsertSqlHeader() string {
    74  	fieldNames := make([]string, 0)
    75  
    76  	for _, f := range vari.GlobalVars.ExportFields {
    77  		if vari.GlobalVars.DBType == consts.DBTypeMysql {
    78  			f = "`" + helper.EscapeColumnOfMysql(f) + "`"
    79  		} else if vari.GlobalVars.DBType == consts.DBTypeOracle {
    80  			f = `"` + f + `"`
    81  		} else if vari.GlobalVars.DBType == consts.DBTypeSqlServer {
    82  			f = "[" + helper.EscapeColumnOfSqlServer(f) + "]"
    83  		}
    84  
    85  		fieldNames = append(fieldNames, f)
    86  	}
    87  
    88  	var ret string
    89  	switch vari.GlobalVars.DBType {
    90  	case consts.DBTypeMysql:
    91  		ret = fmt.Sprintf("`%s` (%s)", vari.GlobalVars.Table, strings.Join(fieldNames, ", "))
    92  	case consts.DBTypeOracle:
    93  		ret = fmt.Sprintf(`"%s" (%s)`, vari.GlobalVars.Table, strings.Join(fieldNames, ", "))
    94  	case consts.DBTypeSqlServer:
    95  		ret = fmt.Sprintf("[%s] (%s)", vari.GlobalVars.Table, strings.Join(fieldNames, ", "))
    96  	default:
    97  	}
    98  
    99  	ret = "INSERT INTO " + ret + " VALUES\n"
   100  
   101  	return ret
   102  }
   103  
   104  func (s *OutputService) genSqlLine(values []string) string {
   105  	return "(" + strings.Join(values, ",") + ")"
   106  }