github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/libraries/doltcore/rowconv/row_converter.go (about)

     1  // Copyright 2019 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package rowconv
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  
    21  	"github.com/dolthub/dolt/go/libraries/doltcore/row"
    22  	"github.com/dolthub/dolt/go/libraries/doltcore/schema"
    23  	"github.com/dolthub/dolt/go/libraries/doltcore/schema/typeinfo"
    24  	"github.com/dolthub/dolt/go/store/types"
    25  
    26  	"github.com/dolthub/go-mysql-server/sql"
    27  )
    28  
    29  var IdentityConverter = &RowConverter{nil, true, nil}
    30  
    31  // WarnFunction is a callback function that callers can optionally provide during row conversion
    32  // to take an extra action when a value cannot be automatically converted to the output data type.
    33  type WarnFunction func(int, string, ...interface{})
    34  
    35  const DatatypeCoercionFailureWarning = "unable to coerce value from field '%s' into latest column schema"
    36  const DatatypeCoercionFailureWarningCode int = 1105 // Since this our own custom warning we'll use 1105, the code for an unknown error
    37  
    38  const TruncatedOutOfRangeValueWarning = "Truncated %v value: %v"
    39  const TruncatedOutOfRangeValueWarningCode = 1292
    40  
    41  // RowConverter converts rows from one schema to another
    42  type RowConverter struct {
    43  	// FieldMapping is a mapping from source column to destination column
    44  	*FieldMapping
    45  	// IdentityConverter is a bool which is true if the converter is doing nothing.
    46  	IdentityConverter bool
    47  	ConvFuncs         map[uint64]types.MarshalCallback
    48  }
    49  
    50  func newIdentityConverter(mapping *FieldMapping) *RowConverter {
    51  	return &RowConverter{mapping, true, nil}
    52  }
    53  
    54  // NewRowConverter creates a row converter from a given FieldMapping.
    55  func NewRowConverter(ctx context.Context, vrw types.ValueReadWriter, mapping *FieldMapping) (*RowConverter, error) {
    56  	if nec, err := IsNecessary(mapping.SrcSch, mapping.DestSch, mapping.SrcToDest); err != nil {
    57  		return nil, err
    58  	} else if !nec {
    59  		return newIdentityConverter(mapping), nil
    60  	}
    61  
    62  	// Panic if there are any duplicate columns mapped to the same destination tag.
    63  	panicOnDuplicateMappings(mapping)
    64  
    65  	convFuncs := make(map[uint64]types.MarshalCallback, len(mapping.SrcToDest))
    66  	for srcTag, destTag := range mapping.SrcToDest {
    67  		destCol, destOk := mapping.DestSch.GetAllCols().GetByTag(destTag)
    68  		srcCol, srcOk := mapping.SrcSch.GetAllCols().GetByTag(srcTag)
    69  
    70  		if !destOk || !srcOk {
    71  			return nil, fmt.Errorf("Could not find column being mapped. src tag: %d, dest tag: %d", srcTag, destTag)
    72  		}
    73  
    74  		tc, _, err := typeinfo.GetTypeConverter(ctx, srcCol.TypeInfo, destCol.TypeInfo)
    75  		if err != nil {
    76  			return nil, err
    77  		}
    78  		convFuncs[srcTag] = func(v types.Value) (types.Value, error) {
    79  			return tc(ctx, vrw, v)
    80  		}
    81  	}
    82  
    83  	return &RowConverter{mapping, false, convFuncs}, nil
    84  }
    85  
    86  // panicOnDuplicateMappings checks if more than one input field is mapped to the same output field.
    87  // Multiple input fields mapped to the same output field results in a race condition.
    88  func panicOnDuplicateMappings(mapping *FieldMapping) {
    89  	destToSrcMapping := make(map[uint64]uint64, len(mapping.SrcToDest))
    90  	for srcTag, destTag := range mapping.SrcToDest {
    91  		if _, found := destToSrcMapping[destTag]; found {
    92  			panic("multiple columns mapped to the same destination tag '" + types.Uint(destTag).HumanReadableString() + "'")
    93  		}
    94  		destToSrcMapping[destTag] = srcTag
    95  	}
    96  }
    97  
    98  // ConvertWithWarnings takes an input row, maps its columns to their destination columns, performing any type
    99  // conversions needed to create a row of the expected destination schema, and uses the optional WarnFunction
   100  // callback to let callers handle logging a warning when a field cannot be cleanly converted.
   101  func (rc *RowConverter) ConvertWithWarnings(inRow row.Row, warnFn WarnFunction) (row.Row, error) {
   102  	return rc.convert(inRow, warnFn)
   103  }
   104  
   105  // convert takes a row and maps its columns to their destination columns, automatically performing any type conversion
   106  // needed, and using the optional WarnFunction to let callers log warnings on any type conversion errors.
   107  func (rc *RowConverter) convert(inRow row.Row, warnFn WarnFunction) (row.Row, error) {
   108  	if rc.IdentityConverter {
   109  		return inRow, nil
   110  	}
   111  
   112  	outTaggedVals := make(row.TaggedValues, len(rc.SrcToDest))
   113  	_, err := inRow.IterCols(func(tag uint64, val types.Value) (stop bool, err error) {
   114  		convFunc, ok := rc.ConvFuncs[tag]
   115  
   116  		if ok {
   117  			outTag := rc.SrcToDest[tag]
   118  			outVal, err := convFunc(val)
   119  
   120  			if sql.ErrInvalidValue.Is(err) && warnFn != nil {
   121  				col, _ := rc.SrcSch.GetAllCols().GetByTag(tag)
   122  				warnFn(DatatypeCoercionFailureWarningCode, DatatypeCoercionFailureWarning, col.Name)
   123  				outVal = types.NullValue
   124  				err = nil
   125  			}
   126  
   127  			if err != nil {
   128  				return false, err
   129  			}
   130  
   131  			if types.IsNull(outVal) {
   132  				return false, nil
   133  			}
   134  
   135  			outTaggedVals[outTag] = outVal
   136  		}
   137  
   138  		return false, nil
   139  	})
   140  
   141  	if err != nil {
   142  		return nil, err
   143  	}
   144  
   145  	return row.New(inRow.Format(), rc.DestSch, outTaggedVals)
   146  }
   147  
   148  func IsNecessary(srcSch, destSch schema.Schema, destToSrc map[uint64]uint64) (bool, error) {
   149  	srcCols := srcSch.GetAllCols()
   150  	destCols := destSch.GetAllCols()
   151  
   152  	if len(destToSrc) != srcCols.Size() || len(destToSrc) != destCols.Size() {
   153  		return true, nil
   154  	}
   155  
   156  	for k, v := range destToSrc {
   157  		if k != v {
   158  			return true, nil
   159  		}
   160  
   161  		srcCol, srcOk := srcCols.GetByTag(v)
   162  		destCol, destOk := destCols.GetByTag(k)
   163  
   164  		if !srcOk || !destOk {
   165  			panic("There is a bug.  FieldMapping creation should prevent this from happening")
   166  		}
   167  
   168  		if srcCol.IsPartOfPK != destCol.IsPartOfPK {
   169  			return true, nil
   170  		}
   171  
   172  		if !srcCol.TypeInfo.Equals(destCol.TypeInfo) {
   173  			return true, nil
   174  		}
   175  	}
   176  
   177  	srcPKCols := srcSch.GetPKCols()
   178  	destPKCols := destSch.GetPKCols()
   179  
   180  	if srcPKCols.Size() != destPKCols.Size() {
   181  		return true, nil
   182  	}
   183  
   184  	i := 0
   185  	err := destPKCols.Iter(func(tag uint64, col schema.Column) (stop bool, err error) {
   186  		srcPKCol := srcPKCols.GetByIndex(i)
   187  
   188  		if srcPKCol.Tag != col.Tag {
   189  			return true, nil
   190  		}
   191  
   192  		i++
   193  		return false, nil
   194  	})
   195  
   196  	if err != nil {
   197  		return false, err
   198  	}
   199  
   200  	return false, nil
   201  }