go-micro.dev/v5@v5.12.0/store/postgres/pgx/pgx.go (about)

     1  // Licensed under the Apache License, Version 2.0 (the "License");
     2  // you may not use this file except in compliance with the License.
     3  // You may obtain a copy of the License at
     4  //
     5  //     https://www.apache.org/licenses/LICENSE-2.0
     6  //
     7  // Unless required by applicable law or agreed to in writing, software
     8  // distributed under the License is distributed on an "AS IS" BASIS,
     9  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    10  // See the License for the specific language governing permissions and
    11  // limitations under the License.
    12  
    13  // Package pgx implements the postgres store with pgx driver
    14  package pgx
    15  
    16  import (
    17  	"database/sql"
    18  	"fmt"
    19  	"net/url"
    20  	"regexp"
    21  	"strings"
    22  	"sync"
    23  	"time"
    24  
    25  	"github.com/jackc/pgx/v4"
    26  	"github.com/jackc/pgx/v4/pgxpool"
    27  	"github.com/pkg/errors"
    28  
    29  	"go-micro.dev/v5/logger"
    30  	"go-micro.dev/v5/store"
    31  )
    32  
    33  const defaultDatabase = "micro"
    34  const defaultTable = "micro"
    35  
    36  type sqlStore struct {
    37  	options store.Options
    38  	re      *regexp.Regexp
    39  	sync.Mutex
    40  	// known databases
    41  	databases map[string]DB
    42  }
    43  
    44  func (s *sqlStore) getDB(database, table string) (string, string) {
    45  	if len(database) == 0 {
    46  		if len(s.options.Database) > 0 {
    47  			database = s.options.Database
    48  		} else {
    49  			database = defaultDatabase
    50  		}
    51  	}
    52  
    53  	if len(table) == 0 {
    54  		if len(s.options.Table) > 0 {
    55  			table = s.options.Table
    56  		} else {
    57  			table = defaultTable
    58  		}
    59  	}
    60  
    61  	// store.namespace must only contain letters, numbers and underscores
    62  	database = s.re.ReplaceAllString(database, "_")
    63  	table = s.re.ReplaceAllString(table, "_")
    64  
    65  	return database, table
    66  }
    67  
    68  func (s *sqlStore) db(database, table string) (*pgxpool.Pool, Queries, error) {
    69  	s.Lock()
    70  	defer s.Unlock()
    71  
    72  	database, table = s.getDB(database, table)
    73  
    74  	if _, ok := s.databases[database]; !ok {
    75  		err := s.initDB(database)
    76  		if err != nil {
    77  			return nil, Queries{}, err
    78  		}
    79  	}
    80  	dbObj := s.databases[database]
    81  	if _, ok := dbObj.tables[table]; !ok {
    82  		err := s.initTable(database, table)
    83  		if err != nil {
    84  			return nil, Queries{}, err
    85  		}
    86  	}
    87  
    88  	return dbObj.conn, dbObj.tables[table], nil
    89  }
    90  
    91  func (s *sqlStore) initTable(database, table string) error {
    92  	db := s.databases[database].conn
    93  
    94  	_, err := db.Exec(s.options.Context, fmt.Sprintf(createTable, database, table))
    95  	if err != nil {
    96  		return errors.Wrap(err, "cannot create table")
    97  	}
    98  
    99  	_, err = db.Exec(s.options.Context, fmt.Sprintf(createMDIndex, table, database, table))
   100  	if err != nil {
   101  		return errors.Wrap(err, "cannot create metadata index")
   102  	}
   103  
   104  	_, err = db.Exec(s.options.Context, fmt.Sprintf(createExpiryIndex, table, database, table))
   105  	if err != nil {
   106  		return errors.Wrap(err, "cannot create expiry index")
   107  	}
   108  
   109  	s.databases[database].tables[table] = NewQueries(database, table)
   110  
   111  	return nil
   112  }
   113  
   114  func (s *sqlStore) initDB(database string) error {
   115  	if len(s.options.Nodes) == 0 {
   116  		s.options.Nodes = []string{"postgresql://root@localhost:26257?sslmode=disable"}
   117  	}
   118  
   119  	source := s.options.Nodes[0]
   120  	// check if it is a standard connection string eg: host=%s port=%d user=%s password=%s dbname=%s sslmode=disable
   121  	// if err is nil which means it would be a URL like postgre://xxxx?yy=zz
   122  	_, err := url.Parse(source)
   123  	if err != nil {
   124  		if !strings.Contains(source, " ") {
   125  			source = fmt.Sprintf("host=%s", source)
   126  		}
   127  	}
   128  
   129  	config, err := pgxpool.ParseConfig(source)
   130  	if err != nil {
   131  		return err
   132  	}
   133  
   134  	db, err := pgxpool.ConnectConfig(s.options.Context, config)
   135  	if err != nil {
   136  		return err
   137  	}
   138  
   139  	if err = db.Ping(s.options.Context); err != nil {
   140  		return err
   141  	}
   142  
   143  	_, err = db.Exec(s.options.Context, fmt.Sprintf(createSchema, database))
   144  	if err != nil {
   145  		return err
   146  	}
   147  
   148  	if len(database) == 0 {
   149  		if len(s.options.Database) > 0 {
   150  			database = s.options.Database
   151  		} else {
   152  			database = defaultDatabase
   153  		}
   154  	}
   155  
   156  	// save the values
   157  	s.databases[database] = DB{
   158  		conn:   db,
   159  		tables: make(map[string]Queries),
   160  	}
   161  
   162  	return nil
   163  }
   164  
   165  func (s *sqlStore) Close() error {
   166  	for _, obj := range s.databases {
   167  		obj.conn.Close()
   168  	}
   169  	return nil
   170  }
   171  
   172  func (s *sqlStore) Init(opts ...store.Option) error {
   173  	for _, o := range opts {
   174  		o(&s.options)
   175  	}
   176  	_, _, err := s.db(s.options.Database, s.options.Table)
   177  	return err
   178  }
   179  
   180  // List all the known records
   181  func (s *sqlStore) List(opts ...store.ListOption) ([]string, error) {
   182  	options := store.ListOptions{}
   183  
   184  	for _, o := range opts {
   185  		o(&options)
   186  	}
   187  	db, queries, err := s.db(options.Database, options.Table)
   188  	if err != nil {
   189  		return nil, err
   190  	}
   191  	pattern := "%"
   192  	if options.Prefix != "" {
   193  		pattern = options.Prefix + pattern
   194  	}
   195  	if options.Suffix != "" {
   196  		pattern = pattern + options.Suffix
   197  	}
   198  
   199  	var rows pgx.Rows
   200  	if options.Limit > 0 {
   201  		rows, err = db.Query(s.options.Context, queries.ListAscLimit, pattern, options.Limit, options.Offset)
   202  
   203  	} else {
   204  
   205  		rows, err = db.Query(s.options.Context, queries.ListAsc, pattern)
   206  
   207  	}
   208  	if err != nil {
   209  		if err == pgx.ErrNoRows {
   210  			return nil, nil
   211  		}
   212  		return nil, err
   213  	}
   214  	defer rows.Close()
   215  
   216  	keys := make([]string, 0, 10)
   217  	for rows.Next() {
   218  		var key string
   219  		err = rows.Scan(&key)
   220  		if err != nil {
   221  			return nil, err
   222  		}
   223  		keys = append(keys, key)
   224  	}
   225  
   226  	return keys, nil
   227  }
   228  
   229  // rowToRecord converts from pgx.Row to a store.Record
   230  func (s *sqlStore) rowToRecord(row pgx.Row) (*store.Record, error) {
   231  	var expiry *time.Time
   232  	record := &store.Record{}
   233  	metadata := make(Metadata)
   234  
   235  	if err := row.Scan(&record.Key, &record.Value, &metadata, &expiry); err != nil {
   236  		if err == sql.ErrNoRows {
   237  			return record, store.ErrNotFound
   238  		}
   239  		return nil, err
   240  	}
   241  
   242  	// set the metadata
   243  	record.Metadata = toMetadata(&metadata)
   244  	if expiry != nil {
   245  		record.Expiry = time.Until(*expiry)
   246  	}
   247  
   248  	return record, nil
   249  }
   250  
   251  // rowsToRecords converts from pgx.Rows to []*store.Record
   252  func (s *sqlStore) rowsToRecords(rows pgx.Rows) ([]*store.Record, error) {
   253  	var records []*store.Record
   254  
   255  	for rows.Next() {
   256  		var expiry *time.Time
   257  		record := &store.Record{}
   258  		metadata := make(Metadata)
   259  
   260  		if err := rows.Scan(&record.Key, &record.Value, &metadata, &expiry); err != nil {
   261  			return records, err
   262  		}
   263  
   264  		// set the metadata
   265  		record.Metadata = toMetadata(&metadata)
   266  		if expiry != nil {
   267  			record.Expiry = time.Until(*expiry)
   268  		}
   269  		records = append(records, record)
   270  	}
   271  	return records, nil
   272  }
   273  
   274  // Read a single key
   275  func (s *sqlStore) Read(key string, opts ...store.ReadOption) ([]*store.Record, error) {
   276  	options := store.ReadOptions{}
   277  	for _, o := range opts {
   278  		o(&options)
   279  	}
   280  
   281  	db, queries, err := s.db(options.Database, options.Table)
   282  	if err != nil {
   283  		return nil, err
   284  	}
   285  
   286  	// read one record
   287  	if !options.Prefix && !options.Suffix {
   288  		row := db.QueryRow(s.options.Context, queries.ReadOne, key)
   289  		record, err := s.rowToRecord(row)
   290  		if err != nil {
   291  			return nil, err
   292  		}
   293  		return []*store.Record{record}, nil
   294  	}
   295  
   296  	// read by pattern
   297  	pattern := "%"
   298  	if options.Prefix {
   299  		pattern = key + pattern
   300  	}
   301  	if options.Suffix {
   302  		pattern = pattern + key
   303  	}
   304  
   305  	var rows pgx.Rows
   306  	if options.Limit > 0 {
   307  
   308  		rows, err = db.Query(s.options.Context, queries.ListAscLimit, pattern, options.Limit, options.Offset)
   309  
   310  	} else {
   311  
   312  		rows, err = db.Query(s.options.Context, queries.ListAsc, pattern)
   313  
   314  	}
   315  	if err != nil {
   316  		if err == pgx.ErrNoRows {
   317  			return nil, nil
   318  		}
   319  		return nil, err
   320  	}
   321  	defer rows.Close()
   322  
   323  	return s.rowsToRecords(rows)
   324  }
   325  
   326  // Write records
   327  func (s *sqlStore) Write(r *store.Record, opts ...store.WriteOption) error {
   328  	var options store.WriteOptions
   329  	for _, o := range opts {
   330  		o(&options)
   331  	}
   332  
   333  	db, queries, err := s.db(options.Database, options.Table)
   334  	if err != nil {
   335  		return err
   336  	}
   337  
   338  	metadata := make(Metadata)
   339  	for k, v := range r.Metadata {
   340  		metadata[k] = v
   341  	}
   342  
   343  	if r.Expiry != 0 {
   344  		_, err = db.Exec(s.options.Context, queries.Write, r.Key, r.Value, metadata, time.Now().Add(r.Expiry))
   345  	} else {
   346  		_, err = db.Exec(s.options.Context, queries.Write, r.Key, r.Value, metadata, nil)
   347  	}
   348  	if err != nil {
   349  		return errors.Wrap(err, "cannot upsert record "+r.Key)
   350  	}
   351  
   352  	return nil
   353  }
   354  
   355  // Delete records with keys
   356  func (s *sqlStore) Delete(key string, opts ...store.DeleteOption) error {
   357  	var options store.DeleteOptions
   358  	for _, o := range opts {
   359  		o(&options)
   360  	}
   361  
   362  	db, queries, err := s.db(options.Database, options.Table)
   363  	if err != nil {
   364  		return err
   365  	}
   366  
   367  	_, err = db.Exec(s.options.Context, queries.Delete, key)
   368  	return err
   369  }
   370  
   371  func (s *sqlStore) Options() store.Options {
   372  	return s.options
   373  }
   374  
   375  func (s *sqlStore) String() string {
   376  	return "pgx"
   377  }
   378  
   379  // NewStore returns a new micro Store backed by sql
   380  func NewStore(opts ...store.Option) store.Store {
   381  	options := store.Options{
   382  		Database: defaultDatabase,
   383  		Table:    defaultTable,
   384  	}
   385  
   386  	for _, o := range opts {
   387  		o(&options)
   388  	}
   389  
   390  	// new store
   391  	s := new(sqlStore)
   392  	s.options = options
   393  	s.databases = make(map[string]DB)
   394  	s.re = regexp.MustCompile("[^a-zA-Z0-9]+")
   395  
   396  	go s.expiryLoop()
   397  	// return store
   398  	return s
   399  }
   400  
   401  func (s *sqlStore) expiryLoop() {
   402  	for {
   403  		err := s.expireRows()
   404  		if err != nil {
   405  			logger.Errorf("error cleaning up %s", err)
   406  		}
   407  		time.Sleep(1 * time.Hour)
   408  	}
   409  }
   410  
   411  func (s *sqlStore) expireRows() error {
   412  	for database, dbObj := range s.databases {
   413  		db := dbObj.conn
   414  		for table, queries := range dbObj.tables {
   415  			res, err := db.Exec(s.options.Context, queries.DeleteExpired)
   416  			if err != nil {
   417  				logger.Errorf("Error cleaning up %s", err)
   418  				return err
   419  			}
   420  			logger.Infof("Cleaning up %s %s: %d rows deleted", database, table, res.RowsAffected())
   421  		}
   422  	}
   423  
   424  	return nil
   425  }