github.com/hasnat/dolt/go@v0.0.0-20210628190320-9eb5d843fbb7/libraries/doltcore/rowconv/row_converter.go (about)

     1  // Copyright 2019 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package rowconv
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  
    21  	"github.com/dolthub/dolt/go/libraries/doltcore/row"
    22  	"github.com/dolthub/dolt/go/libraries/doltcore/schema"
    23  	"github.com/dolthub/dolt/go/libraries/doltcore/schema/typeinfo"
    24  	"github.com/dolthub/dolt/go/store/types"
    25  )
    26  
    27  var IdentityConverter = &RowConverter{nil, true, nil}
    28  
    29  // RowConverter converts rows from one schema to another
    30  type RowConverter struct {
    31  	// FieldMapping is a mapping from source column to destination column
    32  	*FieldMapping
    33  	// IdentityConverter is a bool which is true if the converter is doing nothing.
    34  	IdentityConverter bool
    35  	ConvFuncs         map[uint64]types.MarshalCallback
    36  }
    37  
    38  func newIdentityConverter(mapping *FieldMapping) *RowConverter {
    39  	return &RowConverter{mapping, true, nil}
    40  }
    41  
    42  // NewRowConverter creates a row converter from a given FieldMapping.
    43  func NewRowConverter(ctx context.Context, vrw types.ValueReadWriter, mapping *FieldMapping) (*RowConverter, error) {
    44  	if nec, err := isNecessary(mapping.SrcSch, mapping.DestSch, mapping.SrcToDest); err != nil {
    45  		return nil, err
    46  	} else if !nec {
    47  		return newIdentityConverter(mapping), nil
    48  	}
    49  
    50  	convFuncs := make(map[uint64]types.MarshalCallback, len(mapping.SrcToDest))
    51  	for srcTag, destTag := range mapping.SrcToDest {
    52  		destCol, destOk := mapping.DestSch.GetAllCols().GetByTag(destTag)
    53  		srcCol, srcOk := mapping.SrcSch.GetAllCols().GetByTag(srcTag)
    54  
    55  		if !destOk || !srcOk {
    56  			return nil, fmt.Errorf("Could not find column being mapped. src tag: %d, dest tag: %d", srcTag, destTag)
    57  		}
    58  
    59  		if srcCol.TypeInfo.Equals(destCol.TypeInfo) {
    60  			convFuncs[srcTag] = func(v types.Value) (types.Value, error) {
    61  				return v, nil
    62  			}
    63  		}
    64  		if typeinfo.IsStringType(destCol.TypeInfo) {
    65  			convFuncs[srcTag] = func(v types.Value) (types.Value, error) {
    66  				val, err := srcCol.TypeInfo.FormatValue(v)
    67  				if err != nil {
    68  					return nil, err
    69  				}
    70  				if val == nil {
    71  					return types.NullValue, nil
    72  				}
    73  				return types.String(*val), nil
    74  			}
    75  		} else {
    76  			convFuncs[srcTag] = func(v types.Value) (types.Value, error) {
    77  				return typeinfo.Convert(ctx, vrw, v, srcCol.TypeInfo, destCol.TypeInfo)
    78  			}
    79  		}
    80  	}
    81  
    82  	return &RowConverter{mapping, false, convFuncs}, nil
    83  }
    84  
    85  // NewImportRowConverter creates a row converter from a given FieldMapping specifically for importing.
    86  func NewImportRowConverter(ctx context.Context, vrw types.ValueReadWriter, mapping *FieldMapping) (*RowConverter, error) {
    87  	if nec, err := isNecessary(mapping.SrcSch, mapping.DestSch, mapping.SrcToDest); err != nil {
    88  		return nil, err
    89  	} else if !nec {
    90  		return newIdentityConverter(mapping), nil
    91  	}
    92  
    93  	convFuncs := make(map[uint64]types.MarshalCallback, len(mapping.SrcToDest))
    94  	for srcTag, destTag := range mapping.SrcToDest {
    95  		destCol, destOk := mapping.DestSch.GetAllCols().GetByTag(destTag)
    96  		srcCol, srcOk := mapping.SrcSch.GetAllCols().GetByTag(srcTag)
    97  
    98  		if !destOk || !srcOk {
    99  			return nil, fmt.Errorf("Could not find column being mapped. src tag: %d, dest tag: %d", srcTag, destTag)
   100  		}
   101  
   102  		if srcCol.TypeInfo.Equals(destCol.TypeInfo) {
   103  			convFuncs[srcTag] = func(v types.Value) (types.Value, error) {
   104  				return v, nil
   105  			}
   106  		}
   107  		if typeinfo.IsStringType(destCol.TypeInfo) {
   108  			convFuncs[srcTag] = func(v types.Value) (types.Value, error) {
   109  				val, err := srcCol.TypeInfo.FormatValue(v)
   110  				if err != nil {
   111  					return nil, err
   112  				}
   113  				if val == nil {
   114  					return types.NullValue, nil
   115  				}
   116  				return types.String(*val), nil
   117  			}
   118  		} else if destCol.TypeInfo.Equals(typeinfo.PseudoBoolType) || destCol.TypeInfo.Equals(typeinfo.Int8Type) {
   119  			// BIT(1) and BOOLEAN (MySQL alias for TINYINT or Int8) are both logical stand-ins for a bool type
   120  			convFuncs[srcTag] = func(v types.Value) (types.Value, error) {
   121  				intermediateVal, err := typeinfo.Convert(ctx, vrw, v, srcCol.TypeInfo, typeinfo.BoolType)
   122  				if err != nil {
   123  					return nil, err
   124  				}
   125  				return typeinfo.Convert(ctx, vrw, intermediateVal, typeinfo.BoolType, destCol.TypeInfo)
   126  			}
   127  		} else {
   128  			convFuncs[srcTag] = func(v types.Value) (types.Value, error) {
   129  				return typeinfo.Convert(ctx, vrw, v, srcCol.TypeInfo, destCol.TypeInfo)
   130  			}
   131  		}
   132  	}
   133  
   134  	return &RowConverter{mapping, false, convFuncs}, nil
   135  }
   136  
   137  // Convert takes a row maps its columns to their destination columns, and performs any type conversion needed to create
   138  // a row of the expected destination schema.
   139  func (rc *RowConverter) Convert(inRow row.Row) (row.Row, error) {
   140  	if rc.IdentityConverter {
   141  		return inRow, nil
   142  	}
   143  
   144  	outTaggedVals := make(row.TaggedValues, len(rc.SrcToDest))
   145  	_, err := inRow.IterCols(func(tag uint64, val types.Value) (stop bool, err error) {
   146  		convFunc, ok := rc.ConvFuncs[tag]
   147  
   148  		if ok {
   149  			outTag := rc.SrcToDest[tag]
   150  			outVal, err := convFunc(val)
   151  
   152  			if err != nil {
   153  				return false, err
   154  			}
   155  
   156  			outTaggedVals[outTag] = outVal
   157  		}
   158  
   159  		return false, nil
   160  	})
   161  
   162  	if err != nil {
   163  		return nil, err
   164  	}
   165  
   166  	return row.New(inRow.Format(), rc.DestSch, outTaggedVals)
   167  }
   168  
   169  func isNecessary(srcSch, destSch schema.Schema, destToSrc map[uint64]uint64) (bool, error) {
   170  	srcCols := srcSch.GetAllCols()
   171  	destCols := destSch.GetAllCols()
   172  
   173  	if len(destToSrc) != srcCols.Size() || len(destToSrc) != destCols.Size() {
   174  		return true, nil
   175  	}
   176  
   177  	for k, v := range destToSrc {
   178  		if k != v {
   179  			return true, nil
   180  		}
   181  
   182  		srcCol, srcOk := srcCols.GetByTag(v)
   183  		destCol, destOk := destCols.GetByTag(k)
   184  
   185  		if !srcOk || !destOk {
   186  			panic("There is a bug.  FieldMapping creation should prevent this from happening")
   187  		}
   188  
   189  		if srcCol.IsPartOfPK != destCol.IsPartOfPK {
   190  			return true, nil
   191  		}
   192  
   193  		if !srcCol.TypeInfo.Equals(destCol.TypeInfo) {
   194  			return true, nil
   195  		}
   196  	}
   197  
   198  	srcPKCols := srcSch.GetPKCols()
   199  	destPKCols := destSch.GetPKCols()
   200  
   201  	if srcPKCols.Size() != destPKCols.Size() {
   202  		return true, nil
   203  	}
   204  
   205  	i := 0
   206  	err := destPKCols.Iter(func(tag uint64, col schema.Column) (stop bool, err error) {
   207  		srcPKCol := srcPKCols.GetByIndex(i)
   208  
   209  		if srcPKCol.Tag != col.Tag {
   210  			return true, nil
   211  		}
   212  
   213  		i++
   214  		return false, nil
   215  	})
   216  
   217  	if err != nil {
   218  		return false, err
   219  	}
   220  
   221  	return false, nil
   222  }