github.com/dshekhar95/sub_dgraph@v0.0.0-20230424164411-6be28e40bbf1/dgraph/cmd/migrate/table_info.go (about)

     1  /*
     2   * Copyright 2022 Dgraph Labs, Inc. and Contributors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package migrate
    18  
    19  import (
    20  	"database/sql"
    21  	"fmt"
    22  	"strings"
    23  
    24  	"github.com/pkg/errors"
    25  
    26  	"github.com/dgraph-io/dgraph/x"
    27  )
    28  
    29  type keyType int
    30  
    31  const (
    32  	none keyType = iota
    33  	primary
    34  	secondary
    35  )
    36  
    37  type dataType int
    38  
    39  type columnInfo struct {
    40  	name     string
    41  	keyType  keyType
    42  	dataType dataType
    43  }
    44  
    45  // fkConstraint represents a foreign key constraint
    46  type fkConstraint struct {
    47  	parts []*constraintPart
    48  	// the referenced column names and their indices in the foreign table
    49  	foreignIndices []*columnIdx
    50  }
    51  
    52  type constraintPart struct {
    53  	// the local table name
    54  	tableName string
    55  	// the local column name
    56  	columnName string
    57  	// the remote table name can be either the source or target of a foreign key constraint
    58  	remoteTableName string
    59  	// the remote column name can be either the source or target of a foreign key constraint
    60  	remoteColumnName string
    61  }
    62  
    63  // a sqlTable contains a SQL table's metadata such as the table name,
    64  // the info of each column etc
    65  type sqlTable struct {
    66  	tableName string
    67  	columns   map[string]*columnInfo
    68  
    69  	// The following 3 columns are used by the rowMeta when converting rows
    70  	columnDataTypes []dataType
    71  	columnNames     []string
    72  	isForeignKey    map[string]bool // whether a given column is a foreign key
    73  	predNames       []string
    74  
    75  	// the referenced tables by the current table through foreign key constraints
    76  	dstTables map[string]interface{}
    77  
    78  	// a map from constraint names to constraints
    79  	foreignKeyConstraints map[string]*fkConstraint
    80  
    81  	// the list of foreign key constraints using this table as the target
    82  	cstSources []*fkConstraint
    83  }
    84  
    85  func getDataType(dbType string) dataType {
    86  	for prefix, goType := range sqlTypeToInternal {
    87  		if strings.HasPrefix(dbType, prefix) {
    88  			return goType
    89  		}
    90  	}
    91  	return unknownType
    92  }
    93  
    94  func getColumnInfo(fieldName string, dbType string) *columnInfo {
    95  	columnInfo := columnInfo{}
    96  	columnInfo.name = fieldName
    97  	columnInfo.dataType = getDataType(dbType)
    98  	return &columnInfo
    99  }
   100  
   101  func parseTables(pool *sql.DB, tableName string, database string) (*sqlTable, error) {
   102  	query := fmt.Sprintf(`select COLUMN_NAME,DATA_TYPE from INFORMATION_SCHEMA.
   103  COLUMNS where TABLE_NAME = "%s" AND TABLE_SCHEMA="%s" ORDER BY COLUMN_NAME`, tableName, database)
   104  	columns, err := pool.Query(query)
   105  	if err != nil {
   106  		return nil, err
   107  	}
   108  	defer columns.Close()
   109  
   110  	table := &sqlTable{
   111  		tableName:             tableName,
   112  		columns:               make(map[string]*columnInfo),
   113  		columnNames:           make([]string, 0),
   114  		isForeignKey:          make(map[string]bool),
   115  		columnDataTypes:       make([]dataType, 0),
   116  		predNames:             make([]string, 0),
   117  		dstTables:             make(map[string]interface{}),
   118  		foreignKeyConstraints: make(map[string]*fkConstraint),
   119  	}
   120  
   121  	for columns.Next() {
   122  		/*
   123  			each row represents info about a column, for example
   124  			+---------------+-----------+
   125  			| COLUMN_NAME   | DATA_TYPE |
   126  			+---------------+-----------+
   127  			| p_company     | varchar   |
   128  			| p_employee_id | int       |
   129  			| p_fname       | varchar   |
   130  			| p_lname       | varchar   |
   131  			| title         | varchar   |
   132  			+---------------+-----------+
   133  		*/
   134  		var fieldName, dbType string
   135  		if err := columns.Scan(&fieldName, &dbType); err != nil {
   136  			return nil, errors.Wrapf(err, "unable to scan table description result for table %s",
   137  				tableName)
   138  		}
   139  
   140  		// TODO, should store the column data types into the table info as an array
   141  		// and the RMI should simply get the data types from the table info
   142  		table.columns[fieldName] = getColumnInfo(fieldName, dbType)
   143  		table.columnNames = append(table.columnNames, fieldName)
   144  		table.columnDataTypes = append(table.columnDataTypes, getDataType(dbType))
   145  	}
   146  
   147  	// query indices
   148  	indexQuery := fmt.Sprintf(`select INDEX_NAME,COLUMN_NAME from INFORMATION_SCHEMA.`+
   149  		`STATISTICS where TABLE_NAME = "%s" AND index_schema="%s"`, tableName, database)
   150  	indices, err := pool.Query(indexQuery)
   151  	if err != nil {
   152  		return nil, err
   153  	}
   154  	defer indices.Close()
   155  	for indices.Next() {
   156  		var indexName, columnName string
   157  		err := indices.Scan(&indexName, &columnName)
   158  		if err != nil {
   159  			return nil, errors.Wrapf(err, "unable to scan index info for table %s", tableName)
   160  		}
   161  		switch indexName {
   162  		case "PRIMARY":
   163  			table.columns[columnName].keyType = primary
   164  		default:
   165  			table.columns[columnName].keyType = secondary
   166  		}
   167  
   168  	}
   169  
   170  	foreignKeysQuery := fmt.Sprintf(`select COLUMN_NAME,CONSTRAINT_NAME,REFERENCED_TABLE_NAME,
   171  		REFERENCED_COLUMN_NAME from INFORMATION_SCHEMA.KEY_COLUMN_USAGE where TABLE_NAME = "%s"
   172          AND CONSTRAINT_SCHEMA="%s" AND REFERENCED_TABLE_NAME IS NOT NULL`, tableName, database)
   173  	fkeys, err := pool.Query(foreignKeysQuery)
   174  	if err != nil {
   175  		return nil, err
   176  	}
   177  	defer fkeys.Close()
   178  	for fkeys.Next() {
   179  		/* example output from MySQL
   180  		+---------------+-----------------+-----------------------+------------------------+
   181  		| COLUMN_NAME   | CONSTRAINT_NAME | REFERENCED_TABLE_NAME | REFERENCED_COLUMN_NAME |
   182  		+---------------+-----------------+-----------------------+------------------------+
   183  		| p_fname       | role_ibfk_1     | person                | fname                  |
   184  		| p_lname       | role_ibfk_1     | person                | lname                  |
   185  		| p_company     | role_ibfk_2     | person                | company                |
   186  		| p_employee_id | role_ibfk_2     | person                | employee_id            |
   187  		+---------------+-----------------+-----------------------+------------------------+
   188  		*/
   189  		var col, constraintName, dstTable, dstCol string
   190  		if err := fkeys.Scan(&col, &constraintName, &dstTable, &dstCol); err != nil {
   191  			return nil, errors.Wrapf(err, "unable to scan usage info for table %s", tableName)
   192  		}
   193  
   194  		table.dstTables[dstTable] = struct{}{}
   195  		var constraint *fkConstraint
   196  		var ok bool
   197  		if constraint, ok = table.foreignKeyConstraints[constraintName]; !ok {
   198  			constraint = &fkConstraint{
   199  				parts: make([]*constraintPart, 0),
   200  			}
   201  			table.foreignKeyConstraints[constraintName] = constraint
   202  		}
   203  		constraint.parts = append(constraint.parts, &constraintPart{
   204  			tableName:        tableName,
   205  			columnName:       col,
   206  			remoteTableName:  dstTable,
   207  			remoteColumnName: dstCol,
   208  		})
   209  
   210  		table.isForeignKey[col] = true
   211  	}
   212  	return table, nil
   213  }
   214  
   215  // validateAndGetReverse flip the foreign key reference direction in a constraint.
   216  // For example, if the constraint's local table name is A, and it has 3 columns
   217  // col1, col2, col3 that references a remote table B's 3 columns col4, col5, col6,
   218  // then we return a reversed constraint whose local table name is B with local columns
   219  // col4, col5, col6 whose remote table name is A, and remote columns are
   220  // col1, col2 and col3
   221  func validateAndGetReverse(constraint *fkConstraint) (string, *fkConstraint) {
   222  	reverseParts := make([]*constraintPart, 0)
   223  	// verify that within one constraint, the remote table names are the same
   224  	var remoteTableName string
   225  	for _, part := range constraint.parts {
   226  		if len(remoteTableName) == 0 {
   227  			remoteTableName = part.remoteTableName
   228  		} else {
   229  			x.AssertTrue(part.remoteTableName == remoteTableName)
   230  		}
   231  		reverseParts = append(reverseParts, &constraintPart{
   232  			tableName:        part.remoteColumnName,
   233  			columnName:       part.remoteColumnName,
   234  			remoteTableName:  part.tableName,
   235  			remoteColumnName: part.columnName,
   236  		})
   237  	}
   238  	return remoteTableName, &fkConstraint{
   239  		parts: reverseParts,
   240  	}
   241  }
   242  
   243  // populateReferencedByColumns calculates the reverse links of
   244  // the data at tables[table name].foreignKeyReferences
   245  // and stores them in tables[table name].cstSources
   246  func populateReferencedByColumns(tables map[string]*sqlTable) {
   247  	for _, tableInfo := range tables {
   248  		for _, constraint := range tableInfo.foreignKeyConstraints {
   249  			reverseTable, reverseConstraint := validateAndGetReverse(constraint)
   250  
   251  			tables[reverseTable].cstSources = append(tables[reverseTable].cstSources,
   252  				reverseConstraint)
   253  		}
   254  	}
   255  }