github.com/GuanceCloud/cliutils@v1.1.21/pipeline/ptinput/refertable/table_sqlite_other.go (about)

     1  // Unless explicitly stated otherwise all files in this repository are licensed
     2  // under the MIT License.
     3  // This product includes software developed at Guance Cloud (https://www.guance.com/).
     4  // Copyright 2021-present Guance, Inc.
     5  
     6  //go:build !(windows && 386)
     7  // +build !windows !386
     8  
     9  package refertable
    10  
    11  import (
    12  	"database/sql"
    13  	"errors"
    14  	"fmt"
    15  	"strings"
    16  
    17  	_ "modernc.org/sqlite"
    18  )
    19  
    20  type PlReferTablesSqlite struct {
    21  	tableNames []string
    22  	db         *sql.DB
    23  }
    24  
    25  func (p *PlReferTablesSqlite) Query(tableName string, colName []string, colValue []any, kGet []string) (map[string]any, bool) {
    26  	if p.db == nil {
    27  		return nil, false
    28  	}
    29  	query := buildSelectStmt(tableName, colName, kGet)
    30  	l.Debugf("got SQL statement '%s' with params %v", query, colValue)
    31  
    32  	result, err := p.db.Query(query, colValue...)
    33  	if err != nil {
    34  		l.Errorf("Query returned: %v", err)
    35  		return nil, false
    36  	}
    37  	defer result.Close() //nolint:errcheck
    38  	var cols []string
    39  	if len(kGet) == 0 {
    40  		// Get all columns.
    41  		cols, err = result.Columns()
    42  		if err != nil {
    43  			l.Errorf("failed to get column names: %v", err)
    44  			return nil, false
    45  		}
    46  	} else {
    47  		cols = kGet
    48  	}
    49  	nCol := len(cols)
    50  
    51  	its := make([]interface{}, nCol)
    52  	itAddrs := make([]interface{}, nCol)
    53  	for i := 0; i < nCol; i++ {
    54  		itAddrs[i] = &its[i]
    55  	}
    56  	// Scan only one row.
    57  	if result.Next() {
    58  		if err := result.Scan(itAddrs...); err != nil {
    59  			l.Errorf("failed to scan query result: %v", err)
    60  			return nil, false
    61  		}
    62  	}
    63  
    64  	ret := make(map[string]any)
    65  	for i := 0; i < nCol; i++ {
    66  		ret[cols[i]] = its[i]
    67  	}
    68  	return ret, true
    69  }
    70  
    71  func (p *PlReferTablesSqlite) updateAll(tables []referTable) (retErr error) {
    72  	if p.db == nil {
    73  		return errors.New("PlReferTablesSqlite is not initialized")
    74  	}
    75  	for _, table := range tables {
    76  		if err := table.check(); err != nil {
    77  			return err
    78  		}
    79  	}
    80  
    81  	tx, err := p.db.Begin()
    82  	if err != nil {
    83  		return fmt.Errorf("failed to start a TX: %w", err)
    84  	}
    85  	defer func() {
    86  		if retErr != nil {
    87  			if err := tx.Rollback(); err != nil {
    88  				l.Errorf("failed to rollback TX: %v", err)
    89  			}
    90  		} else {
    91  			if err := tx.Commit(); err != nil {
    92  				l.Errorf("failed to commit TX: %v", err)
    93  			}
    94  		}
    95  	}()
    96  
    97  	// Drop deprecated tables.
    98  	for _, t := range p.tableNames {
    99  		dropStmt := fmt.Sprintf("DROP TABLE IF EXISTS %s", t)
   100  		if _, err := tx.Exec(dropStmt); err != nil {
   101  			return fmt.Errorf("failed to execute '%s': %w", dropStmt, err)
   102  		}
   103  	}
   104  	for _, t := range tables {
   105  		dropStmt := fmt.Sprintf("DROP TABLE IF EXISTS %s", t.TableName)
   106  		if _, err := tx.Exec(dropStmt); err != nil {
   107  			return fmt.Errorf("failed to execute '%s': %w", dropStmt, err)
   108  		}
   109  	}
   110  
   111  	// Create new tables.
   112  	for i := range tables {
   113  		createStmt := buildCreateTableStmt(&tables[i])
   114  		if _, err := tx.Exec(createStmt); err != nil {
   115  			return fmt.Errorf("failed to execute '%s': %w", createStmt, err)
   116  		}
   117  	}
   118  
   119  	// Insert tuples into these tables.
   120  	for i := range tables {
   121  		insertStmt := buildInsertIntoStmts(&tables[i])
   122  		for j := 0; j < len(tables[i].RowData); j++ {
   123  			if _, err := tx.Exec(insertStmt, tables[i].RowData[j]...); err != nil {
   124  				return fmt.Errorf("failed to execute '%s' with params %v: %w", insertStmt, tables[i].RowData[j], err)
   125  			}
   126  		}
   127  	}
   128  
   129  	// Update table list.
   130  	p.tableNames = []string{}
   131  	for _, t := range tables {
   132  		p.tableNames = append(p.tableNames, t.TableName)
   133  	}
   134  
   135  	return nil
   136  }
   137  
   138  func (p *PlReferTablesSqlite) Stats() *ReferTableStats {
   139  	if p.db == nil {
   140  		return nil
   141  	}
   142  	var (
   143  		res    ReferTableStats
   144  		numRow int
   145  	)
   146  	for _, tableName := range p.tableNames {
   147  		if err := p.db.QueryRow(fmt.Sprintf("SELECT COUNT(*) FROM %s", tableName)).Scan(&numRow); err != nil {
   148  			l.Errorf("Query retuned: %v", err)
   149  			return nil
   150  		}
   151  		res.Row = append(res.Row, numRow)
   152  	}
   153  	return &res
   154  }
   155  
   156  func buildSelectStmt(tableName string, colName []string, kGet []string) string {
   157  	var query, items string
   158  
   159  	if len(kGet) == 0 {
   160  		items = "*"
   161  	} else {
   162  		items = strings.Join(kGet, ", ")
   163  	}
   164  
   165  	if len(colName) > 0 {
   166  		var conStmt strings.Builder
   167  		for i, c := range colName {
   168  			if i != 0 {
   169  				conStmt.WriteString(" AND ")
   170  			}
   171  			conStmt.WriteString(c + " = ?")
   172  		}
   173  		query = fmt.Sprintf("SELECT %s FROM %s WHERE %s", items, tableName, conStmt.String())
   174  	} else {
   175  		query = fmt.Sprintf("SELECT %s FROM %s", items, tableName)
   176  	}
   177  	return query
   178  }
   179  
   180  func buildCreateTableStmt(table *referTable) string {
   181  	var res strings.Builder
   182  	res.WriteString("CREATE TABLE " + table.TableName + " (")
   183  	for i, colName := range table.ColumnName {
   184  		if i != 0 {
   185  			res.WriteString(", ")
   186  		}
   187  		res.WriteString(colName)
   188  		res.WriteString(" " + ColType2SqliteType(table.ColumnType[i]))
   189  	}
   190  	res.WriteString(")")
   191  	return res.String()
   192  }
   193  
   194  func buildInsertIntoStmts(table *referTable) string {
   195  	var sb strings.Builder
   196  	sb.WriteString("INSERT INTO " + table.TableName + " (")
   197  	for i, colName := range table.ColumnName {
   198  		if i != 0 {
   199  			sb.WriteString(", ")
   200  		}
   201  		sb.WriteString(colName)
   202  	}
   203  	sb.WriteString(") VALUES (")
   204  	for i := range table.ColumnName {
   205  		if i != 0 {
   206  			sb.WriteString(", ")
   207  		}
   208  		sb.WriteByte('?')
   209  	}
   210  	sb.WriteByte(')')
   211  	return sb.String()
   212  }
   213  
   214  func ColType2SqliteType(typeName string) string {
   215  	switch typeName {
   216  	case columnTypeStr:
   217  		return "TEXT"
   218  	case columnTypeFloat:
   219  		return "REAL"
   220  	case columnTypeInt:
   221  		return "INTEGER"
   222  	case columnTypeBool:
   223  		return "NUMERIC"
   224  	default:
   225  		return ""
   226  	}
   227  }