github.com/pyroscope-io/pyroscope@v0.37.3-0.20230725203016-5f6947968bd0/pkg/sqlstore/sqlstore.go (about)

     1  package sqlstore
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"errors"
     7  	"fmt"
     8  	"path/filepath"
     9  
    10  	"gorm.io/driver/sqlite"
    11  	"gorm.io/gorm"
    12  	"gorm.io/gorm/logger"
    13  
    14  	"github.com/pyroscope-io/pyroscope/pkg/config"
    15  	"github.com/pyroscope-io/pyroscope/pkg/fs"
    16  	"github.com/pyroscope-io/pyroscope/pkg/sqlstore/migrations"
    17  )
    18  
    19  type SQLStore struct {
    20  	config *config.Server
    21  
    22  	db  *sql.DB
    23  	orm *gorm.DB
    24  }
    25  
    26  func Open(c *config.Server) (*SQLStore, error) {
    27  	s := SQLStore{config: c}
    28  	var err error
    29  	switch s.config.Database.Type {
    30  	case "sqlite3":
    31  		err = s.openSQLiteDB()
    32  	default:
    33  		return nil, errors.New("unknown db type")
    34  	}
    35  	if err != nil {
    36  		return nil, fmt.Errorf("failed to connect database: %w", err)
    37  	}
    38  	if err = s.Ping(context.Background()); err != nil {
    39  		return nil, err
    40  	}
    41  	if err = migrations.Migrate(s.orm, c); err != nil {
    42  		return nil, err
    43  	}
    44  	return &s, nil
    45  }
    46  
    47  func (s *SQLStore) DB() *gorm.DB { return s.orm }
    48  
    49  func (s *SQLStore) Close() error { return s.db.Close() }
    50  
    51  func (s *SQLStore) Ping(ctx context.Context) error {
    52  	return s.db.PingContext(ctx)
    53  }
    54  
    55  func (s *SQLStore) openSQLiteDB() (err error) {
    56  	err = fs.EnsureDirExists(s.config.StoragePath)
    57  	if err != nil {
    58  		return err
    59  	}
    60  	path := filepath.Join(s.config.StoragePath, "pyroscope.sqlite3")
    61  	if s.config.Database.URL != "" {
    62  		path = s.config.Database.URL
    63  	}
    64  	s.orm, err = gorm.Open(sqlite.Open(path), &gorm.Config{
    65  		Logger: logger.Discard,
    66  	})
    67  	s.db, err = s.orm.DB()
    68  	return err
    69  }