github.com/tuhaihe/gpbackup@v1.0.3/options/options.go (about)

     1  package options
     2  
     3  import (
     4  	"fmt"
     5  	"strings"
     6  
     7  	"github.com/tuhaihe/gp-common-go-libs/dbconn"
     8  	"github.com/tuhaihe/gp-common-go-libs/gplog"
     9  	"github.com/tuhaihe/gp-common-go-libs/iohelper"
    10  	"github.com/tuhaihe/gpbackup/utils"
    11  	"github.com/pkg/errors"
    12  	"github.com/spf13/pflag"
    13  )
    14  
    15  // This is meant to be a read only package. Values inside should only be
    16  // modified by setters, it's method functions, or initialization function.
    17  // This package is meant to make mocking flags easier.
    18  type Options struct {
    19  	IncludedRelations         []string
    20  	ExcludedRelations         []string
    21  	isLeafPartitionData       bool
    22  	ExcludedSchemas           []string
    23  	IncludedSchemas           []string
    24  	originalIncludedRelations []string
    25  	RedirectSchema            string
    26  }
    27  
    28  func NewOptions(initialFlags *pflag.FlagSet) (*Options, error) {
    29  	includedRelations, err := setFiltersFromFile(initialFlags, INCLUDE_RELATION, INCLUDE_RELATION_FILE)
    30  	if err != nil {
    31  		return nil, err
    32  	}
    33  	err = utils.ValidateFQNs(includedRelations)
    34  	if err != nil {
    35  		return nil, err
    36  	}
    37  
    38  	excludedRelations, err := setFiltersFromFile(initialFlags, EXCLUDE_RELATION, EXCLUDE_RELATION_FILE)
    39  	if err != nil {
    40  		return nil, err
    41  	}
    42  	err = utils.ValidateFQNs(excludedRelations)
    43  	if err != nil {
    44  		return nil, err
    45  	}
    46  
    47  	includedSchemas, err := setFiltersFromFile(initialFlags, INCLUDE_SCHEMA, INCLUDE_SCHEMA_FILE)
    48  	if err != nil {
    49  		return nil, err
    50  	}
    51  
    52  	excludedSchemas, err := setFiltersFromFile(initialFlags, EXCLUDE_SCHEMA, EXCLUDE_SCHEMA_FILE)
    53  	if err != nil {
    54  		return nil, err
    55  	}
    56  
    57  	leafPartitionData, err := initialFlags.GetBool(LEAF_PARTITION_DATA)
    58  	if err != nil {
    59  		return nil, err
    60  	}
    61  
    62  	redirectSchema := ""
    63  	if initialFlags.Lookup(REDIRECT_SCHEMA) != nil {
    64  		redirectSchema, err = initialFlags.GetString(REDIRECT_SCHEMA)
    65  		if err != nil {
    66  			return nil, err
    67  		}
    68  	}
    69  
    70  	return &Options{
    71  		IncludedRelations:         includedRelations,
    72  		ExcludedRelations:         excludedRelations,
    73  		IncludedSchemas:           includedSchemas,
    74  		ExcludedSchemas:           excludedSchemas,
    75  		isLeafPartitionData:       leafPartitionData,
    76  		originalIncludedRelations: includedRelations,
    77  		RedirectSchema:            redirectSchema,
    78  	}, nil
    79  }
    80  
    81  func setFiltersFromFile(initialFlags *pflag.FlagSet, filterFlag string, filterFileFlag string) ([]string, error) {
    82  	filters, err := initialFlags.GetStringArray(filterFlag)
    83  	if err != nil {
    84  		return nil, err
    85  	}
    86  	// values obtained from file filterFileFlag are copied to values in filterFlag
    87  	// values are mutually exclusive so this is not an overwrite, it is a "fresh" setting
    88  	filename, err := initialFlags.GetString(filterFileFlag)
    89  	if err != nil {
    90  		return nil, err
    91  	}
    92  	if filename != "" {
    93  		filterLines, err := iohelper.ReadLinesFromFile(filename)
    94  		if err != nil {
    95  			return nil, err
    96  		}
    97  		// copy any values for flag filterFileFlag into global flag for filterFlag
    98  		for _, fqn := range filterLines {
    99  			if fqn != "" {
   100  				filters = append(filters, fqn)          //This appends filter to options
   101  				err = initialFlags.Set(filterFlag, fqn) //This appends to the slice underlying the flag.
   102  				if err != nil {
   103  					return nil, err
   104  				}
   105  			}
   106  		}
   107  		if err != nil {
   108  			return nil, err
   109  		}
   110  	}
   111  	return filters, nil
   112  }
   113  
   114  func (o Options) GetIncludedTables() []string {
   115  	return o.IncludedRelations
   116  }
   117  
   118  func (o Options) GetOriginalIncludedTables() []string {
   119  	return o.originalIncludedRelations
   120  }
   121  
   122  func (o Options) GetExcludedTables() []string {
   123  	return o.ExcludedRelations
   124  }
   125  
   126  func (o Options) IsLeafPartitionData() bool {
   127  	return o.isLeafPartitionData
   128  }
   129  
   130  func (o Options) GetIncludedSchemas() []string {
   131  	return o.IncludedSchemas
   132  }
   133  
   134  func (o Options) GetExcludedSchemas() []string {
   135  	return o.ExcludedSchemas
   136  }
   137  
   138  func (o *Options) AddIncludedRelation(relation string) {
   139  	o.IncludedRelations = append(o.IncludedRelations, relation)
   140  }
   141  
   142  type FqnStruct struct {
   143  	SchemaName string
   144  	TableName  string
   145  }
   146  
   147  func QuoteTableNames(conn *dbconn.DBConn, tableNames []string) ([]string, error) {
   148  	if len(tableNames) == 0 {
   149  		return []string{}, nil
   150  	}
   151  
   152  	// Properly escape single quote before running quote ident. Postgres
   153  	// quote_ident escapes single quotes by doubling them
   154  	escapedTables := make([]string, 0)
   155  	for _, v := range tableNames {
   156  		escapedTables = append(escapedTables, utils.EscapeSingleQuotes(v))
   157  	}
   158  
   159  	fqnSlice, err := SeparateSchemaAndTable(escapedTables)
   160  	if err != nil {
   161  		return nil, err
   162  	}
   163  	result := make([]string, 0)
   164  
   165  	quoteIdentTableFQNQuery := `SELECT quote_ident('%s') AS schemaname, quote_ident('%s') AS tablename`
   166  	for _, fqn := range fqnSlice {
   167  		queryResultTable := make([]FqnStruct, 0)
   168  		query := fmt.Sprintf(quoteIdentTableFQNQuery, fqn.SchemaName, fqn.TableName)
   169  		err := conn.Select(&queryResultTable, query)
   170  		if err != nil {
   171  			return nil, err
   172  		}
   173  		quoted := queryResultTable[0].SchemaName + "." + queryResultTable[0].TableName
   174  		result = append(result, quoted)
   175  	}
   176  
   177  	return result, nil
   178  }
   179  
   180  func SeparateSchemaAndTable(tableNames []string) ([]FqnStruct, error) {
   181  	fqnSlice := make([]FqnStruct, 0)
   182  	for _, fqn := range tableNames {
   183  		parts := strings.Split(fqn, ".")
   184  		if len(parts) > 2 {
   185  			return nil, errors.Errorf("cannot process an Fully Qualified Name with embedded dots yet: %s", fqn)
   186  		}
   187  		if len(parts) < 2 {
   188  			return nil, errors.Errorf("Fully Qualified Names require a minimum of one dot, specifying the schema and table. Cannot process: %s", fqn)
   189  		}
   190  		schema := parts[0]
   191  		table := parts[1]
   192  		if schema == "" || table == "" {
   193  			return nil, errors.Errorf("Fully Qualified Names must specify the schema and table. Cannot process: %s", fqn)
   194  		}
   195  
   196  		currFqn := FqnStruct{
   197  			SchemaName: schema,
   198  			TableName:  table,
   199  		}
   200  
   201  		fqnSlice = append(fqnSlice, currFqn)
   202  	}
   203  
   204  	return fqnSlice, nil
   205  }
   206  
   207  func (o *Options) ExpandIncludesForPartitions(conn *dbconn.DBConn, flags *pflag.FlagSet) error {
   208  	if len(o.GetIncludedTables()) == 0 {
   209  		return nil
   210  	}
   211  
   212  	quotedIncludeRelations, err := QuoteTableNames(conn, o.GetIncludedTables())
   213  	if err != nil {
   214  		return err
   215  	}
   216  
   217  	allFqnStructs, err := o.getUserTableRelationsWithIncludeFiltering(conn, quotedIncludeRelations)
   218  	if err != nil {
   219  		return err
   220  	}
   221  
   222  	includeSet := map[string]bool{}
   223  	for _, include := range o.GetIncludedTables() {
   224  		includeSet[include] = true
   225  	}
   226  
   227  	allFqnSet := map[string]bool{}
   228  	for _, fqnStruct := range allFqnStructs {
   229  		fqn := fmt.Sprintf("%s.%s", fqnStruct.SchemaName, fqnStruct.TableName)
   230  		allFqnSet[fqn] = true
   231  	}
   232  
   233  	// set arithmetic: find difference
   234  	diff := make([]string, 0)
   235  	for key := range allFqnSet {
   236  		_, keyExists := includeSet[key]
   237  		if !keyExists {
   238  			diff = append(diff, key)
   239  		}
   240  	}
   241  
   242  	for _, fqn := range diff {
   243  		err = flags.Set(INCLUDE_RELATION, fqn)
   244  		if err != nil {
   245  			return err
   246  		}
   247  		o.AddIncludedRelation(fqn)
   248  	}
   249  
   250  	return nil
   251  }
   252  
   253  func (o *Options) QuoteIncludeRelations(conn *dbconn.DBConn) error {
   254  	var err error
   255  	o.IncludedRelations, err = QuoteTableNames(conn, o.GetIncludedTables())
   256  	if err != nil {
   257  		return err
   258  	}
   259  
   260  	return nil
   261  }
   262  
   263  func (o *Options) QuoteExcludeRelations(conn *dbconn.DBConn) error {
   264  	var err error
   265  	o.ExcludedRelations, err = QuoteTableNames(conn, o.GetExcludedTables())
   266  	if err != nil {
   267  		return err
   268  	}
   269  
   270  	return nil
   271  }
   272  
   273  // given a set of table oids, return a deduplicated set of other tables that EITHER depend
   274  // on them, OR that they depend on. The behavior for which is set with recurseDirection.
   275  func (o *Options) recurseTableDepend(conn *dbconn.DBConn, includeOids []string, recurseSource string) ([]string, error) {
   276  	var err error
   277  	var dependQuery string
   278  
   279  	expandedIncludeOids := make(map[string]bool)
   280  	for _, oid := range includeOids {
   281  		expandedIncludeOids[oid] = true
   282  	}
   283  
   284  	if recurseSource == "child" {
   285  		dependQuery = `
   286  			SELECT dep.refobjid
   287  			FROM
   288  				pg_depend dep
   289  				INNER JOIN pg_class cls ON dep.refobjid = cls.oid
   290  			WHERE
   291  				dep.objid in (%s)
   292  				AND cls.relkind in ('r', 'p', 'f')`
   293  	} else if recurseSource == "parent" {
   294  		dependQuery = `
   295  			SELECT dep.objid
   296  			FROM
   297  				pg_depend dep
   298  				INNER JOIN pg_class cls ON dep.objid = cls.oid
   299  			WHERE
   300  				dep.refobjid in (%s)
   301  				AND cls.relkind in ('r', 'p', 'f')`
   302  	} else {
   303  		gplog.Error("Please fix calling of this function recurseTableDepend. Argument recurseSource only accepts 'parent' or 'child'.")
   304  	}
   305  
   306  	// here we loop until no further table dependencies are found.  implemented iteratively, but functions like a recursion
   307  	foundDeps := true
   308  	loopOids := includeOids
   309  	for foundDeps {
   310  		foundDeps = false
   311  		depOids := make([]string, 0)
   312  		loopDepQuery := fmt.Sprintf(dependQuery, strings.Join(loopOids, ", "))
   313  		err = conn.Select(&depOids, loopDepQuery)
   314  		if err != nil {
   315  			gplog.Warn("Table dependency query failed: %s", loopDepQuery)
   316  			return nil, err
   317  		}
   318  
   319  		// confirm that any table dependencies are found
   320  		// save the table dependencies for both output and for next recursion
   321  		loopOids = loopOids[:]
   322  		for _, depOid := range depOids {
   323  			// must exclude oids already captured to avoid circular dependencies
   324  			// causing an infinite loop
   325  			if !expandedIncludeOids[depOid] {
   326  				foundDeps = true
   327  				loopOids = append(loopOids, depOid)
   328  				expandedIncludeOids[depOid] = true
   329  			}
   330  		}
   331  	}
   332  
   333  	// capture deduplicated oids from map keys, return as array
   334  	// done as a direct array assignment loop because it's faster and we know the length
   335  	expandedIncludeOidsArr := make([]string, len(expandedIncludeOids))
   336  	arrayIdx := 0
   337  	for idx := range expandedIncludeOids {
   338  		expandedIncludeOidsArr[arrayIdx] = idx
   339  		arrayIdx++
   340  	}
   341  	return expandedIncludeOidsArr, err
   342  }
   343  
   344  func (o Options) getUserTableRelationsWithIncludeFiltering(connectionPool *dbconn.DBConn, includedRelationsQuoted []string) ([]FqnStruct, error) {
   345  	includeOids, err := getOidsFromRelationList(connectionPool, includedRelationsQuoted)
   346  	if err != nil {
   347  		return nil, err
   348  	}
   349  
   350  	oidStr := strings.Join(includeOids, ", ")
   351  	childPartitionFilter := ""
   352  	parentAndExternalPartitionFilter := ""
   353  	// GPDB7+ reworks the nature of partition tables.  It is no longer sufficient
   354  	// to pull parents and children in one step.  Instead we must recursively climb/descend
   355  	// the pg_depend ladder, filtering to only members of pg_class at each step, until the
   356  	// full hierarchy has been retrieved
   357  	childOids, err := o.recurseTableDepend(connectionPool, includeOids, "parent")
   358  	if err != nil {
   359  		return nil, err
   360  	}
   361  	if len(childOids) > 0 {
   362  		childPartitionFilter = fmt.Sprintf(`OR c.oid IN (%s)`, strings.Join(childOids, ", "))
   363  	}
   364  
   365  	parentOids, err := o.recurseTableDepend(connectionPool, includeOids, "child")
   366  	if err != nil {
   367  		return nil, err
   368  	}
   369  	if len(parentOids) > 0 {
   370  		parentAndExternalPartitionFilter = fmt.Sprintf(`OR c.oid IN (%s)`, strings.Join(parentOids, ", "))
   371  	}
   372  
   373  	query := fmt.Sprintf(`
   374  SELECT
   375  	n.nspname AS schemaname,
   376  	c.relname AS tablename
   377  FROM pg_class c
   378  JOIN pg_namespace n
   379  	ON c.relnamespace = n.oid
   380  WHERE %s
   381  AND (
   382  	-- Get tables in the include list
   383  	c.oid IN (%s)
   384  	%s
   385  	%s
   386  )
   387  AND relkind IN ('r', 'f', 'p')
   388  AND %s
   389  ORDER BY c.oid;`, o.schemaFilterClause("n"), oidStr, parentAndExternalPartitionFilter, childPartitionFilter, ExtensionFilterClause("c"))
   390  
   391  	results := make([]FqnStruct, 0)
   392  	err = connectionPool.Select(&results, query)
   393  
   394  	return results, err
   395  }
   396  
   397  func getOidsFromRelationList(connectionPool *dbconn.DBConn, quotedRelationNames []string) ([]string, error) {
   398  	relList := utils.SliceToQuotedString(quotedRelationNames)
   399  	query := fmt.Sprintf(`
   400  SELECT
   401  	c.oid AS string
   402  FROM pg_class c
   403  JOIN pg_namespace n ON c.relnamespace = n.oid
   404  WHERE quote_ident(n.nspname) || '.' || quote_ident(c.relname) IN (%s)`, relList)
   405  
   406  	return dbconn.SelectStringSlice(connectionPool, query)
   407  }
   408  
   409  // A list of schemas we don't want to back up, formatted for use in a WHERE clause
   410  func (o Options) schemaFilterClause(namespace string) string {
   411  	schemaFilterClauseStr := ""
   412  	if len(o.GetIncludedSchemas()) > 0 {
   413  		schemaFilterClauseStr = fmt.Sprintf("\nAND %s.nspname IN (%s)", namespace, utils.SliceToQuotedString(o.GetIncludedSchemas()))
   414  	}
   415  	if len(o.GetExcludedSchemas()) > 0 {
   416  		schemaFilterClauseStr = fmt.Sprintf("\nAND %s.nspname NOT IN (%s)", namespace, utils.SliceToQuotedString(o.GetExcludedSchemas()))
   417  	}
   418  	return fmt.Sprintf(`%s.nspname NOT LIKE 'pg_temp_%%' AND %s.nspname NOT LIKE 'pg_toast%%' AND %s.nspname NOT IN ('gp_toolkit', 'information_schema', 'pg_aoseg', 'pg_bitmapindex', 'pg_catalog') %s`, namespace, namespace, namespace, schemaFilterClauseStr)
   419  }
   420  
   421  func ExtensionFilterClause(namespace string) string {
   422  	oidStr := "oid"
   423  	if namespace != "" {
   424  		oidStr = fmt.Sprintf("%s.oid", namespace)
   425  	}
   426  
   427  	return fmt.Sprintf("%s NOT IN (select objid from pg_depend where deptype = 'e')", oidStr)
   428  }