github.com/unigraph-dev/dgraph@v1.1.1-0.20200923154953-8b52b426f765/dgraph/cmd/bulk/reduce.go (about)

     1  /*
     2   * Copyright 2017-2018 Dgraph Labs, Inc. and Contributors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package bulk
    18  
    19  import (
    20  	"bufio"
    21  	"bytes"
    22  	"compress/gzip"
    23  	"container/heap"
    24  	"encoding/binary"
    25  	"fmt"
    26  	"io"
    27  	"log"
    28  	"os"
    29  	"sync/atomic"
    30  
    31  	"github.com/dgraph-io/badger"
    32  	bo "github.com/dgraph-io/badger/options"
    33  	bpb "github.com/dgraph-io/badger/pb"
    34  	"github.com/dgraph-io/badger/y"
    35  	"github.com/dgraph-io/dgraph/codec"
    36  	"github.com/dgraph-io/dgraph/posting"
    37  	"github.com/dgraph-io/dgraph/protos/pb"
    38  	"github.com/dgraph-io/dgraph/x"
    39  	"github.com/gogo/protobuf/proto"
    40  )
    41  
    42  type reducer struct {
    43  	*state
    44  	streamId uint32
    45  }
    46  
    47  func (r *reducer) run() error {
    48  	dirs := shardDirs(r.opt.TmpDir)
    49  	x.AssertTrue(len(dirs) == r.opt.ReduceShards)
    50  	x.AssertTrue(len(r.opt.shardOutputDirs) == r.opt.ReduceShards)
    51  
    52  	thr := y.NewThrottle(r.opt.NumReducers)
    53  	for i := 0; i < r.opt.ReduceShards; i++ {
    54  		if err := thr.Do(); err != nil {
    55  			return err
    56  		}
    57  		go func(shardId int, db *badger.DB) {
    58  			defer thr.Done(nil)
    59  
    60  			mapFiles := filenamesInTree(dirs[shardId])
    61  			var mapItrs []*mapIterator
    62  			for _, mapFile := range mapFiles {
    63  				itr := newMapIterator(mapFile)
    64  				mapItrs = append(mapItrs, itr)
    65  			}
    66  
    67  			writer := db.NewStreamWriter()
    68  			if err := writer.Prepare(); err != nil {
    69  				x.Check(err)
    70  			}
    71  
    72  			ci := &countIndexer{reducer: r, writer: writer}
    73  			r.reduce(mapItrs, ci)
    74  			ci.wait()
    75  
    76  			if err := writer.Flush(); err != nil {
    77  				x.Check(err)
    78  			}
    79  			for _, itr := range mapItrs {
    80  				if err := itr.Close(); err != nil {
    81  					fmt.Printf("Error while closing iterator: %v", err)
    82  				}
    83  			}
    84  		}(i, r.createBadger(i))
    85  	}
    86  	return thr.Finish()
    87  }
    88  
    89  func (r *reducer) createBadger(i int) *badger.DB {
    90  	opt := badger.DefaultOptions(r.opt.shardOutputDirs[i]).WithSyncWrites(false).
    91  		WithTableLoadingMode(bo.MemoryMap).WithValueThreshold(1 << 10 /* 1 KB */).
    92  		WithLogger(nil)
    93  	db, err := badger.OpenManaged(opt)
    94  	x.Check(err)
    95  	r.dbs = append(r.dbs, db)
    96  	return db
    97  }
    98  
    99  type mapIterator struct {
   100  	fd     *os.File
   101  	reader *bufio.Reader
   102  	tmpBuf []byte
   103  }
   104  
   105  func (mi *mapIterator) Close() error {
   106  	return mi.fd.Close()
   107  }
   108  
   109  func (mi *mapIterator) Next() *pb.MapEntry {
   110  	r := mi.reader
   111  	buf, err := r.Peek(binary.MaxVarintLen64)
   112  	if err == io.EOF {
   113  		return nil
   114  	}
   115  	x.Check(err)
   116  	sz, n := binary.Uvarint(buf)
   117  	if n <= 0 {
   118  		log.Fatalf("Could not read uvarint: %d", n)
   119  	}
   120  	x.Check2(r.Discard(n))
   121  
   122  	for cap(mi.tmpBuf) < int(sz) {
   123  		mi.tmpBuf = make([]byte, sz)
   124  	}
   125  	x.Check2(io.ReadFull(r, mi.tmpBuf[:sz]))
   126  
   127  	me := new(pb.MapEntry)
   128  	x.Check(proto.Unmarshal(mi.tmpBuf[:sz], me))
   129  	return me
   130  }
   131  
   132  func newMapIterator(filename string) *mapIterator {
   133  	fd, err := os.Open(filename)
   134  	x.Check(err)
   135  	gzReader, err := gzip.NewReader(fd)
   136  	x.Check(err)
   137  
   138  	return &mapIterator{fd: fd, reader: bufio.NewReaderSize(gzReader, 16<<10)}
   139  }
   140  
   141  func (r *reducer) encodeAndWrite(
   142  	writer *badger.StreamWriter, entryCh chan []*pb.MapEntry, closer *y.Closer) {
   143  	defer closer.Done()
   144  
   145  	var listSize int
   146  	list := &bpb.KVList{}
   147  
   148  	preds := make(map[string]uint32)
   149  	setStreamId := func(kv *bpb.KV) {
   150  		pk, err := x.Parse(kv.Key)
   151  		x.Check(err)
   152  		x.AssertTrue(len(pk.Attr) > 0)
   153  
   154  		// We don't need to consider the data prefix, count prefix, etc. because each predicate
   155  		// contains sorted keys, the way they are produced.
   156  		streamId := preds[pk.Attr]
   157  		if streamId == 0 {
   158  			streamId = atomic.AddUint32(&r.streamId, 1)
   159  			preds[pk.Attr] = streamId
   160  		}
   161  		// TODO: Having many stream ids can cause memory issues with StreamWriter. So, we
   162  		// should build a way in StreamWriter to indicate that the stream is over, so the
   163  		// table for that stream can be flushed and memory released.
   164  		kv.StreamId = streamId
   165  	}
   166  
   167  	for batch := range entryCh {
   168  		listSize += r.toList(batch, list)
   169  		if listSize > 4<<20 {
   170  			for _, kv := range list.Kv {
   171  				setStreamId(kv)
   172  			}
   173  			x.Check(writer.Write(list))
   174  			list = &bpb.KVList{}
   175  			listSize = 0
   176  		}
   177  	}
   178  	if len(list.Kv) > 0 {
   179  		for _, kv := range list.Kv {
   180  			setStreamId(kv)
   181  		}
   182  		x.Check(writer.Write(list))
   183  	}
   184  }
   185  
   186  func (r *reducer) reduce(mapItrs []*mapIterator, ci *countIndexer) {
   187  	entryCh := make(chan []*pb.MapEntry, 100)
   188  	closer := y.NewCloser(1)
   189  	defer closer.SignalAndWait()
   190  
   191  	var ph postingHeap
   192  	for _, itr := range mapItrs {
   193  		me := itr.Next()
   194  		if me != nil {
   195  			heap.Push(&ph, heapNode{mapEntry: me, itr: itr})
   196  		} else {
   197  			fmt.Printf("NIL first map entry for %s", itr.fd.Name())
   198  		}
   199  	}
   200  
   201  	writer := ci.writer
   202  	go r.encodeAndWrite(writer, entryCh, closer)
   203  
   204  	const batchSize = 10000
   205  	const batchAlloc = batchSize * 11 / 10
   206  	batch := make([]*pb.MapEntry, 0, batchAlloc)
   207  	var prevKey []byte
   208  	var plistLen int
   209  
   210  	for len(ph.nodes) > 0 {
   211  		node0 := &ph.nodes[0]
   212  		me := node0.mapEntry
   213  		node0.mapEntry = node0.itr.Next()
   214  		if node0.mapEntry != nil {
   215  			heap.Fix(&ph, 0)
   216  		} else {
   217  			heap.Pop(&ph)
   218  		}
   219  
   220  		keyChanged := !bytes.Equal(prevKey, me.Key)
   221  		// Note that the keys are coming in sorted order from the heap. So, if
   222  		// we see a new key, we should push out the number of entries we got
   223  		// for the current key, so the count index can register that.
   224  		if keyChanged && plistLen > 0 {
   225  			ci.addUid(prevKey, plistLen)
   226  			plistLen = 0
   227  		}
   228  
   229  		if len(batch) >= batchSize && keyChanged {
   230  			entryCh <- batch
   231  			batch = make([]*pb.MapEntry, 0, batchAlloc)
   232  		}
   233  		prevKey = me.Key
   234  		batch = append(batch, me)
   235  		plistLen++
   236  	}
   237  	if len(batch) > 0 {
   238  		entryCh <- batch
   239  	}
   240  	if plistLen > 0 {
   241  		ci.addUid(prevKey, plistLen)
   242  	}
   243  	close(entryCh)
   244  }
   245  
   246  type heapNode struct {
   247  	mapEntry *pb.MapEntry
   248  	itr      *mapIterator
   249  }
   250  
   251  type postingHeap struct {
   252  	nodes []heapNode
   253  }
   254  
   255  func (h *postingHeap) Len() int {
   256  	return len(h.nodes)
   257  }
   258  func (h *postingHeap) Less(i, j int) bool {
   259  	return less(h.nodes[i].mapEntry, h.nodes[j].mapEntry)
   260  }
   261  func (h *postingHeap) Swap(i, j int) {
   262  	h.nodes[i], h.nodes[j] = h.nodes[j], h.nodes[i]
   263  }
   264  func (h *postingHeap) Push(x interface{}) {
   265  	h.nodes = append(h.nodes, x.(heapNode))
   266  }
   267  func (h *postingHeap) Pop() interface{} {
   268  	elem := h.nodes[len(h.nodes)-1]
   269  	h.nodes = h.nodes[:len(h.nodes)-1]
   270  	return elem
   271  }
   272  
   273  func (r *reducer) toList(mapEntries []*pb.MapEntry, list *bpb.KVList) int {
   274  	var currentKey []byte
   275  	var uids []uint64
   276  	pl := new(pb.PostingList)
   277  	var size int
   278  
   279  	appendToList := func() {
   280  		atomic.AddInt64(&r.prog.reduceKeyCount, 1)
   281  
   282  		// For a UID-only posting list, the badger value is a delta packed UID
   283  		// list. The UserMeta indicates to treat the value as a delta packed
   284  		// list when the value is read by dgraph.  For a value posting list,
   285  		// the full pb.Posting type is used (which pb.y contains the
   286  		// delta packed UID list).
   287  		if len(uids) == 0 {
   288  			return
   289  		}
   290  
   291  		// If the schema is of type uid and not a list but we have more than one uid in this
   292  		// list, we cannot enforce the constraint without losing data. Inform the user and
   293  		// force the schema to be a list so that all the data can be found when Dgraph is started.
   294  		// The user should fix their data once Dgraph is up.
   295  		parsedKey, err := x.Parse(currentKey)
   296  		x.Check(err)
   297  		if parsedKey.IsData() {
   298  			schema := r.state.schema.getSchema(parsedKey.Attr)
   299  			if schema.GetValueType() == pb.Posting_UID && !schema.GetList() && len(uids) > 1 {
   300  				fmt.Printf("Schema for pred %s specifies that this is not a list but more than  "+
   301  					"one UID has been found. Forcing the schema to be a list to avoid any "+
   302  					"data loss. Please fix the data to your specifications once Dgraph is up.\n",
   303  					parsedKey.Attr)
   304  				r.state.schema.setSchemaAsList(parsedKey.Attr)
   305  			}
   306  		}
   307  
   308  		pl.Pack = codec.Encode(uids, 256)
   309  		val, err := pl.Marshal()
   310  		x.Check(err)
   311  		kv := &bpb.KV{
   312  			Key:      y.Copy(currentKey),
   313  			Value:    val,
   314  			UserMeta: []byte{posting.BitCompletePosting},
   315  			Version:  r.state.writeTs,
   316  		}
   317  		size += kv.Size()
   318  		list.Kv = append(list.Kv, kv)
   319  		uids = uids[:0]
   320  		pl.Reset()
   321  	}
   322  
   323  	for _, mapEntry := range mapEntries {
   324  		atomic.AddInt64(&r.prog.reduceEdgeCount, 1)
   325  
   326  		if !bytes.Equal(mapEntry.Key, currentKey) && currentKey != nil {
   327  			appendToList()
   328  		}
   329  		currentKey = mapEntry.Key
   330  
   331  		uid := mapEntry.Uid
   332  		if mapEntry.Posting != nil {
   333  			uid = mapEntry.Posting.Uid
   334  		}
   335  		if len(uids) > 0 && uids[len(uids)-1] == uid {
   336  			continue
   337  		}
   338  		uids = append(uids, uid)
   339  		if mapEntry.Posting != nil {
   340  			pl.Postings = append(pl.Postings, mapEntry.Posting)
   341  		}
   342  	}
   343  	appendToList()
   344  	return size
   345  }