github.com/hasnat/dolt/go@v0.0.0-20210628190320-9eb5d843fbb7/libraries/doltcore/rowconv/joiner.go (about)

     1  // Copyright 2019 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package rowconv
    16  
    17  import (
    18  	"errors"
    19  
    20  	"github.com/dolthub/dolt/go/libraries/doltcore/row"
    21  	"github.com/dolthub/dolt/go/libraries/doltcore/schema"
    22  	"github.com/dolthub/dolt/go/store/types"
    23  )
    24  
    25  // ColNamingFunc defines a function signature which takes the name of a column, and returns the name that should be used
    26  // for the column in the joined dataset.
    27  type ColNamingFunc func(colName string) string
    28  
    29  type stringUint64Tuple struct {
    30  	str string
    31  	u64 uint64
    32  }
    33  
    34  // NamedSchema is an object that associates a schema with a string
    35  type NamedSchema struct {
    36  	// Name the name given to the schema
    37  	Name string
    38  
    39  	// Sch is the schema
    40  	Sch schema.Schema
    41  }
    42  
    43  // Joiner is an object that can be used to join multiple rows together into a single row (See Join), and also to reverse
    44  // this operation by taking a joined row and getting back a map of rows (See Split).
    45  type Joiner struct {
    46  	srcSchemas map[string]schema.Schema
    47  	tagMaps    map[string]map[uint64]uint64
    48  	revTagMap  map[uint64]stringUint64Tuple
    49  	joined     schema.Schema
    50  }
    51  
    52  // NewJoiner creates a joiner from a slice of NamedSchemas and a map of ColNamingFuncs.  A new schema for joined rows will
    53  // be created, and the columns for joined schemas will be named according to the ColNamingFunc associated with each schema
    54  // name.
    55  func NewJoiner(namedSchemas []NamedSchema, namers map[string]ColNamingFunc) (*Joiner, error) {
    56  	tags := make(map[string][]uint64)
    57  	revTagMap := make(map[uint64]stringUint64Tuple)
    58  	tagMaps := make(map[string]map[uint64]uint64, len(namedSchemas))
    59  	srcSchemas := make(map[string]schema.Schema)
    60  	for _, namedSch := range namedSchemas {
    61  		tagMaps[namedSch.Name] = make(map[uint64]uint64)
    62  		srcSchemas[namedSch.Name] = namedSch.Sch
    63  	}
    64  
    65  	var cols []schema.Column
    66  	var destTag uint64
    67  	var err error
    68  	for _, namedSch := range namedSchemas {
    69  		sch := namedSch.Sch
    70  		name := namedSch.Name
    71  		allCols := sch.GetAllCols()
    72  		namer := namers[name]
    73  		allCols.IterInSortedOrder(func(srcTag uint64, col schema.Column) (stop bool) {
    74  			newColName := namer(col.Name)
    75  			var newCol schema.Column
    76  			newCol, err = schema.NewColumnWithTypeInfo(newColName, destTag, col.TypeInfo, false, col.Default, false, col.Comment)
    77  			if err != nil {
    78  				return true
    79  			}
    80  			cols = append(cols, newCol)
    81  			tagMaps[name][srcTag] = destTag
    82  			revTagMap[destTag] = stringUint64Tuple{str: name, u64: srcTag}
    83  			tags[name] = append(tags[name], destTag)
    84  			destTag++
    85  
    86  			return false
    87  		})
    88  	}
    89  	if err != nil {
    90  		return nil, err
    91  	}
    92  
    93  	colColl := schema.NewColCollection(cols...)
    94  
    95  	joined := schema.UnkeyedSchemaFromCols(colColl)
    96  
    97  	return &Joiner{srcSchemas, tagMaps, revTagMap, joined}, nil
    98  }
    99  
   100  // Join takes a map from schema name to row which has that schema, and returns a single joined row containing all the
   101  // data
   102  func (j *Joiner) Join(namedRows map[string]row.Row) (row.Row, error) {
   103  	var nbf *types.NomsBinFormat
   104  	colVals := make(row.TaggedValues)
   105  	for name, r := range namedRows {
   106  		if r == nil {
   107  			continue
   108  		}
   109  
   110  		if nbf == nil {
   111  			nbf = r.Format()
   112  		} else if nbf.VersionString() != r.Format().VersionString() {
   113  			return nil, errors.New("not all rows have the same format")
   114  		}
   115  
   116  		_, err := r.IterCols(func(tag uint64, val types.Value) (stop bool, err error) {
   117  			destTag := j.tagMaps[name][tag]
   118  			colVals[destTag] = val
   119  
   120  			return false, nil
   121  		})
   122  
   123  		if err != nil {
   124  			return nil, err
   125  		}
   126  	}
   127  
   128  	return row.New(nbf, j.joined, colVals)
   129  }
   130  
   131  // Split takes a row which has the created joined schema, and splits it into a map of rows where the key of the map is
   132  // the name of the schema for the associated row.
   133  func (j *Joiner) Split(r row.Row) (map[string]row.Row, error) {
   134  	colVals := make(map[string]row.TaggedValues, len(j.srcSchemas))
   135  	for name := range j.srcSchemas {
   136  		colVals[name] = make(row.TaggedValues)
   137  	}
   138  
   139  	_, err := r.IterCols(func(tag uint64, val types.Value) (stop bool, err error) {
   140  		schemaNameAndTag := j.revTagMap[tag]
   141  		colVals[schemaNameAndTag.str][schemaNameAndTag.u64] = val
   142  
   143  		return false, nil
   144  	})
   145  
   146  	if err != nil {
   147  		return nil, err
   148  	}
   149  
   150  	rows := make(map[string]row.Row, len(colVals))
   151  	for name, sch := range j.srcSchemas {
   152  		var err error
   153  
   154  		currColVals := colVals[name]
   155  
   156  		if len(currColVals) == 0 {
   157  			continue
   158  		}
   159  
   160  		rows[name], err = row.New(r.Format(), sch, currColVals)
   161  
   162  		if err != nil {
   163  			return nil, err
   164  		}
   165  	}
   166  
   167  	return rows, nil
   168  }
   169  
   170  // GetSchema returns the schema which all joined rows will have, and any row passed into split should have.
   171  func (j *Joiner) GetSchema() schema.Schema {
   172  	return j.joined
   173  }
   174  
   175  // SchemaForName retrieves the original schema which has the given name.
   176  func (j *Joiner) SchemaForName(name string) schema.Schema {
   177  	return j.srcSchemas[name]
   178  }