github.com/pachyderm/pachyderm@v1.13.4/src/client/pkg/shard/sharder.go (about) 1 package shard 2 3 import ( 4 "fmt" 5 "math" 6 "path" 7 "sort" 8 "strings" 9 "sync" 10 "time" 11 12 "github.com/gogo/protobuf/jsonpb" 13 "github.com/pachyderm/pachyderm/src/client/pkg/discovery" 14 "github.com/pachyderm/pachyderm/src/client/pkg/errors" 15 log "github.com/sirupsen/logrus" 16 17 "golang.org/x/sync/errgroup" 18 ) 19 20 // InvalidVersion is defined as -1 since valid versions are non-negative. 21 const InvalidVersion int64 = -1 22 23 var ( 24 holdTTL uint64 = 20 25 marshaler = &jsonpb.Marshaler{} 26 // ErrCancelled is returned when an action is cancelled by the user 27 ErrCancelled = errors.Errorf("cancelled by user") 28 errComplete = errors.Errorf("COMPLETE") 29 ) 30 31 type sharder struct { 32 discoveryClient discovery.Client 33 numShards uint64 34 namespace string 35 addresses map[int64]*Addresses 36 addressesLock sync.RWMutex 37 } 38 39 func newSharder(discoveryClient discovery.Client, numShards uint64, namespace string) *sharder { 40 return &sharder{discoveryClient, numShards, namespace, make(map[int64]*Addresses), sync.RWMutex{}} 41 } 42 43 func (a *sharder) GetAddress(shard uint64, version int64) (result string, ok bool, retErr error) { 44 addresses, err := a.getAddresses(version) 45 if err != nil { 46 return "", false, err 47 } 48 address, ok := addresses.Addresses[shard] 49 if !ok { 50 return "", false, nil 51 } 52 return address, true, nil 53 } 54 55 func (a *sharder) GetShardToAddress(version int64) (result map[uint64]string, retErr error) { 56 addresses, err := a.getAddresses(version) 57 if err != nil { 58 return nil, err 59 } 60 _result := make(map[uint64]string) 61 for shard, address := range addresses.Addresses { 62 _result[shard] = address 63 } 64 return _result, nil 65 } 66 67 func (a *sharder) Register(address string, servers []Server) (retErr error) { 68 var once sync.Once 69 versionChan := make(chan int64) 70 internalCancel := make(chan bool) 71 var wg sync.WaitGroup 72 wg.Add(3) 73 go func() { 74 defer wg.Done() 75 if err := a.announceServers(address, servers, versionChan, internalCancel); err != nil { 76 once.Do(func() { 77 retErr = err 78 close(internalCancel) 79 }) 80 } 81 }() 82 go func() { 83 defer wg.Done() 84 if err := a.fillRoles(address, servers, versionChan, internalCancel); err != nil { 85 once.Do(func() { 86 retErr = err 87 close(internalCancel) 88 }) 89 } 90 }() 91 go func() { 92 defer wg.Done() 93 <-internalCancel 94 }() 95 wg.Wait() 96 return 97 } 98 99 func (a *sharder) RegisterFrontends(address string, frontends []Frontend) (retErr error) { 100 var once sync.Once 101 versionChan := make(chan int64) 102 internalCancel := make(chan bool) 103 var wg sync.WaitGroup 104 wg.Add(3) 105 go func() { 106 defer wg.Done() 107 if err := a.announceFrontends(address, frontends, versionChan, internalCancel); err != nil { 108 once.Do(func() { 109 retErr = err 110 close(internalCancel) 111 }) 112 } 113 }() 114 go func() { 115 defer wg.Done() 116 if err := a.runFrontends(address, frontends, versionChan, internalCancel); err != nil { 117 once.Do(func() { 118 retErr = err 119 close(internalCancel) 120 }) 121 } 122 }() 123 go func() { 124 defer wg.Done() 125 <-internalCancel 126 }() 127 wg.Wait() 128 return 129 } 130 131 func (a *sharder) AssignRoles(address string) (retErr error) { 132 var unsafeAssignRolesCancel chan bool 133 errChan := make(chan error) 134 // oldValue is the last value we wrote, if it's not "" it means we have the 135 // lock since we're the ones who set it last 136 oldValue := "" 137 for { 138 if err := a.discoveryClient.CheckAndSet("lock", address, holdTTL, oldValue); err != nil { 139 if oldValue != "" { 140 // lock lost 141 oldValue = "" 142 close(unsafeAssignRolesCancel) 143 log.Errorf("sharder.AssignRoles error from unsafeAssignRolesCancel: %+v", <-errChan) 144 } 145 } else { 146 if oldValue == "" { 147 // lock acquired 148 oldValue = address 149 unsafeAssignRolesCancel = make(chan bool) 150 go func() { 151 errChan <- a.unsafeAssignRoles(unsafeAssignRolesCancel) 152 }() 153 } 154 } 155 <-time.After(time.Second * time.Duration(holdTTL/2)) 156 } 157 } 158 159 // unsafeAssignRoles should be run 160 func (a *sharder) unsafeAssignRoles(cancel chan bool) (retErr error) { 161 var version int64 162 oldServers := make(map[string]bool) 163 oldRoles := make(map[string]*ServerRole) 164 oldShards := make(map[uint64]string) 165 var oldMinVersion int64 166 // Reconstruct state from a previous run 167 serverRoles, err := a.discoveryClient.GetAll(a.serverRoleDir()) 168 if err != nil { 169 return err 170 } 171 for _, encodedServerRole := range serverRoles { 172 serverRole, err := decodeServerRole(encodedServerRole) 173 if err != nil { 174 return err 175 } 176 if oldServerRole, ok := oldRoles[serverRole.Address]; !ok || oldServerRole.Version < serverRole.Version { 177 oldRoles[serverRole.Address] = serverRole 178 oldServers[serverRole.Address] = true 179 } 180 if version < serverRole.Version+1 { 181 version = serverRole.Version + 1 182 } 183 } 184 for _, oldServerRole := range oldRoles { 185 for shard := range oldServerRole.Shards { 186 oldShards[shard] = oldServerRole.Address 187 } 188 } 189 err = a.discoveryClient.WatchAll(a.serverStateDir(), cancel, 190 func(encodedServerStates map[string]string) error { 191 if len(encodedServerStates) == 0 { 192 return nil 193 } 194 newServerStates := make(map[string]*ServerState) 195 newRoles := make(map[string]*ServerRole) 196 newShards := make(map[uint64]string) 197 shardsPerServer := a.numShards / uint64(len(encodedServerStates)) 198 shardsRemainder := a.numShards % uint64(len(encodedServerStates)) 199 for _, encodedServerState := range encodedServerStates { 200 serverState, err := decodeServerState(encodedServerState) 201 if err != nil { 202 return err 203 } 204 newServerStates[serverState.Address] = serverState 205 newRoles[serverState.Address] = &ServerRole{ 206 Address: serverState.Address, 207 Version: version, 208 Shards: make(map[uint64]bool), 209 } 210 } 211 // See if there's any roles we can delete 212 minVersion := int64(math.MaxInt64) 213 for _, serverState := range newServerStates { 214 if serverState.Version < minVersion { 215 minVersion = serverState.Version 216 } 217 } 218 // Delete roles that no servers are using anymore 219 if minVersion > oldMinVersion { 220 oldMinVersion = minVersion 221 if err := a.discoveryClient.WatchAll( 222 a.frontendStateDir(), 223 cancel, 224 func(encodedFrontendStates map[string]string) error { 225 for _, encodedFrontendState := range encodedFrontendStates { 226 frontendState, err := decodeFrontendState(encodedFrontendState) 227 if err != nil { 228 return err 229 } 230 if frontendState.Version < minVersion { 231 return nil 232 } 233 } 234 return errComplete 235 }); err != nil && !errors.Is(err, errComplete) { 236 return err 237 } 238 serverRoles, err := a.discoveryClient.GetAll(a.serverRoleDir()) 239 if err != nil { 240 return err 241 } 242 for key, encodedServerRole := range serverRoles { 243 serverRole, err := decodeServerRole(encodedServerRole) 244 if err != nil { 245 return err 246 } 247 if serverRole.Version < minVersion { 248 if err := a.discoveryClient.Delete(key); err != nil { 249 return err 250 } 251 } 252 } 253 } 254 // if the servers are identical to last time then we know we'll 255 // assign shards the same way 256 if sameServers(oldServers, newServerStates) { 257 return nil 258 } 259 Shard: 260 for shard := uint64(0); shard < a.numShards; shard++ { 261 if address, ok := oldShards[shard]; ok { 262 if assignShard(newRoles, newShards, address, shard, shardsPerServer, &shardsRemainder) { 263 continue Shard 264 } 265 } 266 for address := range newServerStates { 267 if assignShard(newRoles, newShards, address, shard, shardsPerServer, &shardsRemainder) { 268 continue Shard 269 } 270 } 271 log.Error(&FailedToAssignRoles{ 272 ServerStates: newServerStates, 273 NumShards: a.numShards, 274 }) 275 return nil 276 } 277 addresses := Addresses{ 278 Version: version, 279 Addresses: make(map[uint64]string), 280 } 281 for address, serverRole := range newRoles { 282 encodedServerRole, err := marshaler.MarshalToString(serverRole) 283 if err != nil { 284 return err 285 } 286 if err := a.discoveryClient.Set(a.serverRoleKeyVersion(address, version), encodedServerRole, 0); err != nil { 287 return err 288 } 289 address := newServerStates[address].Address 290 for shard := range serverRole.Shards { 291 addresses.Addresses[shard] = address 292 } 293 } 294 encodedAddresses, err := marshaler.MarshalToString(&addresses) 295 if err != nil { 296 return err 297 } 298 if err := a.discoveryClient.Set(a.addressesKey(version), encodedAddresses, 0); err != nil { 299 return err 300 } 301 version++ 302 oldServers = make(map[string]bool) 303 for address := range newServerStates { 304 oldServers[address] = true 305 } 306 oldRoles = newRoles 307 oldShards = newShards 308 return nil 309 }) 310 if errors.Is(err, discovery.ErrCancelled) { 311 return ErrCancelled 312 } 313 return err 314 } 315 316 func (a *sharder) WaitForAvailability(frontendAddresses []string, serverAddresses []string) error { 317 version := InvalidVersion 318 if err := a.discoveryClient.WatchAll(a.serverDir(), nil, 319 func(encodedServerStatesAndRoles map[string]string) error { 320 serverStates := make(map[string]*ServerState) 321 serverRoles := make(map[string]map[int64]*ServerRole) 322 for key, encodedServerStateOrRole := range encodedServerStatesAndRoles { 323 if strings.HasPrefix(key, a.serverStateDir()) { 324 serverState, err := decodeServerState(encodedServerStateOrRole) 325 if err != nil { 326 return err 327 } 328 serverStates[serverState.Address] = serverState 329 } 330 if strings.HasPrefix(key, a.serverRoleDir()) { 331 serverRole, err := decodeServerRole(encodedServerStateOrRole) 332 if err != nil { 333 return err 334 } 335 if _, ok := serverRoles[serverRole.Address]; !ok { 336 serverRoles[serverRole.Address] = make(map[int64]*ServerRole) 337 } 338 serverRoles[serverRole.Address][serverRole.Version] = serverRole 339 } 340 } 341 if len(serverStates) != len(serverAddresses) { 342 return nil 343 } 344 if len(serverRoles) != len(serverAddresses) { 345 return nil 346 } 347 for _, address := range serverAddresses { 348 if _, ok := serverStates[address]; !ok { 349 return nil 350 } 351 if _, ok := serverRoles[address]; !ok { 352 return nil 353 } 354 } 355 versions := make(map[int64]bool) 356 for _, serverState := range serverStates { 357 if serverState.Version == InvalidVersion { 358 return nil 359 } 360 versions[serverState.Version] = true 361 } 362 if len(versions) != 1 { 363 return nil 364 } 365 for _, versionToServerRole := range serverRoles { 366 if len(versionToServerRole) != 1 { 367 return nil 368 } 369 for version := range versionToServerRole { 370 if !versions[version] { 371 return nil 372 } 373 } 374 } 375 // This loop actually does something, it sets the outside 376 // version variable. 377 for version = range versions { 378 } 379 return errComplete 380 }); !errors.Is(err, errComplete) { 381 return err 382 } 383 384 if err := a.discoveryClient.WatchAll( 385 a.frontendStateDir(), 386 nil, 387 func(encodedFrontendStates map[string]string) error { 388 frontendStates := make(map[string]*FrontendState) 389 for _, encodedFrontendState := range encodedFrontendStates { 390 frontendState, err := decodeFrontendState(encodedFrontendState) 391 if err != nil { 392 return err 393 } 394 395 if frontendState.Version != version { 396 return nil 397 } 398 frontendStates[frontendState.Address] = frontendState 399 } 400 if len(frontendStates) != len(frontendAddresses) { 401 return nil 402 } 403 for _, address := range frontendAddresses { 404 if _, ok := frontendStates[address]; !ok { 405 return nil 406 } 407 } 408 return errComplete 409 }); err != nil && !errors.Is(err, errComplete) { 410 return err 411 } 412 return nil 413 } 414 415 type localSharder struct { 416 shardToAddress map[uint64]string 417 } 418 419 func newLocalSharder(addresses []string, numShards uint64) *localSharder { 420 result := &localSharder{shardToAddress: make(map[uint64]string)} 421 for i := uint64(0); i < numShards; i++ { 422 result.shardToAddress[i] = addresses[int(i)%len(addresses)] 423 } 424 return result 425 } 426 427 func (s *localSharder) GetAddress(shard uint64, version int64) (string, bool, error) { 428 address, ok := s.shardToAddress[shard] 429 return address, ok, nil 430 } 431 432 func (s *localSharder) GetShardToAddress(version int64) (map[uint64]string, error) { 433 return s.shardToAddress, nil 434 } 435 436 func (s *localSharder) Register(address string, servers []Server) error { 437 return nil 438 } 439 440 func (s *localSharder) RegisterFrontends(address string, frontends []Frontend) error { 441 return nil 442 } 443 444 func (s *localSharder) AssignRoles(string) error { 445 return nil 446 } 447 448 func (a *sharder) routeDir() string { 449 return fmt.Sprintf("%s/pfs/route", a.namespace) 450 } 451 452 func (a *sharder) serverDir() string { 453 return path.Join(a.routeDir(), "server") 454 } 455 456 func (a *sharder) serverStateDir() string { 457 return path.Join(a.serverDir(), "state") 458 } 459 460 func (a *sharder) serverStateKey(address string) string { 461 return path.Join(a.serverStateDir(), address) 462 } 463 464 func (a *sharder) serverRoleDir() string { 465 return path.Join(a.serverDir(), "role") 466 } 467 468 func (a *sharder) serverRoleKey(address string) string { 469 return path.Join(a.serverRoleDir(), address) 470 } 471 472 func (a *sharder) serverRoleKeyVersion(address string, version int64) string { 473 return path.Join(a.serverRoleKey(address), fmt.Sprint(version)) 474 } 475 476 func (a *sharder) frontendDir() string { 477 return path.Join(a.routeDir(), "frontend") 478 } 479 480 func (a *sharder) frontendStateDir() string { 481 return path.Join(a.frontendDir(), "state") 482 } 483 484 func (a *sharder) frontendStateKey(address string) string { 485 return path.Join(a.frontendStateDir(), address) 486 } 487 488 func (a *sharder) addressesDir() string { 489 return path.Join(a.routeDir(), "addresses") 490 } 491 492 func (a *sharder) addressesKey(version int64) string { 493 return path.Join(a.addressesDir(), fmt.Sprint(version)) 494 } 495 496 func decodeServerState(encodedServerState string) (*ServerState, error) { 497 var serverState ServerState 498 if err := jsonpb.UnmarshalString(encodedServerState, &serverState); err != nil { 499 return nil, err 500 } 501 return &serverState, nil 502 } 503 504 func decodeFrontendState(encodedFrontendState string) (*FrontendState, error) { 505 var frontendState FrontendState 506 if err := jsonpb.UnmarshalString(encodedFrontendState, &frontendState); err != nil { 507 return nil, err 508 } 509 return &frontendState, nil 510 } 511 512 func decodeServerRole(encodedServerRole string) (*ServerRole, error) { 513 var serverRole ServerRole 514 if err := jsonpb.UnmarshalString(encodedServerRole, &serverRole); err != nil { 515 return nil, err 516 } 517 return &serverRole, nil 518 } 519 520 func (a *sharder) getAddresses(version int64) (*Addresses, error) { 521 if version == InvalidVersion { 522 return nil, errors.Errorf("invalid version") 523 } 524 a.addressesLock.RLock() 525 if addresses, ok := a.addresses[version]; ok { 526 a.addressesLock.RUnlock() 527 return addresses, nil 528 } 529 a.addressesLock.RUnlock() 530 a.addressesLock.Lock() 531 defer a.addressesLock.Unlock() 532 encodedAddresses, err := a.discoveryClient.Get(a.addressesKey(version)) 533 if err != nil { 534 return nil, err 535 } 536 var addresses Addresses 537 if err := jsonpb.UnmarshalString(encodedAddresses, &addresses); err != nil { 538 return nil, err 539 } 540 a.addresses[version] = &addresses 541 return &addresses, nil 542 } 543 544 func hasShard(serverRole *ServerRole, shard uint64) bool { 545 return serverRole.Shards[shard] 546 } 547 548 func assignShard( 549 serverRoles map[string]*ServerRole, 550 shards map[uint64]string, 551 address string, 552 shard uint64, 553 shardsPerServer uint64, 554 shardsRemainder *uint64, 555 ) bool { 556 serverRole, ok := serverRoles[address] 557 if !ok { 558 return false 559 } 560 if uint64(len(serverRole.Shards)) > shardsPerServer { 561 return false 562 } 563 if uint64(len(serverRole.Shards)) == shardsPerServer && *shardsRemainder == 0 { 564 return false 565 } 566 if hasShard(serverRole, shard) { 567 return false 568 } 569 if uint64(len(serverRole.Shards)) == shardsPerServer && *shardsRemainder > 0 { 570 *shardsRemainder-- 571 } 572 serverRole.Shards[shard] = true 573 serverRoles[address] = serverRole 574 shards[shard] = address 575 return true 576 } 577 578 func (a *sharder) announceServers( 579 address string, 580 servers []Server, 581 versionChan chan int64, 582 cancel chan bool, 583 ) error { 584 serverState := &ServerState{ 585 Address: address, 586 Version: InvalidVersion, 587 } 588 for { 589 encodedServerState, err := marshaler.MarshalToString(serverState) 590 if err != nil { 591 return err 592 } 593 if err := a.discoveryClient.Set(a.serverStateKey(address), encodedServerState, holdTTL); err != nil { 594 log.Errorf("Error setting server state: %s", err.Error()) 595 } 596 select { 597 case <-cancel: 598 return nil 599 case version := <-versionChan: 600 serverState.Version = version 601 case <-time.After(time.Second * time.Duration(holdTTL/2)): 602 } 603 } 604 } 605 606 func (a *sharder) announceFrontends( 607 address string, 608 frontends []Frontend, 609 versionChan chan int64, 610 cancel chan bool, 611 ) error { 612 frontendState := &FrontendState{ 613 Address: address, 614 Version: InvalidVersion, 615 } 616 for { 617 encodedFrontendState, err := marshaler.MarshalToString(frontendState) 618 if err != nil { 619 return err 620 } 621 if err := a.discoveryClient.Set(a.frontendStateKey(address), encodedFrontendState, holdTTL); err != nil { 622 log.Errorf("Error setting server state: %s", err.Error()) 623 } 624 select { 625 case <-cancel: 626 return nil 627 case version := <-versionChan: 628 frontendState.Version = version 629 case <-time.After(time.Second * time.Duration(holdTTL/2)): 630 } 631 } 632 } 633 634 type int64Slice []int64 635 636 func (s int64Slice) Len() int { return len(s) } 637 func (s int64Slice) Swap(i, j int) { s[i], s[j] = s[j], s[i] } 638 func (s int64Slice) Less(i, j int) bool { return s[i] < s[j] } 639 640 func (a *sharder) fillRoles( 641 address string, 642 servers []Server, 643 versionChan chan int64, 644 cancel chan bool, 645 ) error { 646 oldRoles := make(map[int64]ServerRole) 647 return a.discoveryClient.WatchAll( 648 a.serverRoleKey(address), 649 cancel, 650 func(encodedServerRoles map[string]string) error { 651 roles := make(map[int64]ServerRole) 652 var versions int64Slice 653 // Decode the roles 654 for _, encodedServerRole := range encodedServerRoles { 655 var serverRole ServerRole 656 if err := jsonpb.UnmarshalString(encodedServerRole, &serverRole); err != nil { 657 return err 658 } 659 roles[serverRole.Version] = serverRole 660 versions = append(versions, serverRole.Version) 661 } 662 sort.Sort(versions) 663 if len(versions) > 2 { 664 versions = versions[0:2] 665 } 666 // For each new version bring the server up to date 667 for _, version := range versions { 668 if _, ok := oldRoles[version]; ok { 669 // we've already seen these roles, so nothing to do here 670 continue 671 } 672 serverRole := roles[version] 673 var wg sync.WaitGroup 674 var addShardErr error 675 for _, shard := range shards(serverRole) { 676 if !containsShard(oldRoles, shard) { 677 shard := shard 678 for _, server := range servers { 679 wg.Add(1) 680 server := server 681 go func() { 682 defer wg.Done() 683 if err := server.AddShard(shard); err != nil && addShardErr == nil { 684 addShardErr = err 685 } 686 }() 687 } 688 } 689 } 690 wg.Wait() 691 if addShardErr != nil { 692 return addShardErr 693 } 694 oldRoles[version] = serverRole 695 versionChan <- version 696 } 697 // See if there are any old roles that aren't needed 698 for version, serverRole := range oldRoles { 699 var wg sync.WaitGroup 700 var removeShardErr error 701 if _, ok := roles[version]; ok { 702 // these roles haven't expired yet, so nothing to do 703 continue 704 } 705 for _, shard := range shards(serverRole) { 706 if !containsShard(roles, shard) { 707 shard := shard 708 for _, server := range servers { 709 server := server 710 wg.Add(1) 711 go func(shard uint64) { 712 defer wg.Done() 713 if err := server.DeleteShard(shard); err != nil && removeShardErr == nil { 714 removeShardErr = err 715 } 716 }(shard) 717 } 718 } 719 } 720 wg.Wait() 721 if removeShardErr != nil { 722 log.Error(&RemoveServerRole{ 723 ServerRole: &serverRole, 724 Error: removeShardErr.Error(), 725 }) 726 return removeShardErr 727 } 728 } 729 oldRoles = make(map[int64]ServerRole) 730 for _, version := range versions { 731 oldRoles[version] = roles[version] 732 } 733 return nil 734 }, 735 ) 736 } 737 738 func (a *sharder) runFrontends( 739 address string, 740 frontends []Frontend, 741 versionChan chan int64, 742 cancel chan bool, 743 ) error { 744 version := InvalidVersion 745 return a.discoveryClient.WatchAll( 746 a.serverStateDir(), 747 cancel, 748 func(encodedServerStates map[string]string) error { 749 if len(encodedServerStates) == 0 { 750 return nil 751 } 752 minVersion := int64(math.MaxInt64) 753 for _, encodedServerState := range encodedServerStates { 754 serverState, err := decodeServerState(encodedServerState) 755 if err != nil { 756 return err 757 } 758 if serverState.Version < minVersion { 759 minVersion = serverState.Version 760 } 761 } 762 if minVersion > version { 763 var eg errgroup.Group 764 for _, frontend := range frontends { 765 frontend := frontend 766 eg.Go(func() error { return frontend.Version(minVersion) }) 767 } 768 if err := eg.Wait(); err != nil { 769 return err 770 } 771 version = minVersion 772 versionChan <- version 773 } 774 return nil 775 }) 776 } 777 778 func shards(serverRole ServerRole) []uint64 { 779 var result []uint64 780 for shard := range serverRole.Shards { 781 result = append(result, shard) 782 } 783 return result 784 } 785 786 func containsShard(roles map[int64]ServerRole, shard uint64) bool { 787 for _, serverRole := range roles { 788 if serverRole.Shards[shard] { 789 return true 790 } 791 } 792 return false 793 } 794 795 func sameServers(oldServers map[string]bool, newServerStates map[string]*ServerState) bool { 796 if len(oldServers) != len(newServerStates) { 797 return false 798 } 799 for address := range oldServers { 800 if _, ok := newServerStates[address]; !ok { 801 return false 802 } 803 } 804 return true 805 }