github.com/NVIDIA/aistore@v1.3.23-0.20240517131212-7df6609be51d/ext/dsort/dsort_mem.go (about) 1 // Package dsort provides distributed massively parallel resharding for very large datasets. 2 /* 3 * Copyright (c) 2018-2023, NVIDIA CORPORATION. All rights reserved. 4 */ 5 package dsort 6 7 import ( 8 "context" 9 "fmt" 10 "io" 11 "sync" 12 13 "github.com/NVIDIA/aistore/cmn" 14 "github.com/NVIDIA/aistore/cmn/atomic" 15 "github.com/NVIDIA/aistore/cmn/cos" 16 "github.com/NVIDIA/aistore/cmn/debug" 17 "github.com/NVIDIA/aistore/cmn/nlog" 18 "github.com/NVIDIA/aistore/core" 19 "github.com/NVIDIA/aistore/core/meta" 20 "github.com/NVIDIA/aistore/ext/dsort/shard" 21 "github.com/NVIDIA/aistore/fs" 22 "github.com/NVIDIA/aistore/memsys" 23 "github.com/NVIDIA/aistore/sys" 24 "github.com/NVIDIA/aistore/transport" 25 "github.com/NVIDIA/aistore/transport/bundle" 26 jsoniter "github.com/json-iterator/go" 27 "github.com/pkg/errors" 28 "golang.org/x/sync/errgroup" 29 ) 30 31 // This implementation of dsorter focuses on creation phase and maximizing 32 // memory usage in that phase. It has an active push mechanism which instead 33 // of waiting for requests, it sends all the record objects it has for the 34 // shard other target is building. The requirement for using this dsorter 35 // implementation is a lot of memory available. In creation phase target 36 // needs to have enough memory to build a given shard all in memory otherwise 37 // it would be easy to deadlock when targets would send the record objects in 38 // incorrect order. 39 40 const ( 41 MemType = "dsort_mem" 42 ) 43 44 type ( 45 rwConnection struct { 46 r io.Reader 47 wgr *cos.TimeoutGroup 48 // In case the reader is first to connect, the data is copied into SGL 49 // so that the reader will not block on the connection. 50 sgl *memsys.SGL 51 52 w io.Writer 53 wgw *sync.WaitGroup 54 55 n int64 56 } 57 58 rwConnector struct { 59 mu sync.Mutex 60 m *Manager 61 connections map[string]*rwConnection 62 } 63 64 dsorterMem struct { 65 m *Manager 66 streams struct { 67 cleanupDone atomic.Bool 68 builder *bundle.Streams // streams for sending information about building shards 69 records *bundle.Streams // streams for sending the record 70 } 71 creationPhase struct { 72 connector *rwConnector // used to connect readers (streams, local data) with writers (shards) 73 requestedShards chan string 74 75 adjuster struct { 76 read *concAdjuster 77 write *concAdjuster 78 } 79 } 80 } 81 ) 82 83 // interface guard 84 var _ dsorter = (*dsorterMem)(nil) 85 86 func newRWConnection(r io.Reader, w io.Writer) *rwConnection { 87 debug.Assert(r != nil || w != nil) 88 wgr := cos.NewTimeoutGroup() 89 wgr.Add(1) 90 wgw := &sync.WaitGroup{} 91 wgw.Add(1) 92 return &rwConnection{ 93 r: r, 94 w: w, 95 wgr: wgr, 96 wgw: wgw, 97 } 98 } 99 100 func newRWConnector(m *Manager) *rwConnector { 101 return &rwConnector{ 102 m: m, 103 connections: make(map[string]*rwConnection, 1000), 104 } 105 } 106 107 func (c *rwConnector) free() { 108 c.mu.Lock() 109 for _, v := range c.connections { 110 if v.sgl != nil { 111 v.sgl.Free() 112 } 113 } 114 c.mu.Unlock() 115 } 116 117 func (c *rwConnector) connect(key string, r io.Reader, w io.Writer) (rwc *rwConnection, all bool) { 118 var ok bool 119 120 if rwc, ok = c.connections[key]; !ok { 121 rwc = newRWConnection(r, w) 122 c.connections[key] = rwc 123 } else { 124 if rwc.r == nil { 125 rwc.r = r 126 } 127 if rwc.w == nil { 128 rwc.w = w 129 } 130 } 131 all = rwc.r != nil && rwc.w != nil 132 return 133 } 134 135 func (c *rwConnector) connectReader(key string, r io.Reader, size int64) (err error) { 136 c.mu.Lock() 137 rw, all := c.connect(key, r, nil) 138 c.mu.Unlock() 139 140 if !all { 141 rw.sgl = g.mm.NewSGL(size) 142 _, err = io.Copy(rw.sgl, r) 143 rw.wgr.Done() 144 return 145 } 146 147 rw.wgr.Done() 148 rw.wgw.Wait() 149 return 150 } 151 152 func (c *rwConnector) connectWriter(key string, w io.Writer) (int64, error) { 153 c.mu.Lock() 154 rw, all := c.connect(key, nil, w) 155 c.mu.Unlock() 156 defer rw.wgw.Done() // inform the reader that the copying has finished 157 158 timed, stopped := rw.wgr.WaitTimeoutWithStop(c.m.callTimeout, c.m.listenAborted()) // wait for reader 159 if timed { 160 return 0, errors.Errorf("%s: timed out waiting for remote content", core.T) 161 } 162 if stopped { 163 return 0, errors.Errorf("%s: aborted waiting for remote content", core.T) 164 } 165 166 if all { // reader connected and left SGL with the content 167 n, err := io.Copy(rw.w, rw.sgl) 168 rw.sgl.Free() 169 rw.sgl = nil 170 return n, err 171 } 172 173 n, err := io.CopyBuffer(rw.w, rw.r, nil) 174 rw.n = n 175 return n, err 176 } 177 178 func newDsorterMem(m *Manager) *dsorterMem { 179 return &dsorterMem{ 180 m: m, 181 } 182 } 183 184 func (*dsorterMem) name() string { return MemType } 185 186 func (ds *dsorterMem) init() error { 187 ds.creationPhase.connector = newRWConnector(ds.m) 188 ds.creationPhase.requestedShards = make(chan string, 10000) 189 190 ds.creationPhase.adjuster.read = newConcAdjuster( 191 ds.m.Pars.CreateConcMaxLimit, 192 1, /*goroutineLimitCoef*/ 193 ) 194 ds.creationPhase.adjuster.write = newConcAdjuster( 195 ds.m.Pars.CreateConcMaxLimit, 196 1, /*goroutineLimitCoef*/ 197 ) 198 return nil 199 } 200 201 func (ds *dsorterMem) start() error { 202 // Requests are usually small packets, no more 1KB that is why we want to 203 // utilize intraControl network. 204 config := cmn.GCO.Get() 205 reqNetwork := cmn.NetIntraControl 206 // Responses to the other targets are objects that is why we want to use 207 // intraData network. 208 respNetwork := cmn.NetIntraData 209 210 client := transport.NewIntraDataClient() 211 212 trname := fmt.Sprintf(recvReqStreamNameFmt, ds.m.ManagerUUID) 213 reqSbArgs := bundle.Args{ 214 Multiplier: ds.m.Pars.SbundleMult, 215 Net: reqNetwork, 216 Trname: trname, 217 Ntype: core.Targets, 218 Extra: &transport.Extra{ 219 Config: config, 220 }, 221 } 222 if err := transport.Handle(trname, ds.recvReq); err != nil { 223 return errors.WithStack(err) 224 } 225 226 trname = fmt.Sprintf(recvRespStreamNameFmt, ds.m.ManagerUUID) 227 respSbArgs := bundle.Args{ 228 Multiplier: ds.m.Pars.SbundleMult, 229 Net: respNetwork, 230 Trname: trname, 231 Ntype: core.Targets, 232 Extra: &transport.Extra{ 233 Compression: config.Dsort.Compression, 234 Config: config, 235 }, 236 } 237 if err := transport.Handle(trname, ds.recvResp); err != nil { 238 return errors.WithStack(err) 239 } 240 241 ds.streams.builder = bundle.New(client, reqSbArgs) 242 ds.streams.records = bundle.New(client, respSbArgs) 243 return nil 244 } 245 246 func (ds *dsorterMem) cleanupStreams() (err error) { 247 if !ds.streams.cleanupDone.CAS(false, true) { 248 return nil 249 } 250 251 if ds.streams.builder != nil { 252 trname := fmt.Sprintf(recvReqStreamNameFmt, ds.m.ManagerUUID) 253 if unhandleErr := transport.Unhandle(trname); unhandleErr != nil { 254 err = errors.WithStack(unhandleErr) 255 } 256 } 257 258 if ds.streams.records != nil { 259 trname := fmt.Sprintf(recvRespStreamNameFmt, ds.m.ManagerUUID) 260 if unhandleErr := transport.Unhandle(trname); unhandleErr != nil { 261 err = errors.WithStack(unhandleErr) 262 } 263 } 264 265 for _, streamBundle := range []*bundle.Streams{ds.streams.builder, ds.streams.records} { 266 if streamBundle != nil { 267 // NOTE: We don't want stream to send a message at this point as the 268 // receiver might have closed its corresponding stream. 269 streamBundle.Close(false /*gracefully*/) 270 } 271 } 272 273 return err 274 } 275 276 func (*dsorterMem) cleanup() {} 277 278 func (ds *dsorterMem) finalCleanup() error { 279 err := ds.cleanupStreams() 280 close(ds.creationPhase.requestedShards) 281 ds.creationPhase.connector.free() 282 ds.creationPhase.connector = nil 283 return err 284 } 285 286 func (*dsorterMem) postRecordDistribution() {} 287 288 func (ds *dsorterMem) preShardCreation(shardName string, mi *fs.Mountpath) error { 289 bsi := &buildingShardInfo{ 290 shardName: shardName, 291 } 292 o := transport.AllocSend() 293 o.Hdr.Opaque = bsi.NewPack(core.T.ByteMM()) 294 if ds.m.smap.HasActiveTs(core.T.SID() /*except*/) { 295 if err := ds.streams.builder.Send(o, nil); err != nil { 296 return err 297 } 298 } 299 ds.creationPhase.requestedShards <- shardName // we also need to inform ourselves 300 ds.creationPhase.adjuster.write.acquireSema(mi) 301 return nil 302 } 303 304 func (ds *dsorterMem) postShardCreation(mi *fs.Mountpath) { 305 ds.creationPhase.adjuster.write.releaseSema(mi) 306 } 307 308 func (ds *dsorterMem) Load(w io.Writer, rec *shard.Record, obj *shard.RecordObj) (int64, error) { 309 if ds.m.aborted() { 310 return 0, ds.m.newErrAborted() 311 } 312 return ds.creationPhase.connector.connectWriter(rec.MakeUniqueName(obj), w) 313 } 314 315 // createShardsLocally waits until it's given the signal to start creating 316 // shards, then creates shards in parallel. 317 func (ds *dsorterMem) createShardsLocally() error { 318 phaseInfo := &ds.m.creationPhase 319 320 ds.creationPhase.adjuster.read.start() 321 ds.creationPhase.adjuster.write.start() 322 323 metrics := ds.m.Metrics.Creation 324 metrics.begin() 325 metrics.mu.Lock() 326 metrics.ToCreate = int64(len(phaseInfo.metadata.Shards)) 327 metrics.mu.Unlock() 328 329 var ( 330 mem sys.MemStat 331 wg = &sync.WaitGroup{} 332 errCh = make(chan error, 2) 333 stopCh = &cos.StopCh{} 334 ) 335 stopCh.Init() 336 337 // cleanup 338 defer func(metrics *ShardCreation, stopCh *cos.StopCh) { 339 stopCh.Close() 340 metrics.finish() 341 ds.creationPhase.adjuster.write.stop() 342 ds.creationPhase.adjuster.read.stop() 343 }(metrics, stopCh) 344 345 if err := mem.Get(); err != nil { 346 return err 347 } 348 maxMemoryToUse := calcMaxMemoryUsage(ds.m.Pars.MaxMemUsage, &mem) 349 sa := newInmemShardAllocator(maxMemoryToUse - mem.ActualUsed) 350 351 // read 352 wg.Add(1) 353 go func() { 354 ds.localRead(stopCh, errCh) 355 wg.Done() 356 }() 357 358 // write 359 wg.Add(1) 360 go func() { 361 ds.localWrite(sa, stopCh, errCh) 362 wg.Done() 363 }() 364 365 wg.Wait() 366 367 close(errCh) 368 for err := range errCh { 369 if err != nil { 370 return err 371 } 372 } 373 return nil 374 } 375 376 func (ds *dsorterMem) localRead(stopCh *cos.StopCh, errCh chan error) { 377 var ( 378 phaseInfo = &ds.m.creationPhase 379 group, ctx = errgroup.WithContext(context.Background()) 380 ) 381 outer: 382 for { 383 // If that was the last shard to send we need to break and we will 384 // be waiting for the result. 385 if len(phaseInfo.metadata.SendOrder) == 0 { 386 break outer 387 } 388 389 select { 390 case shardName := <-ds.creationPhase.requestedShards: 391 shard, ok := phaseInfo.metadata.SendOrder[shardName] 392 if !ok { 393 break 394 } 395 396 ds.creationPhase.adjuster.read.acquireGoroutineSema() 397 es := &dsmExtractShard{ds, shard} 398 group.Go(es.do) 399 400 delete(phaseInfo.metadata.SendOrder, shardName) 401 case <-ds.m.listenAborted(): 402 stopCh.Close() 403 group.Wait() 404 errCh <- ds.m.newErrAborted() 405 return 406 case <-ctx.Done(): // context was canceled, therefore we have an error 407 stopCh.Close() 408 break outer 409 case <-stopCh.Listen(): // writing side stopped we need to do the same 410 break outer 411 } 412 } 413 414 errCh <- group.Wait() 415 } 416 417 func (ds *dsorterMem) localWrite(sa *inmemShardAllocator, stopCh *cos.StopCh, errCh chan error) { 418 var ( 419 phaseInfo = &ds.m.creationPhase 420 group, ctx = errgroup.WithContext(context.Background()) 421 ) 422 outer: 423 for _, s := range phaseInfo.metadata.Shards { 424 select { 425 case <-ds.m.listenAborted(): 426 stopCh.Close() 427 group.Wait() 428 errCh <- ds.m.newErrAborted() 429 return 430 case <-ctx.Done(): // context was canceled, therefore we have an error 431 stopCh.Close() 432 break outer 433 case <-stopCh.Listen(): 434 break outer // reading side stopped we need to do the same 435 default: 436 } 437 438 sa.alloc(uint64(s.Size)) 439 440 ds.creationPhase.adjuster.write.acquireGoroutineSema() 441 cs := &dsmCreateShard{ds, s, sa} 442 group.Go(cs.do) 443 } 444 445 errCh <- group.Wait() 446 } 447 448 func (ds *dsorterMem) connectOrSend(rec *shard.Record, obj *shard.RecordObj, tsi *meta.Snode) error { 449 debug.Assert(core.T.SID() == rec.DaemonID, core.T.SID()+" vs "+rec.DaemonID) 450 var ( 451 resp = &dsmCS{ 452 ds: ds, 453 tsi: tsi, 454 rsp: RemoteResponse{Record: rec, RecordObj: obj}, 455 } 456 fullContentPath = ds.m.recm.FullContentPath(obj) 457 ) 458 ct, err := core.NewCTFromBO(&ds.m.Pars.OutputBck, fullContentPath, nil) 459 ds.creationPhase.adjuster.read.acquireSema(ct.Mountpath()) 460 defer func() { 461 if !resp.decRef { 462 ds.m.decrementRef(1) 463 } 464 ds.creationPhase.adjuster.read.releaseSema(ct.Mountpath()) 465 }() 466 467 if err != nil { 468 return err 469 } 470 if ds.m.aborted() { 471 return ds.m.newErrAborted() 472 } 473 474 resp.hdr.Opaque = cos.MustMarshal(resp.rsp) 475 if ds.m.Pars.DryRun { 476 lr := cos.NopReader(obj.MetadataSize + obj.Size) 477 r := cos.NopOpener(io.NopCloser(lr)) 478 resp.hdr.ObjAttrs.Size = obj.MetadataSize + obj.Size 479 return resp.connectOrSend(r) 480 } 481 482 switch obj.StoreType { 483 case shard.OffsetStoreType: 484 resp.hdr.ObjAttrs.Size = obj.MetadataSize + obj.Size 485 r, err := cos.NewFileSectionHandle(fullContentPath, obj.Offset-obj.MetadataSize, resp.hdr.ObjAttrs.Size) 486 if err != nil { 487 return err 488 } 489 return resp.connectOrSend(r) 490 case shard.DiskStoreType: 491 f, err := cos.NewFileHandle(fullContentPath) 492 if err != nil { 493 return err 494 } 495 fi, err := f.Stat() 496 if err != nil { 497 cos.Close(f) 498 return err 499 } 500 resp.hdr.ObjAttrs.Size = fi.Size() 501 return resp.connectOrSend(f) 502 default: 503 debug.Assert(false, obj.StoreType) 504 return nil 505 } 506 } 507 508 func (ds *dsorterMem) sentCallback(_ *transport.ObjHdr, rc io.ReadCloser, x any, err error) { 509 if sgl, ok := rc.(*memsys.SGL); ok { 510 sgl.Free() 511 } 512 ds.m.decrementRef(1) 513 if err != nil { 514 req := x.(*RemoteResponse) 515 nlog.Errorf("%s: [dsort] %s failed to send remore-rsp %s: %v - aborting...", 516 core.T, ds.m.ManagerUUID, req.Record.MakeUniqueName(req.RecordObj), err) 517 ds.m.abort(err) 518 } 519 } 520 521 func (*dsorterMem) postExtraction() {} 522 523 // implements receiver i/f 524 func (ds *dsorterMem) recvReq(hdr *transport.ObjHdr, objReader io.Reader, err error) error { 525 ds.m.inFlightInc() 526 defer func() { 527 transport.DrainAndFreeReader(objReader) 528 ds.m.inFlightDec() 529 }() 530 531 if err != nil { 532 ds.m.abort(err) 533 return err 534 } 535 536 unpacker := cos.NewUnpacker(hdr.Opaque) 537 req := buildingShardInfo{} 538 if err := unpacker.ReadAny(&req); err != nil { 539 ds.m.abort(err) 540 return err 541 } 542 543 if ds.m.aborted() { 544 return ds.m.newErrAborted() 545 } 546 547 ds.creationPhase.requestedShards <- req.shardName 548 return nil 549 } 550 551 func (ds *dsorterMem) recvResp(hdr *transport.ObjHdr, object io.Reader, err error) error { 552 ds.m.inFlightInc() 553 defer func() { 554 transport.DrainAndFreeReader(object) 555 ds.m.inFlightDec() 556 }() 557 558 if err != nil { 559 ds.m.abort(err) 560 return err 561 } 562 563 req := RemoteResponse{} 564 if err := jsoniter.Unmarshal(hdr.Opaque, &req); err != nil { 565 ds.m.abort(err) 566 return err 567 } 568 569 if ds.m.aborted() { 570 return ds.m.newErrAborted() 571 } 572 573 uname := req.Record.MakeUniqueName(req.RecordObj) 574 if err := ds.creationPhase.connector.connectReader(uname, object, hdr.ObjAttrs.Size); err != nil { 575 ds.m.abort(err) 576 return err 577 } 578 579 return nil 580 } 581 582 func (*dsorterMem) preShardExtraction(uint64) bool { return true } 583 func (*dsorterMem) postShardExtraction(uint64) {} 584 func (ds *dsorterMem) onAbort() { _ = ds.cleanupStreams() } 585 586 //////////////////// 587 // dsmCreateShard // 588 //////////////////// 589 590 type dsmCreateShard struct { 591 ds *dsorterMem 592 shard *shard.Shard 593 sa *inmemShardAllocator 594 } 595 596 func (cs *dsmCreateShard) do() (err error) { 597 lom := core.AllocLOM(cs.shard.Name) 598 err = cs.ds.m.createShard(cs.shard, lom) 599 core.FreeLOM(lom) 600 cs.ds.creationPhase.adjuster.write.releaseGoroutineSema() 601 cs.sa.free(uint64(cs.shard.Size)) 602 return 603 } 604 605 //////////////////// 606 // dsmExtractShard // 607 //////////////////// 608 609 type dsmExtractShard struct { 610 ds *dsorterMem 611 shard *shard.Shard 612 } 613 614 func (es *dsmExtractShard) do() error { 615 ds, shard := es.ds, es.shard 616 defer ds.creationPhase.adjuster.read.releaseGoroutineSema() 617 618 bck := meta.NewBck(ds.m.Pars.OutputBck.Name, ds.m.Pars.OutputBck.Provider, cmn.NsGlobal) 619 if err := bck.Init(core.T.Bowner()); err != nil { 620 return err 621 } 622 smap := core.T.Sowner().Get() 623 tsi, err := smap.HrwName2T(bck.MakeUname(shard.Name)) 624 if err != nil { 625 return err 626 } 627 for _, rec := range shard.Records.All() { 628 for _, obj := range rec.Objects { 629 if err := ds.connectOrSend(rec, obj, tsi); err != nil { 630 return err 631 } 632 } 633 } 634 return nil 635 } 636 637 /////////// 638 // dsmCS // 639 /////////// 640 641 type dsmCS struct { 642 ds *dsorterMem 643 tsi *meta.Snode 644 rsp RemoteResponse 645 hdr transport.ObjHdr 646 decRef bool 647 } 648 649 func (resp *dsmCS) connectOrSend(r cos.ReadOpenCloser) (err error) { 650 if resp.tsi.ID() == core.T.SID() { 651 uname := resp.rsp.Record.MakeUniqueName(resp.rsp.RecordObj) 652 err = resp.ds.creationPhase.connector.connectReader(uname, r, resp.hdr.ObjAttrs.Size) 653 cos.Close(r) 654 } else { 655 o := transport.AllocSend() 656 o.Hdr = resp.hdr 657 o.Callback, o.CmplArg = resp.ds.sentCallback, &resp.rsp 658 err = resp.ds.streams.records.Send(o, r, resp.tsi) 659 resp.decRef = true // sentCallback will call decrementRef 660 } 661 return 662 }