github.com/hasnat/dolt/go@v0.0.0-20210628190320-9eb5d843fbb7/libraries/doltcore/rowconv/row_converter.go (about) 1 // Copyright 2019 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package rowconv 16 17 import ( 18 "context" 19 "fmt" 20 21 "github.com/dolthub/dolt/go/libraries/doltcore/row" 22 "github.com/dolthub/dolt/go/libraries/doltcore/schema" 23 "github.com/dolthub/dolt/go/libraries/doltcore/schema/typeinfo" 24 "github.com/dolthub/dolt/go/store/types" 25 ) 26 27 var IdentityConverter = &RowConverter{nil, true, nil} 28 29 // RowConverter converts rows from one schema to another 30 type RowConverter struct { 31 // FieldMapping is a mapping from source column to destination column 32 *FieldMapping 33 // IdentityConverter is a bool which is true if the converter is doing nothing. 34 IdentityConverter bool 35 ConvFuncs map[uint64]types.MarshalCallback 36 } 37 38 func newIdentityConverter(mapping *FieldMapping) *RowConverter { 39 return &RowConverter{mapping, true, nil} 40 } 41 42 // NewRowConverter creates a row converter from a given FieldMapping. 43 func NewRowConverter(ctx context.Context, vrw types.ValueReadWriter, mapping *FieldMapping) (*RowConverter, error) { 44 if nec, err := isNecessary(mapping.SrcSch, mapping.DestSch, mapping.SrcToDest); err != nil { 45 return nil, err 46 } else if !nec { 47 return newIdentityConverter(mapping), nil 48 } 49 50 convFuncs := make(map[uint64]types.MarshalCallback, len(mapping.SrcToDest)) 51 for srcTag, destTag := range mapping.SrcToDest { 52 destCol, destOk := mapping.DestSch.GetAllCols().GetByTag(destTag) 53 srcCol, srcOk := mapping.SrcSch.GetAllCols().GetByTag(srcTag) 54 55 if !destOk || !srcOk { 56 return nil, fmt.Errorf("Could not find column being mapped. src tag: %d, dest tag: %d", srcTag, destTag) 57 } 58 59 if srcCol.TypeInfo.Equals(destCol.TypeInfo) { 60 convFuncs[srcTag] = func(v types.Value) (types.Value, error) { 61 return v, nil 62 } 63 } 64 if typeinfo.IsStringType(destCol.TypeInfo) { 65 convFuncs[srcTag] = func(v types.Value) (types.Value, error) { 66 val, err := srcCol.TypeInfo.FormatValue(v) 67 if err != nil { 68 return nil, err 69 } 70 if val == nil { 71 return types.NullValue, nil 72 } 73 return types.String(*val), nil 74 } 75 } else { 76 convFuncs[srcTag] = func(v types.Value) (types.Value, error) { 77 return typeinfo.Convert(ctx, vrw, v, srcCol.TypeInfo, destCol.TypeInfo) 78 } 79 } 80 } 81 82 return &RowConverter{mapping, false, convFuncs}, nil 83 } 84 85 // NewImportRowConverter creates a row converter from a given FieldMapping specifically for importing. 86 func NewImportRowConverter(ctx context.Context, vrw types.ValueReadWriter, mapping *FieldMapping) (*RowConverter, error) { 87 if nec, err := isNecessary(mapping.SrcSch, mapping.DestSch, mapping.SrcToDest); err != nil { 88 return nil, err 89 } else if !nec { 90 return newIdentityConverter(mapping), nil 91 } 92 93 convFuncs := make(map[uint64]types.MarshalCallback, len(mapping.SrcToDest)) 94 for srcTag, destTag := range mapping.SrcToDest { 95 destCol, destOk := mapping.DestSch.GetAllCols().GetByTag(destTag) 96 srcCol, srcOk := mapping.SrcSch.GetAllCols().GetByTag(srcTag) 97 98 if !destOk || !srcOk { 99 return nil, fmt.Errorf("Could not find column being mapped. src tag: %d, dest tag: %d", srcTag, destTag) 100 } 101 102 if srcCol.TypeInfo.Equals(destCol.TypeInfo) { 103 convFuncs[srcTag] = func(v types.Value) (types.Value, error) { 104 return v, nil 105 } 106 } 107 if typeinfo.IsStringType(destCol.TypeInfo) { 108 convFuncs[srcTag] = func(v types.Value) (types.Value, error) { 109 val, err := srcCol.TypeInfo.FormatValue(v) 110 if err != nil { 111 return nil, err 112 } 113 if val == nil { 114 return types.NullValue, nil 115 } 116 return types.String(*val), nil 117 } 118 } else if destCol.TypeInfo.Equals(typeinfo.PseudoBoolType) || destCol.TypeInfo.Equals(typeinfo.Int8Type) { 119 // BIT(1) and BOOLEAN (MySQL alias for TINYINT or Int8) are both logical stand-ins for a bool type 120 convFuncs[srcTag] = func(v types.Value) (types.Value, error) { 121 intermediateVal, err := typeinfo.Convert(ctx, vrw, v, srcCol.TypeInfo, typeinfo.BoolType) 122 if err != nil { 123 return nil, err 124 } 125 return typeinfo.Convert(ctx, vrw, intermediateVal, typeinfo.BoolType, destCol.TypeInfo) 126 } 127 } else { 128 convFuncs[srcTag] = func(v types.Value) (types.Value, error) { 129 return typeinfo.Convert(ctx, vrw, v, srcCol.TypeInfo, destCol.TypeInfo) 130 } 131 } 132 } 133 134 return &RowConverter{mapping, false, convFuncs}, nil 135 } 136 137 // Convert takes a row maps its columns to their destination columns, and performs any type conversion needed to create 138 // a row of the expected destination schema. 139 func (rc *RowConverter) Convert(inRow row.Row) (row.Row, error) { 140 if rc.IdentityConverter { 141 return inRow, nil 142 } 143 144 outTaggedVals := make(row.TaggedValues, len(rc.SrcToDest)) 145 _, err := inRow.IterCols(func(tag uint64, val types.Value) (stop bool, err error) { 146 convFunc, ok := rc.ConvFuncs[tag] 147 148 if ok { 149 outTag := rc.SrcToDest[tag] 150 outVal, err := convFunc(val) 151 152 if err != nil { 153 return false, err 154 } 155 156 outTaggedVals[outTag] = outVal 157 } 158 159 return false, nil 160 }) 161 162 if err != nil { 163 return nil, err 164 } 165 166 return row.New(inRow.Format(), rc.DestSch, outTaggedVals) 167 } 168 169 func isNecessary(srcSch, destSch schema.Schema, destToSrc map[uint64]uint64) (bool, error) { 170 srcCols := srcSch.GetAllCols() 171 destCols := destSch.GetAllCols() 172 173 if len(destToSrc) != srcCols.Size() || len(destToSrc) != destCols.Size() { 174 return true, nil 175 } 176 177 for k, v := range destToSrc { 178 if k != v { 179 return true, nil 180 } 181 182 srcCol, srcOk := srcCols.GetByTag(v) 183 destCol, destOk := destCols.GetByTag(k) 184 185 if !srcOk || !destOk { 186 panic("There is a bug. FieldMapping creation should prevent this from happening") 187 } 188 189 if srcCol.IsPartOfPK != destCol.IsPartOfPK { 190 return true, nil 191 } 192 193 if !srcCol.TypeInfo.Equals(destCol.TypeInfo) { 194 return true, nil 195 } 196 } 197 198 srcPKCols := srcSch.GetPKCols() 199 destPKCols := destSch.GetPKCols() 200 201 if srcPKCols.Size() != destPKCols.Size() { 202 return true, nil 203 } 204 205 i := 0 206 err := destPKCols.Iter(func(tag uint64, col schema.Column) (stop bool, err error) { 207 srcPKCol := srcPKCols.GetByIndex(i) 208 209 if srcPKCol.Tag != col.Tag { 210 return true, nil 211 } 212 213 i++ 214 return false, nil 215 }) 216 217 if err != nil { 218 return false, err 219 } 220 221 return false, nil 222 }