github.com/GuanceCloud/cliutils@v1.1.21/pipeline/ptinput/refertable/table_sqlite_other.go (about) 1 // Unless explicitly stated otherwise all files in this repository are licensed 2 // under the MIT License. 3 // This product includes software developed at Guance Cloud (https://www.guance.com/). 4 // Copyright 2021-present Guance, Inc. 5 6 //go:build !(windows && 386) 7 // +build !windows !386 8 9 package refertable 10 11 import ( 12 "database/sql" 13 "errors" 14 "fmt" 15 "strings" 16 17 _ "modernc.org/sqlite" 18 ) 19 20 type PlReferTablesSqlite struct { 21 tableNames []string 22 db *sql.DB 23 } 24 25 func (p *PlReferTablesSqlite) Query(tableName string, colName []string, colValue []any, kGet []string) (map[string]any, bool) { 26 if p.db == nil { 27 return nil, false 28 } 29 query := buildSelectStmt(tableName, colName, kGet) 30 l.Debugf("got SQL statement '%s' with params %v", query, colValue) 31 32 result, err := p.db.Query(query, colValue...) 33 if err != nil { 34 l.Errorf("Query returned: %v", err) 35 return nil, false 36 } 37 defer result.Close() //nolint:errcheck 38 var cols []string 39 if len(kGet) == 0 { 40 // Get all columns. 41 cols, err = result.Columns() 42 if err != nil { 43 l.Errorf("failed to get column names: %v", err) 44 return nil, false 45 } 46 } else { 47 cols = kGet 48 } 49 nCol := len(cols) 50 51 its := make([]interface{}, nCol) 52 itAddrs := make([]interface{}, nCol) 53 for i := 0; i < nCol; i++ { 54 itAddrs[i] = &its[i] 55 } 56 // Scan only one row. 57 if result.Next() { 58 if err := result.Scan(itAddrs...); err != nil { 59 l.Errorf("failed to scan query result: %v", err) 60 return nil, false 61 } 62 } 63 64 ret := make(map[string]any) 65 for i := 0; i < nCol; i++ { 66 ret[cols[i]] = its[i] 67 } 68 return ret, true 69 } 70 71 func (p *PlReferTablesSqlite) updateAll(tables []referTable) (retErr error) { 72 if p.db == nil { 73 return errors.New("PlReferTablesSqlite is not initialized") 74 } 75 for _, table := range tables { 76 if err := table.check(); err != nil { 77 return err 78 } 79 } 80 81 tx, err := p.db.Begin() 82 if err != nil { 83 return fmt.Errorf("failed to start a TX: %w", err) 84 } 85 defer func() { 86 if retErr != nil { 87 if err := tx.Rollback(); err != nil { 88 l.Errorf("failed to rollback TX: %v", err) 89 } 90 } else { 91 if err := tx.Commit(); err != nil { 92 l.Errorf("failed to commit TX: %v", err) 93 } 94 } 95 }() 96 97 // Drop deprecated tables. 98 for _, t := range p.tableNames { 99 dropStmt := fmt.Sprintf("DROP TABLE IF EXISTS %s", t) 100 if _, err := tx.Exec(dropStmt); err != nil { 101 return fmt.Errorf("failed to execute '%s': %w", dropStmt, err) 102 } 103 } 104 for _, t := range tables { 105 dropStmt := fmt.Sprintf("DROP TABLE IF EXISTS %s", t.TableName) 106 if _, err := tx.Exec(dropStmt); err != nil { 107 return fmt.Errorf("failed to execute '%s': %w", dropStmt, err) 108 } 109 } 110 111 // Create new tables. 112 for i := range tables { 113 createStmt := buildCreateTableStmt(&tables[i]) 114 if _, err := tx.Exec(createStmt); err != nil { 115 return fmt.Errorf("failed to execute '%s': %w", createStmt, err) 116 } 117 } 118 119 // Insert tuples into these tables. 120 for i := range tables { 121 insertStmt := buildInsertIntoStmts(&tables[i]) 122 for j := 0; j < len(tables[i].RowData); j++ { 123 if _, err := tx.Exec(insertStmt, tables[i].RowData[j]...); err != nil { 124 return fmt.Errorf("failed to execute '%s' with params %v: %w", insertStmt, tables[i].RowData[j], err) 125 } 126 } 127 } 128 129 // Update table list. 130 p.tableNames = []string{} 131 for _, t := range tables { 132 p.tableNames = append(p.tableNames, t.TableName) 133 } 134 135 return nil 136 } 137 138 func (p *PlReferTablesSqlite) Stats() *ReferTableStats { 139 if p.db == nil { 140 return nil 141 } 142 var ( 143 res ReferTableStats 144 numRow int 145 ) 146 for _, tableName := range p.tableNames { 147 if err := p.db.QueryRow(fmt.Sprintf("SELECT COUNT(*) FROM %s", tableName)).Scan(&numRow); err != nil { 148 l.Errorf("Query retuned: %v", err) 149 return nil 150 } 151 res.Row = append(res.Row, numRow) 152 } 153 return &res 154 } 155 156 func buildSelectStmt(tableName string, colName []string, kGet []string) string { 157 var query, items string 158 159 if len(kGet) == 0 { 160 items = "*" 161 } else { 162 items = strings.Join(kGet, ", ") 163 } 164 165 if len(colName) > 0 { 166 var conStmt strings.Builder 167 for i, c := range colName { 168 if i != 0 { 169 conStmt.WriteString(" AND ") 170 } 171 conStmt.WriteString(c + " = ?") 172 } 173 query = fmt.Sprintf("SELECT %s FROM %s WHERE %s", items, tableName, conStmt.String()) 174 } else { 175 query = fmt.Sprintf("SELECT %s FROM %s", items, tableName) 176 } 177 return query 178 } 179 180 func buildCreateTableStmt(table *referTable) string { 181 var res strings.Builder 182 res.WriteString("CREATE TABLE " + table.TableName + " (") 183 for i, colName := range table.ColumnName { 184 if i != 0 { 185 res.WriteString(", ") 186 } 187 res.WriteString(colName) 188 res.WriteString(" " + ColType2SqliteType(table.ColumnType[i])) 189 } 190 res.WriteString(")") 191 return res.String() 192 } 193 194 func buildInsertIntoStmts(table *referTable) string { 195 var sb strings.Builder 196 sb.WriteString("INSERT INTO " + table.TableName + " (") 197 for i, colName := range table.ColumnName { 198 if i != 0 { 199 sb.WriteString(", ") 200 } 201 sb.WriteString(colName) 202 } 203 sb.WriteString(") VALUES (") 204 for i := range table.ColumnName { 205 if i != 0 { 206 sb.WriteString(", ") 207 } 208 sb.WriteByte('?') 209 } 210 sb.WriteByte(')') 211 return sb.String() 212 } 213 214 func ColType2SqliteType(typeName string) string { 215 switch typeName { 216 case columnTypeStr: 217 return "TEXT" 218 case columnTypeFloat: 219 return "REAL" 220 case columnTypeInt: 221 return "INTEGER" 222 case columnTypeBool: 223 return "NUMERIC" 224 default: 225 return "" 226 } 227 }