github.com/unigraph-dev/dgraph@v1.1.1-0.20200923154953-8b52b426f765/dgraph/cmd/bulk/mapper.go (about)

     1  /*
     2   * Copyright 2017-2018 Dgraph Labs, Inc. and Contributors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package bulk
    18  
    19  import (
    20  	"bufio"
    21  	"bytes"
    22  	"compress/gzip"
    23  	"encoding/binary"
    24  	"fmt"
    25  	"log"
    26  	"math"
    27  	"os"
    28  	"path/filepath"
    29  	"sort"
    30  	"strconv"
    31  	"strings"
    32  	"sync"
    33  	"sync/atomic"
    34  
    35  	"github.com/dgraph-io/dgo/protos/api"
    36  	"github.com/dgraph-io/dgraph/chunker"
    37  	"github.com/dgraph-io/dgraph/gql"
    38  	"github.com/dgraph-io/dgraph/posting"
    39  	"github.com/dgraph-io/dgraph/protos/pb"
    40  	"github.com/dgraph-io/dgraph/tok"
    41  	"github.com/dgraph-io/dgraph/types"
    42  	"github.com/dgraph-io/dgraph/types/facets"
    43  	"github.com/dgraph-io/dgraph/x"
    44  	farm "github.com/dgryski/go-farm"
    45  )
    46  
    47  type mapper struct {
    48  	*state
    49  	shards []shardState // shard is based on predicate
    50  	mePool *sync.Pool
    51  }
    52  
    53  type shardState struct {
    54  	// Buffer up map entries until we have a sufficient amount, then sort and
    55  	// write them to file.
    56  	entries     []*pb.MapEntry
    57  	encodedSize uint64
    58  	mu          sync.Mutex // Allow only 1 write per shard at a time.
    59  }
    60  
    61  func newMapper(st *state) *mapper {
    62  	return &mapper{
    63  		state:  st,
    64  		shards: make([]shardState, st.opt.MapShards),
    65  		mePool: &sync.Pool{
    66  			New: func() interface{} {
    67  				return &pb.MapEntry{}
    68  			},
    69  		},
    70  	}
    71  }
    72  
    73  func less(lhs, rhs *pb.MapEntry) bool {
    74  	if keyCmp := bytes.Compare(lhs.Key, rhs.Key); keyCmp != 0 {
    75  		return keyCmp < 0
    76  	}
    77  	lhsUID := lhs.Uid
    78  	rhsUID := rhs.Uid
    79  	if lhs.Posting != nil {
    80  		lhsUID = lhs.Posting.Uid
    81  	}
    82  	if rhs.Posting != nil {
    83  		rhsUID = rhs.Posting.Uid
    84  	}
    85  	return lhsUID < rhsUID
    86  }
    87  
    88  func (m *mapper) openOutputFile(shardIdx int) (*os.File, error) {
    89  	fileNum := atomic.AddUint32(&m.mapFileId, 1)
    90  	filename := filepath.Join(
    91  		m.opt.TmpDir,
    92  		"shards",
    93  		fmt.Sprintf("%03d", shardIdx),
    94  		fmt.Sprintf("%06d.map.gz", fileNum),
    95  	)
    96  	x.Check(os.MkdirAll(filepath.Dir(filename), 0755))
    97  	return os.OpenFile(filename, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
    98  }
    99  
   100  func (m *mapper) writeMapEntriesToFile(entries []*pb.MapEntry, encodedSize uint64, shardIdx int) {
   101  	defer m.shards[shardIdx].mu.Unlock() // Locked by caller.
   102  
   103  	sort.Slice(entries, func(i, j int) bool {
   104  		return less(entries[i], entries[j])
   105  	})
   106  
   107  	f, err := m.openOutputFile(shardIdx)
   108  	x.Check(err)
   109  
   110  	defer func() {
   111  		x.Check(f.Sync())
   112  		x.Check(f.Close())
   113  	}()
   114  
   115  	gzWriter := gzip.NewWriter(f)
   116  	w := bufio.NewWriter(gzWriter)
   117  	defer func() {
   118  		x.Check(w.Flush())
   119  		x.Check(gzWriter.Flush())
   120  		x.Check(gzWriter.Close())
   121  	}()
   122  
   123  	sizeBuf := make([]byte, binary.MaxVarintLen64)
   124  	for _, me := range entries {
   125  		n := binary.PutUvarint(sizeBuf, uint64(me.Size()))
   126  		_, err := w.Write(sizeBuf[:n])
   127  		x.Check(err)
   128  
   129  		meBuf, err := me.Marshal()
   130  		x.Check(err)
   131  		_, err = w.Write(meBuf)
   132  		x.Check(err)
   133  		m.mePool.Put(me)
   134  	}
   135  }
   136  
   137  func (m *mapper) run(inputFormat chunker.InputFormat) {
   138  	chunker := chunker.NewChunker(inputFormat, 1000)
   139  	nquads := chunker.NQuads()
   140  	go func() {
   141  		for chunkBuf := range m.readerChunkCh {
   142  			if err := chunker.Parse(chunkBuf); err != nil {
   143  				atomic.AddInt64(&m.prog.errCount, 1)
   144  				if !m.opt.IgnoreErrors {
   145  					x.Check(err)
   146  				}
   147  			}
   148  		}
   149  		nquads.Flush()
   150  	}()
   151  
   152  	for nqs := range nquads.Ch() {
   153  		for _, nq := range nqs {
   154  			if err := facets.SortAndValidate(nq.Facets); err != nil {
   155  				atomic.AddInt64(&m.prog.errCount, 1)
   156  				if !m.opt.IgnoreErrors {
   157  					x.Check(err)
   158  				}
   159  			}
   160  
   161  			m.processNQuad(gql.NQuad{NQuad: nq})
   162  			atomic.AddInt64(&m.prog.nquadCount, 1)
   163  		}
   164  
   165  		for i := range m.shards {
   166  			sh := &m.shards[i]
   167  			if sh.encodedSize >= m.opt.MapBufSize {
   168  				sh.mu.Lock() // One write at a time.
   169  				go m.writeMapEntriesToFile(sh.entries, sh.encodedSize, i)
   170  				// Clear the entries and encodedSize for the next batch.
   171  				// Proactively allocate 32 slots to bootstrap the entries slice.
   172  				sh.entries = make([]*pb.MapEntry, 0, 32)
   173  				sh.encodedSize = 0
   174  			}
   175  		}
   176  	}
   177  
   178  	for i := range m.shards {
   179  		sh := &m.shards[i]
   180  		if len(sh.entries) > 0 {
   181  			sh.mu.Lock() // One write at a time.
   182  			m.writeMapEntriesToFile(sh.entries, sh.encodedSize, i)
   183  		}
   184  		m.shards[i].mu.Lock() // Ensure that the last file write finishes.
   185  	}
   186  }
   187  
   188  func (m *mapper) addMapEntry(key []byte, p *pb.Posting, shard int) {
   189  	atomic.AddInt64(&m.prog.mapEdgeCount, 1)
   190  
   191  	me := m.mePool.Get().(*pb.MapEntry)
   192  	*me = pb.MapEntry{Key: key}
   193  
   194  	if p.PostingType != pb.Posting_REF || len(p.Facets) > 0 {
   195  		me.Posting = p
   196  	} else {
   197  		me.Uid = p.Uid
   198  	}
   199  	sh := &m.shards[shard]
   200  
   201  	var err error
   202  	sh.entries = append(sh.entries, me)
   203  	sh.encodedSize += uint64(me.Size())
   204  	x.Check(err)
   205  }
   206  
   207  func (m *mapper) processNQuad(nq gql.NQuad) {
   208  	sid := m.uid(nq.GetSubject())
   209  	var oid uint64
   210  	var de *pb.DirectedEdge
   211  	if nq.GetObjectValue() == nil {
   212  		oid = m.uid(nq.GetObjectId())
   213  		de = nq.CreateUidEdge(sid, oid)
   214  	} else {
   215  		var err error
   216  		de, err = nq.CreateValueEdge(sid)
   217  		x.Check(err)
   218  	}
   219  
   220  	fwd, rev := m.createPostings(nq, de)
   221  	shard := m.state.shards.shardFor(nq.Predicate)
   222  	key := x.DataKey(nq.Predicate, sid)
   223  	m.addMapEntry(key, fwd, shard)
   224  
   225  	if rev != nil {
   226  		key = x.ReverseKey(nq.Predicate, oid)
   227  		m.addMapEntry(key, rev, shard)
   228  	}
   229  	m.addIndexMapEntries(nq, de)
   230  }
   231  
   232  func (m *mapper) uid(xid string) uint64 {
   233  	if !m.opt.NewUids {
   234  		if uid, err := strconv.ParseUint(xid, 0, 64); err == nil {
   235  			m.xids.BumpTo(uid)
   236  			return uid
   237  		}
   238  	}
   239  
   240  	return m.lookupUid(xid)
   241  }
   242  
   243  func (m *mapper) lookupUid(xid string) uint64 {
   244  	uid := m.xids.AssignUid(xid)
   245  	if !m.opt.StoreXids {
   246  		return uid
   247  	}
   248  	if strings.HasPrefix(xid, "_:") {
   249  		// Don't store xids for blank nodes.
   250  		return uid
   251  	}
   252  	nq := gql.NQuad{NQuad: &api.NQuad{
   253  		Subject:   xid,
   254  		Predicate: "xid",
   255  		ObjectValue: &api.Value{
   256  			Val: &api.Value_StrVal{StrVal: xid},
   257  		},
   258  	}}
   259  	m.processNQuad(nq)
   260  	return uid
   261  }
   262  
   263  func (m *mapper) createPostings(nq gql.NQuad,
   264  	de *pb.DirectedEdge) (*pb.Posting, *pb.Posting) {
   265  
   266  	m.schema.validateType(de, nq.ObjectValue == nil)
   267  
   268  	p := posting.NewPosting(de)
   269  	sch := m.schema.getSchema(nq.GetPredicate())
   270  	if nq.GetObjectValue() != nil {
   271  		if lang := de.GetLang(); len(lang) > 0 {
   272  			p.Uid = farm.Fingerprint64([]byte(lang))
   273  		} else if sch.List {
   274  			p.Uid = farm.Fingerprint64(de.Value)
   275  		} else {
   276  			p.Uid = math.MaxUint64
   277  		}
   278  	}
   279  	p.Facets = nq.Facets
   280  
   281  	// Early exit for no reverse edge.
   282  	if sch.GetDirective() != pb.SchemaUpdate_REVERSE {
   283  		return p, nil
   284  	}
   285  
   286  	// Reverse predicate
   287  	x.AssertTruef(nq.GetObjectValue() == nil, "only has reverse schema if object is UID")
   288  	de.Entity, de.ValueId = de.ValueId, de.Entity
   289  	m.schema.validateType(de, true)
   290  	rp := posting.NewPosting(de)
   291  
   292  	de.Entity, de.ValueId = de.ValueId, de.Entity // de reused so swap back.
   293  
   294  	return p, rp
   295  }
   296  
   297  func (m *mapper) addIndexMapEntries(nq gql.NQuad, de *pb.DirectedEdge) {
   298  	if nq.GetObjectValue() == nil {
   299  		return // Cannot index UIDs
   300  	}
   301  
   302  	sch := m.schema.getSchema(nq.GetPredicate())
   303  	for _, tokerName := range sch.GetTokenizer() {
   304  		// Find tokeniser.
   305  		toker, ok := tok.GetTokenizer(tokerName)
   306  		if !ok {
   307  			log.Fatalf("unknown tokenizer %q", tokerName)
   308  		}
   309  
   310  		// Create storage value.
   311  		storageVal := types.Val{
   312  			Tid:   types.TypeID(de.GetValueType()),
   313  			Value: de.GetValue(),
   314  		}
   315  
   316  		// Convert from storage type to schema type.
   317  		schemaVal, err := types.Convert(storageVal, types.TypeID(sch.GetValueType()))
   318  		// Shouldn't error, since we've already checked for convertibility when
   319  		// doing edge postings. So okay to be fatal.
   320  		x.Check(err)
   321  
   322  		// Extract tokens.
   323  		toks, err := tok.BuildTokens(schemaVal.Value, tok.GetLangTokenizer(toker, nq.Lang))
   324  		x.Check(err)
   325  
   326  		// Store index posting.
   327  		for _, t := range toks {
   328  			m.addMapEntry(
   329  				x.IndexKey(nq.Predicate, t),
   330  				&pb.Posting{
   331  					Uid:         de.GetEntity(),
   332  					PostingType: pb.Posting_REF,
   333  				},
   334  				m.state.shards.shardFor(nq.Predicate),
   335  			)
   336  		}
   337  	}
   338  }