github.com/octohelm/storage@v0.0.0-20240516030302-1ac2cc1ea347/internal/sql/adapter/sqlite/catalog.go (about)

     1  package sqlite
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"io"
     7  	"strings"
     8  	textscanner "text/scanner"
     9  
    10  	"github.com/octohelm/storage/internal/sql/scanner"
    11  	"github.com/octohelm/storage/pkg/sqlbuilder"
    12  )
    13  
    14  func (a *sqliteAdapter) Catalog(ctx context.Context) (*sqlbuilder.Tables, error) {
    15  	cat := &sqlbuilder.Tables{}
    16  
    17  	tblSqlMaster := sqlbuilder.TableFromModel(&sqliteMaster{})
    18  
    19  	stmt := sqlbuilder.Select(tblSqlMaster.Cols()).From(tblSqlMaster)
    20  
    21  	rows, err := a.Query(ctx, stmt)
    22  	if err != nil {
    23  		return nil, err
    24  	}
    25  
    26  	schemaList := make([]sqliteMaster, 0)
    27  
    28  	if err := scanner.Scan(ctx, rows, &schemaList); err != nil {
    29  		return nil, err
    30  	}
    31  
    32  	for _, schema := range schemaList {
    33  		if schema.Type == "table" {
    34  			table := cat.Table(schema.Table)
    35  			if table == nil {
    36  				table = sqlbuilder.T(schema.Table)
    37  				cat.Add(table)
    38  			}
    39  
    40  			cols := extractCols(bytes.NewBufferString(schema.SQL))
    41  			for f, colSql := range cols {
    42  				if f == "PRIMARY" {
    43  					continue
    44  				}
    45  
    46  				def := sqlbuilder.ColumnDef{}
    47  
    48  				if pkSQL, ok := cols["PRIMARY"]; ok {
    49  					def.AutoIncrement = strings.Contains(pkSQL, f)
    50  				}
    51  
    52  				defaultValue := ""
    53  				parts := strings.Split(colSql, " DEFAULT ")
    54  				if len(parts) > 1 {
    55  					defaultValue = parts[1]
    56  					def.Default = &defaultValue
    57  				}
    58  
    59  				def.Null = !strings.Contains(parts[0], "NOT NULL")
    60  
    61  				if !def.Null {
    62  					def.DataType = strings.TrimSpace(strings.Split(parts[0], "NOT NULL")[0])
    63  				} else {
    64  					def.DataType = strings.TrimSpace(strings.Split(parts[0], "NULL")[0])
    65  				}
    66  
    67  				table.(sqlbuilder.ColumnCollectionManger).AddCol(sqlbuilder.Col(f, sqlbuilder.ColDef(def)))
    68  			}
    69  
    70  		}
    71  	}
    72  
    73  	for _, schema := range schemaList {
    74  		if schema.Type == "index" && schema.SQL != "" {
    75  			table := cat.Table(schema.Table)
    76  
    77  			indexName := strings.ToLower(schema.Name[len(table.TableName())+1:])
    78  			isUnique := strings.Contains(schema.SQL, "UNIQUE")
    79  			indexColNameAndOptions := strings.Split(
    80  				strings.TrimSpace(schema.SQL[strings.Index(schema.SQL, "(")+1:strings.Index(schema.SQL, ")")]),
    81  				",",
    82  			)
    83  
    84  			var key sqlbuilder.Key
    85  
    86  			if isUnique {
    87  				key = sqlbuilder.UniqueIndex(indexName, nil, sqlbuilder.IndexColNameAndOptions(indexColNameAndOptions...))
    88  			} else {
    89  				key = sqlbuilder.Index(indexName, nil, sqlbuilder.IndexColNameAndOptions(indexColNameAndOptions...))
    90  			}
    91  
    92  			table.(sqlbuilder.KeyCollectionManager).AddKey(key)
    93  		}
    94  	}
    95  
    96  	return cat, nil
    97  }
    98  
    99  type sqliteMaster struct {
   100  	Type  string `db:"type"` // index or table
   101  	Name  string `db:"name"`
   102  	SQL   string `db:"sql"`
   103  	Table string `db:"tbl_name"` // on <Table>
   104  }
   105  
   106  func (sqliteMaster) TableName() string {
   107  	return "sqlite_master"
   108  }
   109  
   110  func extractCols(r io.Reader) map[string]string {
   111  	s := &textscanner.Scanner{}
   112  	s.Init(r)
   113  	s.Error = func(s *textscanner.Scanner, msg string) {}
   114  
   115  	scope := 0
   116  	cols := make(map[string]string)
   117  	parts := make([]string, 0)
   118  
   119  	collect := func() {
   120  		if len(parts) == 0 || scope != 1 {
   121  			return
   122  		}
   123  		cols[parts[0]] = strings.Join(parts[1:], " ")
   124  		parts = make([]string, 0)
   125  	}
   126  
   127  	for tok := s.Scan(); tok != textscanner.EOF; tok = s.Scan() {
   128  		part := s.TokenText()
   129  
   130  		switch part {
   131  		case "(":
   132  			scope++
   133  			if scope == 1 {
   134  				continue
   135  			}
   136  		case ")":
   137  			collect()
   138  			scope--
   139  		case ",":
   140  			collect()
   141  			continue
   142  		}
   143  
   144  		if scope > 0 {
   145  			parts = append(parts, part)
   146  		}
   147  	}
   148  
   149  	return cols
   150  }