github.com/dshekhar95/sub_dgraph@v0.0.0-20230424164411-6be28e40bbf1/dgraph/cmd/migrate/table_guide.go (about)

     1  /*
     2   * Copyright 2022 Dgraph Labs, Inc. and Contributors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package migrate
    18  
    19  import (
    20  	"database/sql"
    21  	"fmt"
    22  	"strings"
    23  
    24  	"github.com/go-sql-driver/mysql"
    25  	"github.com/pkg/errors"
    26  )
    27  
    28  var separator = "."
    29  
    30  // A blankNode generates the unique blank node label that corresponds to a Dgraph uid.
    31  // Values are passed to the genBlankNode method in the order of alphabetically sorted columns
    32  type blankNode interface {
    33  	generate(info *sqlTable, values []interface{}) string
    34  }
    35  
    36  // usingColumns generates blank node labels using values in the primary key columns
    37  type usingColumns struct {
    38  	primaryKeyIndices []*columnIdx
    39  }
    40  
    41  // As an example, if the employee table has 3 columns (f_name, l_name, and title),
    42  // where f_name and l_name together form the primary key.
    43  // Then a row with values John (f_name), Doe (l_name), Software Engineer (title)
    44  // would generate a blank node label _:person_John_Doe using values from the primary key columns
    45  // in the alphabetic order, that is f_name, l_name in this case.
    46  func (g *usingColumns) generate(info *sqlTable, values []interface{}) string {
    47  	if g.primaryKeyIndices == nil {
    48  		g.primaryKeyIndices = getColumnIndices(info, func(info *sqlTable, column string) bool {
    49  			return info.columns[column].keyType == primary
    50  		})
    51  	}
    52  
    53  	// use the primary key indices to retrieve values in the current row
    54  	var parts []string
    55  	parts = append(parts, info.tableName)
    56  	for _, columnIndex := range g.primaryKeyIndices {
    57  		strVal, err := getValue(info.columns[columnIndex.name].dataType,
    58  			values[columnIndex.index])
    59  		if err != nil {
    60  			logger.Fatalf("Unable to get string value from primary key column %s", columnIndex.name)
    61  		}
    62  		parts = append(parts, strVal)
    63  	}
    64  
    65  	return fmt.Sprintf("_:%s", strings.Join(parts, separator))
    66  }
    67  
    68  // A usingCounter generates blank node labels using a row counter
    69  type usingCounter struct {
    70  	rowCounter int
    71  }
    72  
    73  func (g *usingCounter) generate(info *sqlTable, values []interface{}) string {
    74  	g.rowCounter++
    75  	return fmt.Sprintf("_:%s%s%d", info.tableName, separator, g.rowCounter)
    76  }
    77  
    78  // a valuesRecorder remembers the mapping between an ref label and its blank node label
    79  // For example, if the person table has the (fname, lname) as the primary key,
    80  // and there are two unique indices on the columns "license" and "ssn" respectively.
    81  // For the row fname (John), lname (Doe), license(101), ssn (999-999-9999)
    82  // the Value recorder would remember the following mappings
    83  // _:person_license_101 -> _:person_John_Doe
    84  // _:person_ssn_999-999-9999 -> _:person_John_Doe
    85  // It remembers these mapping so that if another table references the person table through foreign
    86  // key constraints, there is a way to look up the blank node labels and use it to create
    87  // a Dgraph link between the two rows in the two different tables.
    88  type valuesRecorder interface {
    89  	record(info *sqlTable, values []interface{}, blankNodeLabel string)
    90  	getBlankNode(indexLabel string) string
    91  }
    92  
    93  // for a given SQL row, the fkValuesRecorder records mappings from its foreign key target columns to
    94  // the blank node of the row
    95  type fkValuesRecorder struct {
    96  	refToBlank map[string]string
    97  }
    98  
    99  func (r *fkValuesRecorder) getBlankNode(indexLabel string) string {
   100  	return r.refToBlank[indexLabel]
   101  }
   102  
   103  // record keeps track of the mapping between referenced foreign columns and the blank node label
   104  // Consider the "person" table
   105  // fname varchar(50)
   106  // lname varchar(50)
   107  // company varchar(50)
   108  // employee_id int
   109  // primary key (fname, lname)
   110  // index unique (company, employee_id)
   111  
   112  // and it is referenced by the "salary" table
   113  // person_company varchar (50)
   114  // person_employee_id int
   115  // salary float
   116  // foreign key (person_company, person_employee_id) references person (company, employee_id)
   117  
   118  // Then the person table will have blank node label _:person_John_Doe for the row:
   119  // John (fname), Doe (lname), Google (company), 100 (employee_id)
   120  //
   121  // And we need to record the mapping from the refLabel to the blank node label
   122  // _:person_company_Google_employee_id_100 -> _:person_John_Doe
   123  // This mapping will be used later, when processing the salary table, to find the blank node label
   124  // _:person_John_Doe, which is used further to create the Dgraph link between a salary row
   125  // and the person row
   126  func (r *fkValuesRecorder) record(info *sqlTable, values []interface{},
   127  	blankNode string) {
   128  	for _, cst := range info.cstSources {
   129  		// for each foreign key constraint, there should be a mapping
   130  		cstColumns := getCstColumns(cst)
   131  		cstColumnIndices := getColumnIndices(info,
   132  			func(info *sqlTable, column string) bool {
   133  				_, ok := cstColumns[column]
   134  				return ok
   135  			})
   136  
   137  		refLabel, err := createLabel(&ref{
   138  			allColumns:       info.columns,
   139  			refColumnIndices: cstColumnIndices,
   140  			tableName:        info.tableName,
   141  			colValues:        values,
   142  		})
   143  		if err != nil {
   144  			if !quiet {
   145  				logger.Printf("ignoring the constraint because of error "+
   146  					"when getting ref label: %+v\n", cst)
   147  			}
   148  			continue
   149  		}
   150  		r.refToBlank[refLabel] = blankNode
   151  	}
   152  }
   153  
   154  func getCstColumns(cst *fkConstraint) map[string]interface{} {
   155  	columnNames := make(map[string]interface{})
   156  	for _, part := range cst.parts {
   157  		columnNames[part.columnName] = struct{}{}
   158  	}
   159  	return columnNames
   160  }
   161  
   162  func getValue(dataType dataType, value interface{}) (string, error) {
   163  	if value == nil {
   164  		return "", errors.Errorf("nil value found")
   165  	}
   166  
   167  	switch dataType {
   168  	case stringType:
   169  		return fmt.Sprintf("%s", value), nil
   170  	case intType:
   171  		if !value.(sql.NullInt64).Valid {
   172  			return "", errors.Errorf("found invalid nullint")
   173  		}
   174  		intVal, _ := value.(sql.NullInt64).Value()
   175  		return fmt.Sprintf("%v", intVal), nil
   176  	case datetimeType:
   177  		if !value.(mysql.NullTime).Valid {
   178  			return "", errors.Errorf("found invalid nulltime")
   179  		}
   180  		dateVal, _ := value.(mysql.NullTime).Value()
   181  		return fmt.Sprintf("%v", dateVal), nil
   182  	case floatType:
   183  		if !value.(sql.NullFloat64).Valid {
   184  			return "", errors.Errorf("found invalid nullfloat")
   185  		}
   186  		floatVal, _ := value.(sql.NullFloat64).Value()
   187  		return fmt.Sprintf("%v", floatVal), nil
   188  	default:
   189  		return fmt.Sprintf("%v", value), nil
   190  	}
   191  }
   192  
   193  type ref struct {
   194  	allColumns       map[string]*columnInfo
   195  	refColumnIndices []*columnIdx
   196  	tableName        string
   197  	colValues        []interface{}
   198  }
   199  
   200  func createLabel(ref *ref) (string, error) {
   201  	parts := make([]string, 0)
   202  	parts = append(parts, ref.tableName)
   203  	for _, colIdx := range ref.refColumnIndices {
   204  		colVal, err := getValue(ref.allColumns[colIdx.name].dataType,
   205  			ref.colValues[colIdx.index])
   206  		if err != nil {
   207  			return "", err
   208  		}
   209  		parts = append(parts, colIdx.name, colVal)
   210  	}
   211  
   212  	return fmt.Sprintf("_:%s", strings.Join(parts, separator)), nil
   213  }
   214  
   215  // createDgraphSchema generates one Dgraph predicate per SQL column
   216  // and the type of the predicate is inferred from the SQL column type.
   217  func createDgraphSchema(info *sqlTable) []string {
   218  	dgraphIndices := make([]string, 0)
   219  
   220  	for _, column := range info.columnNames {
   221  		if info.isForeignKey[column] {
   222  			// we do not store the plain values in foreign key columns
   223  			continue
   224  		}
   225  		predicate := fmt.Sprintf("%s%s%s", info.tableName, separator, column)
   226  
   227  		dataType := info.columns[column].dataType
   228  
   229  		dgraphIndices = append(dgraphIndices, fmt.Sprintf("%s: %s .\n",
   230  			predicate, dataType))
   231  	}
   232  
   233  	for _, cst := range info.foreignKeyConstraints {
   234  		pred := getPredFromConstraint(info.tableName, separator, cst)
   235  		dgraphIndices = append(dgraphIndices, fmt.Sprintf("%s: [%s] .\n",
   236  			pred, uidType))
   237  	}
   238  	return dgraphIndices
   239  }
   240  
   241  func getPredFromConstraint(
   242  	tableName string, separator string, constraint *fkConstraint) string {
   243  	columnNames := make([]string, 0)
   244  	for _, part := range constraint.parts {
   245  		columnNames = append(columnNames, part.columnName)
   246  	}
   247  	return fmt.Sprintf("%s%s%s", tableName, separator,
   248  		strings.Join(columnNames, separator))
   249  }
   250  
   251  func predicateName(info *sqlTable, column string) string {
   252  	return fmt.Sprintf("%s%s%s", info.tableName, separator, column)
   253  }
   254  
   255  type tableGuide struct {
   256  	blankNode      blankNode
   257  	valuesRecorder valuesRecorder
   258  }
   259  
   260  func getBlankNodeGen(ti *sqlTable) blankNode {
   261  	primaryKeyIndices := getColumnIndices(ti, func(info *sqlTable, column string) bool {
   262  		return info.columns[column].keyType == primary
   263  	})
   264  
   265  	if len(primaryKeyIndices) > 0 {
   266  		return &usingColumns{}
   267  	}
   268  	return &usingCounter{}
   269  }
   270  
   271  func getTableGuides(tables map[string]*sqlTable) map[string]*tableGuide {
   272  	tableGuides := make(map[string]*tableGuide)
   273  	for table, tableInfo := range tables {
   274  		guide := &tableGuide{
   275  			blankNode: getBlankNodeGen(tableInfo),
   276  			valuesRecorder: &fkValuesRecorder{
   277  				refToBlank: make(map[string]string),
   278  			},
   279  		}
   280  
   281  		tableGuides[table] = guide
   282  	}
   283  	return tableGuides
   284  }