go-micro.dev/v5@v5.12.0/store/postgres/postgres.go (about)

     1  // Copyright 2020 Asim Aslam
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     https://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  //
    15  // Original source: github.com/micro/go-plugins/v3/store/cockroach/cockroach.go
    16  
    17  // Package postgres implements the postgres store
    18  package postgres
    19  
    20  import (
    21  	"database/sql"
    22  	"database/sql/driver"
    23  	"fmt"
    24  	"net"
    25  	"net/url"
    26  	"regexp"
    27  	"strings"
    28  	"sync"
    29  	"syscall"
    30  	"time"
    31  
    32  	"github.com/lib/pq"
    33  	"github.com/pkg/errors"
    34  	"go-micro.dev/v5/logger"
    35  	"go-micro.dev/v5/store"
    36  )
    37  
    38  // DefaultDatabase is the namespace that the sql store
    39  // will use if no namespace is provided.
    40  var (
    41  	DefaultDatabase = "micro"
    42  	DefaultTable    = "micro"
    43  	ErrNoConnection = errors.New("Database connection not initialised")
    44  )
    45  
    46  var (
    47  	re = regexp.MustCompile("[^a-zA-Z0-9]+")
    48  
    49  	// alternative ordering
    50  	orderAsc  = "ORDER BY key ASC"
    51  	orderDesc = "ORDER BY key DESC"
    52  
    53  	// the sql statements we prepare and use
    54  	statements = map[string]string{
    55  		"list":          "SELECT key, value, metadata, expiry FROM %s.%s WHERE key LIKE $1 ORDER BY key ASC LIMIT $2 OFFSET $3;",
    56  		"read":          "SELECT key, value, metadata, expiry FROM %s.%s WHERE key = $1;",
    57  		"readMany":      "SELECT key, value, metadata, expiry FROM %s.%s WHERE key LIKE $1 ORDER BY key ASC;",
    58  		"readOffset":    "SELECT key, value, metadata, expiry FROM %s.%s WHERE key LIKE $1 ORDER BY key ASC LIMIT $2 OFFSET $3;",
    59  		"write":         "INSERT INTO %s.%s(key, value, metadata, expiry) VALUES ($1, $2::bytea, $3, $4) ON CONFLICT (key) DO UPDATE SET value = EXCLUDED.value, metadata = EXCLUDED.metadata, expiry = EXCLUDED.expiry;",
    60  		"delete":        "DELETE FROM %s.%s WHERE key = $1;",
    61  		"deleteExpired": "DELETE FROM %s.%s WHERE expiry < now();",
    62  		"showTables":    "SELECT schemaname, tablename FROM pg_catalog.pg_tables WHERE schemaname != 'pg_catalog' AND schemaname != 'information_schema';",
    63  	}
    64  )
    65  
    66  type sqlStore struct {
    67  	options store.Options
    68  	dbConn  *sql.DB
    69  
    70  	sync.RWMutex
    71  	// known databases
    72  	databases map[string]bool
    73  }
    74  
    75  func (s *sqlStore) getDB(database, table string) (string, string) {
    76  	if len(database) == 0 {
    77  		if len(s.options.Database) > 0 {
    78  			database = s.options.Database
    79  		} else {
    80  			database = DefaultDatabase
    81  		}
    82  	}
    83  
    84  	if len(table) == 0 {
    85  		if len(s.options.Table) > 0 {
    86  			table = s.options.Table
    87  		} else {
    88  			table = DefaultTable
    89  		}
    90  	}
    91  
    92  	// store.namespace must only contain letters, numbers and underscores
    93  	database = re.ReplaceAllString(database, "_")
    94  	table = re.ReplaceAllString(table, "_")
    95  
    96  	return database, table
    97  }
    98  
    99  // createDB ensures that the DB and table have been created. It's used for lazy initialisation
   100  // and will record which tables have been created to reduce calls to the DB
   101  func (s *sqlStore) createDB(database, table string) error {
   102  	database, table = s.getDB(database, table)
   103  
   104  	s.Lock()
   105  	defer s.Unlock()
   106  
   107  	if _, ok := s.databases[database+":"+table]; ok {
   108  		return nil
   109  	}
   110  
   111  	if err := s.initDB(database, table); err != nil {
   112  		return err
   113  	}
   114  
   115  	s.databases[database+":"+table] = true
   116  	return nil
   117  }
   118  
   119  // db returns a valid connection to the DB
   120  func (s *sqlStore) db() (*sql.DB, error) {
   121  	if s.dbConn == nil {
   122  		return nil, ErrNoConnection
   123  	}
   124  
   125  	if err := s.dbConn.Ping(); err != nil {
   126  		if !isBadConnError(err) {
   127  			return nil, err
   128  		}
   129  		logger.Errorf("Error with DB connection, will reconfigure: %s", err)
   130  		if err := s.configure(); err != nil {
   131  			logger.Errorf("Error while reconfiguring client: %s", err)
   132  			return nil, err
   133  		}
   134  	}
   135  
   136  	return s.dbConn, nil
   137  }
   138  
   139  // isBadConnError returns true if the error is related to having a bad connection such that you need to reconnect
   140  func isBadConnError(err error) bool {
   141  	if err == nil {
   142  		return false
   143  	}
   144  	if err == driver.ErrBadConn {
   145  		return true
   146  	}
   147  
   148  	// heavy handed crude check for "connection reset by peer"
   149  	if strings.Contains(err.Error(), syscall.ECONNRESET.Error()) {
   150  		return true
   151  	}
   152  
   153  	// otherwise iterate through the error types
   154  	switch t := err.(type) {
   155  	case syscall.Errno:
   156  		return t == syscall.ECONNRESET || t == syscall.ECONNABORTED || t == syscall.ECONNREFUSED
   157  	case *net.OpError:
   158  		return !t.Temporary()
   159  	case net.Error:
   160  		return !t.Temporary()
   161  	}
   162  
   163  	return false
   164  }
   165  
   166  func (s *sqlStore) initDB(database, table string) error {
   167  	db, err := s.db()
   168  	if err != nil {
   169  		return err
   170  	}
   171  	// Create the namespace's database
   172  	_, err = db.Exec(fmt.Sprintf("CREATE DATABASE %s;", database))
   173  	if err != nil && !strings.Contains(err.Error(), "already exists") {
   174  		return err
   175  	}
   176  
   177  	var version string
   178  	if err = db.QueryRow("select version()").Scan(&version); err == nil {
   179  		if strings.Contains(version, "PostgreSQL") {
   180  			_, err = db.Exec(fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s;", database))
   181  			if err != nil {
   182  				return err
   183  			}
   184  		}
   185  	}
   186  
   187  	// Create a table for the namespace's prefix
   188  	_, err = db.Exec(fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s.%s
   189  	(
   190  		key text NOT NULL,
   191  		value bytea,
   192  		metadata JSONB,
   193  		expiry timestamp with time zone,
   194  		CONSTRAINT %s_pkey PRIMARY KEY (key)
   195  	);`, database, table, table))
   196  	if err != nil {
   197  		return errors.Wrap(err, "Couldn't create table")
   198  	}
   199  
   200  	// Create Index
   201  	_, err = db.Exec(fmt.Sprintf(`CREATE INDEX IF NOT EXISTS "%s" ON %s.%s USING btree ("key");`, "key_index_"+table, database, table))
   202  	if err != nil {
   203  		return err
   204  	}
   205  
   206  	// Create Metadata Index
   207  	_, err = db.Exec(fmt.Sprintf(`CREATE INDEX IF NOT EXISTS "%s" ON %s.%s USING GIN ("metadata");`, "metadata_index_"+table, database, table))
   208  	if err != nil {
   209  		return err
   210  	}
   211  
   212  	return nil
   213  }
   214  
   215  func (s *sqlStore) configure() error {
   216  	if len(s.options.Nodes) == 0 {
   217  		s.options.Nodes = []string{"postgresql://root@localhost:26257?sslmode=disable"}
   218  	}
   219  
   220  	source := s.options.Nodes[0]
   221  	// check if it is a standard connection string eg: host=%s port=%d user=%s password=%s dbname=%s sslmode=disable
   222  	// if err is nil which means it would be a URL like postgre://xxxx?yy=zz
   223  	_, err := url.Parse(source)
   224  	if err != nil {
   225  		if !strings.Contains(source, " ") {
   226  			source = fmt.Sprintf("host=%s", source)
   227  		}
   228  	}
   229  
   230  	// create source from first node
   231  	db, err := sql.Open("postgres", source)
   232  	if err != nil {
   233  		return err
   234  	}
   235  
   236  	if err := db.Ping(); err != nil {
   237  		return err
   238  	}
   239  
   240  	if s.dbConn != nil {
   241  		s.dbConn.Close()
   242  	}
   243  
   244  	// save the values
   245  	s.dbConn = db
   246  
   247  	// get DB
   248  	database, table := s.getDB(s.options.Database, s.options.Table)
   249  
   250  	// initialise the database
   251  	return s.initDB(database, table)
   252  }
   253  
   254  func (s *sqlStore) prepare(database, table, query string) (*sql.Stmt, error) {
   255  	st, ok := statements[query]
   256  	if !ok {
   257  		return nil, errors.New("unsupported statement")
   258  	}
   259  
   260  	// get DB
   261  	database, table = s.getDB(database, table)
   262  
   263  	q := fmt.Sprintf(st, database, table)
   264  
   265  	db, err := s.db()
   266  	if err != nil {
   267  		return nil, err
   268  	}
   269  	stmt, err := db.Prepare(q)
   270  	if err != nil {
   271  		return nil, err
   272  	}
   273  	return stmt, nil
   274  }
   275  
   276  func (s *sqlStore) Close() error {
   277  	if s.dbConn != nil {
   278  		return s.dbConn.Close()
   279  	}
   280  	return nil
   281  }
   282  
   283  func (s *sqlStore) Init(opts ...store.Option) error {
   284  	for _, o := range opts {
   285  		o(&s.options)
   286  	}
   287  	// reconfigure
   288  	return s.configure()
   289  }
   290  
   291  // List all the known records
   292  func (s *sqlStore) List(opts ...store.ListOption) ([]string, error) {
   293  	options := store.ListOptions{}
   294  
   295  	for _, o := range opts {
   296  		o(&options)
   297  	}
   298  
   299  	// create the db if not exists
   300  	if err := s.createDB(options.Database, options.Table); err != nil {
   301  		return nil, err
   302  	}
   303  	limit := sql.NullInt32{}
   304  	offset := 0
   305  	pattern := "%"
   306  	if options.Prefix != "" || options.Suffix != "" {
   307  		if options.Prefix != "" {
   308  			pattern = options.Prefix + pattern
   309  		}
   310  		if options.Suffix != "" {
   311  			pattern = pattern + options.Suffix
   312  		}
   313  	}
   314  	if options.Offset > 0 {
   315  		offset = int(options.Offset)
   316  	}
   317  	if options.Limit > 0 {
   318  		limit = sql.NullInt32{Int32: int32(options.Limit), Valid: true}
   319  	}
   320  
   321  	st, err := s.prepare(options.Database, options.Table, "list")
   322  	if err != nil {
   323  		return nil, err
   324  	}
   325  	defer st.Close()
   326  
   327  	rows, err := st.Query(pattern, limit, offset)
   328  	if err != nil {
   329  
   330  		if err == sql.ErrNoRows {
   331  			return nil, nil
   332  		}
   333  		return nil, err
   334  	}
   335  	defer rows.Close()
   336  	var keys []string
   337  	records, err := s.rowsToRecords(rows)
   338  	if err != nil {
   339  		return nil, err
   340  	}
   341  	for _, k := range records {
   342  		keys = append(keys, k.Key)
   343  	}
   344  	rowErr := rows.Close()
   345  	if rowErr != nil {
   346  		// transaction rollback or something
   347  		return keys, rowErr
   348  	}
   349  	if err := rows.Err(); err != nil {
   350  		return keys, err
   351  	}
   352  	return keys, nil
   353  }
   354  
   355  // rowToRecord converts from sql.Row to a store.Record. If the record has expired it will issue a delete in a separate goroutine
   356  func (s *sqlStore) rowToRecord(row *sql.Row) (*store.Record, error) {
   357  	var timehelper pq.NullTime
   358  	record := &store.Record{}
   359  	metadata := make(Metadata)
   360  
   361  	if err := row.Scan(&record.Key, &record.Value, &metadata, &timehelper); err != nil {
   362  		if err == sql.ErrNoRows {
   363  			return record, store.ErrNotFound
   364  		}
   365  		return nil, err
   366  	}
   367  
   368  	// set the metadata
   369  	record.Metadata = toMetadata(&metadata)
   370  	if timehelper.Valid {
   371  		if timehelper.Time.Before(time.Now()) {
   372  			// record has expired
   373  			go s.Delete(record.Key)
   374  			return nil, store.ErrNotFound
   375  		}
   376  		record.Expiry = time.Until(timehelper.Time)
   377  
   378  	}
   379  	return record, nil
   380  }
   381  
   382  // rowsToRecords converts from sql.Rows to  []*store.Record. If a record has expired it will issue a delete in a separate goroutine
   383  func (s *sqlStore) rowsToRecords(rows *sql.Rows) ([]*store.Record, error) {
   384  	var records []*store.Record
   385  	var timehelper pq.NullTime
   386  
   387  	for rows.Next() {
   388  		record := &store.Record{}
   389  		metadata := make(Metadata)
   390  
   391  		if err := rows.Scan(&record.Key, &record.Value, &metadata, &timehelper); err != nil {
   392  			return records, err
   393  		}
   394  
   395  		// set the metadata
   396  		record.Metadata = toMetadata(&metadata)
   397  
   398  		if timehelper.Valid {
   399  			if timehelper.Time.Before(time.Now()) {
   400  				// record has expired
   401  				go s.Delete(record.Key)
   402  			} else {
   403  				record.Expiry = time.Until(timehelper.Time)
   404  				records = append(records, record)
   405  			}
   406  		} else {
   407  			records = append(records, record)
   408  		}
   409  	}
   410  	return records, nil
   411  }
   412  
   413  // Read a single key
   414  func (s *sqlStore) Read(key string, opts ...store.ReadOption) ([]*store.Record, error) {
   415  	options := store.ReadOptions{}
   416  	for _, o := range opts {
   417  		o(&options)
   418  	}
   419  
   420  	// create the db if not exists
   421  	if err := s.createDB(options.Database, options.Table); err != nil {
   422  		return nil, err
   423  	}
   424  
   425  	if options.Prefix || options.Suffix {
   426  		return s.read(key, options)
   427  	}
   428  
   429  	st, err := s.prepare(options.Database, options.Table, "read")
   430  	if err != nil {
   431  		return nil, err
   432  	}
   433  	defer st.Close()
   434  
   435  	row := st.QueryRow(key)
   436  	record, err := s.rowToRecord(row)
   437  	if err != nil {
   438  		return nil, err
   439  	}
   440  	var records []*store.Record
   441  	return append(records, record), nil
   442  }
   443  
   444  // Read Many records
   445  func (s *sqlStore) read(key string, options store.ReadOptions) ([]*store.Record, error) {
   446  	pattern := "%"
   447  	if options.Prefix {
   448  		pattern = key + pattern
   449  	}
   450  	if options.Suffix {
   451  		pattern = pattern + key
   452  	}
   453  
   454  	var rows *sql.Rows
   455  	var st *sql.Stmt
   456  	var err error
   457  
   458  	if options.Limit != 0 {
   459  		st, err = s.prepare(options.Database, options.Table, "readOffset")
   460  		if err != nil {
   461  			return nil, err
   462  		}
   463  		defer st.Close()
   464  
   465  		rows, err = st.Query(pattern, options.Limit, options.Offset)
   466  	} else {
   467  		st, err = s.prepare(options.Database, options.Table, "readMany")
   468  		if err != nil {
   469  			return nil, err
   470  		}
   471  		defer st.Close()
   472  
   473  		rows, err = st.Query(pattern)
   474  	}
   475  	if err != nil {
   476  		if err == sql.ErrNoRows {
   477  			return []*store.Record{}, nil
   478  		}
   479  		return []*store.Record{}, errors.Wrap(err, "sqlStore.read failed")
   480  	}
   481  
   482  	defer rows.Close()
   483  
   484  	records, err := s.rowsToRecords(rows)
   485  	if err != nil {
   486  		return nil, err
   487  	}
   488  	rowErr := rows.Close()
   489  	if rowErr != nil {
   490  		// transaction rollback or something
   491  		return records, rowErr
   492  	}
   493  	if err := rows.Err(); err != nil {
   494  		return records, err
   495  	}
   496  
   497  	return records, nil
   498  }
   499  
   500  // Write records
   501  func (s *sqlStore) Write(r *store.Record, opts ...store.WriteOption) error {
   502  	var options store.WriteOptions
   503  	for _, o := range opts {
   504  		o(&options)
   505  	}
   506  
   507  	// create the db if not exists
   508  	if err := s.createDB(options.Database, options.Table); err != nil {
   509  		return err
   510  	}
   511  
   512  	st, err := s.prepare(options.Database, options.Table, "write")
   513  	if err != nil {
   514  		return err
   515  	}
   516  	defer st.Close()
   517  
   518  	metadata := make(Metadata)
   519  	for k, v := range r.Metadata {
   520  		metadata[k] = v
   521  	}
   522  
   523  	var expiry time.Time
   524  	if r.Expiry != 0 {
   525  		expiry = time.Now().Add(r.Expiry)
   526  	}
   527  
   528  	if expiry.IsZero() {
   529  		_, err = st.Exec(r.Key, r.Value, metadata, nil)
   530  	} else {
   531  		_, err = st.Exec(r.Key, r.Value, metadata, expiry)
   532  	}
   533  
   534  	if err != nil {
   535  		return errors.Wrap(err, "Couldn't insert record "+r.Key)
   536  	}
   537  
   538  	return nil
   539  }
   540  
   541  // Delete records with keys
   542  func (s *sqlStore) Delete(key string, opts ...store.DeleteOption) error {
   543  	var options store.DeleteOptions
   544  	for _, o := range opts {
   545  		o(&options)
   546  	}
   547  
   548  	// create the db if not exists
   549  	if err := s.createDB(options.Database, options.Table); err != nil {
   550  		return err
   551  	}
   552  
   553  	st, err := s.prepare(options.Database, options.Table, "delete")
   554  	if err != nil {
   555  		return err
   556  	}
   557  	defer st.Close()
   558  
   559  	result, err := st.Exec(key)
   560  	if err != nil {
   561  		return err
   562  	}
   563  
   564  	_, err = result.RowsAffected()
   565  	if err != nil {
   566  		return err
   567  	}
   568  
   569  	return nil
   570  }
   571  
   572  func (s *sqlStore) Options() store.Options {
   573  	return s.options
   574  }
   575  
   576  func (s *sqlStore) String() string {
   577  	return "cockroach"
   578  }
   579  
   580  // NewStore returns a new micro Store backed by sql
   581  func NewStore(opts ...store.Option) store.Store {
   582  	options := store.Options{
   583  		Database: DefaultDatabase,
   584  		Table:    DefaultTable,
   585  	}
   586  
   587  	for _, o := range opts {
   588  		o(&options)
   589  	}
   590  
   591  	// new store
   592  	s := new(sqlStore)
   593  	// set the options
   594  	s.options = options
   595  	// mark known databases
   596  	s.databases = make(map[string]bool)
   597  	// best-effort configure the store
   598  	if err := s.configure(); err != nil {
   599  		if logger.V(logger.ErrorLevel, logger.DefaultLogger) {
   600  			logger.Error("Error configuring store ", err)
   601  		}
   602  	}
   603  	go s.expiryLoop()
   604  	// return store
   605  	return s
   606  }
   607  
   608  func (s *sqlStore) expiryLoop() {
   609  	for {
   610  		s.expireRows()
   611  		time.Sleep(1 * time.Hour)
   612  	}
   613  }
   614  
   615  func (s *sqlStore) expireRows() error {
   616  	db, err := s.db()
   617  	if err != nil {
   618  		logger.Errorf("Error getting DB connection %s", err)
   619  		return err
   620  	}
   621  	stmt, err := db.Prepare(statements["showTables"])
   622  	if err != nil {
   623  		logger.Errorf("Error prepping show tables query %s", err)
   624  		return err
   625  	}
   626  	defer stmt.Close()
   627  	rows, err := stmt.Query()
   628  	if err != nil {
   629  		logger.Errorf("Error running show tables query %s", err)
   630  		return err
   631  	}
   632  	defer rows.Close()
   633  	for rows.Next() {
   634  		var schemaName, tableName string
   635  		if err := rows.Scan(&schemaName, &tableName); err != nil {
   636  			logger.Errorf("Error parsing result %s", err)
   637  			return err
   638  		}
   639  		db, err = s.db()
   640  		if err != nil {
   641  			logger.Errorf("Error prepping delete expired query %s", err)
   642  			return err
   643  		}
   644  		delStmt, err := db.Prepare(fmt.Sprintf(statements["deleteExpired"], schemaName, tableName))
   645  		if err != nil {
   646  			logger.Errorf("Error prepping delete expired query %s", err)
   647  			return err
   648  		}
   649  		defer delStmt.Close()
   650  		res, err := delStmt.Exec()
   651  		if err != nil {
   652  			logger.Errorf("Error cleaning up %s", err)
   653  			return err
   654  		}
   655  
   656  		r, _ := res.RowsAffected()
   657  		logger.Infof("Cleaning up %s %s: %d rows deleted", schemaName, tableName, r)
   658  
   659  	}
   660  	return nil
   661  }