github.com/treeverse/lakefs@v1.24.1-0.20240520134607-95648127bfb0/pkg/local/sync.go (about)

     1  package local
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"net/http"
    10  	"os"
    11  	"path/filepath"
    12  	"strconv"
    13  	"strings"
    14  	"sync/atomic"
    15  	"syscall"
    16  	"time"
    17  
    18  	"github.com/go-openapi/swag"
    19  	"github.com/treeverse/lakefs/pkg/api/apigen"
    20  	"github.com/treeverse/lakefs/pkg/api/apiutil"
    21  	"github.com/treeverse/lakefs/pkg/api/helpers"
    22  	"github.com/treeverse/lakefs/pkg/fileutil"
    23  	"github.com/treeverse/lakefs/pkg/uri"
    24  	"golang.org/x/sync/errgroup"
    25  )
    26  
    27  const (
    28  	DefaultDirectoryMask   = 0o755
    29  	ClientMtimeMetadataKey = apiutil.LakeFSMetadataPrefix + "client-mtime"
    30  )
    31  
    32  type SyncFlags struct {
    33  	Parallelism      int
    34  	Presign          bool
    35  	PresignMultipart bool
    36  }
    37  
    38  func getMtimeFromStats(stats apigen.ObjectStats) (int64, error) {
    39  	if stats.Metadata == nil {
    40  		return stats.Mtime, nil
    41  	}
    42  	clientMtime, hasClientMtime := stats.Metadata.Get(ClientMtimeMetadataKey)
    43  	if hasClientMtime {
    44  		// parse
    45  		return strconv.ParseInt(clientMtime, 10, 64)
    46  	}
    47  	return stats.Mtime, nil
    48  }
    49  
    50  type Tasks struct {
    51  	Downloaded uint64
    52  	Uploaded   uint64
    53  	Removed    uint64
    54  }
    55  
    56  type SyncManager struct {
    57  	ctx         context.Context
    58  	client      *apigen.ClientWithResponses
    59  	httpClient  *http.Client
    60  	progressBar *ProgressPool
    61  	flags       SyncFlags
    62  	tasks       Tasks
    63  }
    64  
    65  func NewSyncManager(ctx context.Context, client *apigen.ClientWithResponses, flags SyncFlags) *SyncManager {
    66  	return &SyncManager{
    67  		ctx:         ctx,
    68  		client:      client,
    69  		httpClient:  http.DefaultClient,
    70  		progressBar: NewProgressPool(),
    71  		flags:       flags,
    72  	}
    73  }
    74  
    75  // Sync - sync changes between remote and local directory given the Changes channel.
    76  // For each change, will apply download, upload or delete according to the change type and change source
    77  func (s *SyncManager) Sync(rootPath string, remote *uri.URI, changeSet <-chan *Change) error {
    78  	s.progressBar.Start()
    79  	defer s.progressBar.Stop()
    80  
    81  	wg, ctx := errgroup.WithContext(s.ctx)
    82  	for i := 0; i < s.flags.Parallelism; i++ {
    83  		wg.Go(func() error {
    84  			for change := range changeSet {
    85  				if err := s.apply(ctx, rootPath, remote, change); err != nil {
    86  					return err
    87  				}
    88  			}
    89  			return nil
    90  		})
    91  	}
    92  	if err := wg.Wait(); err != nil {
    93  		return err
    94  	}
    95  	_, err := fileutil.PruneEmptyDirectories(rootPath)
    96  	return err
    97  }
    98  
    99  func (s *SyncManager) apply(ctx context.Context, rootPath string, remote *uri.URI, change *Change) error {
   100  	switch change.Type {
   101  	case ChangeTypeAdded, ChangeTypeModified:
   102  		switch change.Source {
   103  		case ChangeSourceRemote:
   104  			// remotely changed something, download it!
   105  			if err := s.download(ctx, rootPath, remote, change.Path); err != nil {
   106  				return fmt.Errorf("download %s failed: %w", change.Path, err)
   107  			}
   108  		case ChangeSourceLocal:
   109  			// we wrote something, upload it!
   110  			if err := s.upload(ctx, rootPath, remote, change.Path); err != nil {
   111  				return fmt.Errorf("upload %s failed: %w", change.Path, err)
   112  			}
   113  		default:
   114  			panic("invalid change source")
   115  		}
   116  	case ChangeTypeRemoved:
   117  		if change.Source == ChangeSourceRemote {
   118  			// remote deleted something, delete it locally!
   119  			if err := s.deleteLocal(rootPath, change); err != nil {
   120  				return fmt.Errorf("delete local %s failed: %w", change.Path, err)
   121  			}
   122  		} else {
   123  			// we deleted something, delete it on remote!
   124  			if err := s.deleteRemote(ctx, remote, change); err != nil {
   125  				return fmt.Errorf("delete remote %s failed: %w", change.Path, err)
   126  			}
   127  		}
   128  	case ChangeTypeConflict:
   129  		return ErrConflict
   130  	default:
   131  		panic("invalid change type")
   132  	}
   133  	return nil
   134  }
   135  
   136  func (s *SyncManager) download(ctx context.Context, rootPath string, remote *uri.URI, path string) error {
   137  	if err := fileutil.VerifyRelPath(strings.TrimPrefix(path, uri.PathSeparator), rootPath); err != nil {
   138  		return err
   139  	}
   140  	destination := filepath.Join(rootPath, path)
   141  	destinationDirectory := filepath.Dir(destination)
   142  	if err := os.MkdirAll(destinationDirectory, DefaultDirectoryMask); err != nil {
   143  		return err
   144  	}
   145  	statResp, err := s.client.StatObjectWithResponse(ctx, remote.Repository, remote.Ref, &apigen.StatObjectParams{
   146  		Path:         filepath.ToSlash(filepath.Join(remote.GetPath(), path)),
   147  		Presign:      swag.Bool(s.flags.Presign),
   148  		UserMetadata: swag.Bool(true),
   149  	})
   150  	if err != nil {
   151  		return err
   152  	}
   153  	if statResp.StatusCode() != http.StatusOK {
   154  		httpErr := apigen.Error{Message: "no content"}
   155  		_ = json.Unmarshal(statResp.Body, &httpErr)
   156  		return fmt.Errorf("(stat: HTTP %d, message: %s): %w", statResp.StatusCode(), httpErr.Message, ErrDownloadingFile)
   157  	}
   158  	// get mtime
   159  	mtimeSecs, err := getMtimeFromStats(*statResp.JSON200)
   160  	if err != nil {
   161  		return err
   162  	}
   163  
   164  	if strings.HasSuffix(path, uri.PathSeparator) {
   165  		// Directory marker - skip
   166  		return nil
   167  	}
   168  
   169  	lastModified := time.Unix(mtimeSecs, 0)
   170  	sizeBytes := swag.Int64Value(statResp.JSON200.SizeBytes)
   171  	f, err := os.Create(destination)
   172  	if err != nil {
   173  		// Sometimes we get a file that is actually a directory marker (Spark loves writing those).
   174  		// If we already have the directory, we can skip it.
   175  		if errors.Is(err, syscall.EISDIR) && sizeBytes == 0 {
   176  			return nil // no further action required!
   177  		}
   178  		return fmt.Errorf("could not create file '%s': %w", destination, err)
   179  	}
   180  	defer func() {
   181  		err = f.Close()
   182  	}()
   183  
   184  	if sizeBytes == 0 { // if size is empty just create file
   185  		spinner := s.progressBar.AddSpinner("download " + path)
   186  		atomic.AddUint64(&s.tasks.Downloaded, 1)
   187  		defer spinner.Done()
   188  	} else { // Download file
   189  		// make request
   190  		var body io.Reader
   191  		if s.flags.Presign {
   192  			resp, err := s.httpClient.Get(statResp.JSON200.PhysicalAddress)
   193  			if err != nil {
   194  				return err
   195  			}
   196  			defer func() {
   197  				_ = resp.Body.Close()
   198  			}()
   199  			if resp.StatusCode != http.StatusOK {
   200  				return fmt.Errorf("%s (pre-signed GET: HTTP %d): %w", path, resp.StatusCode, ErrDownloadingFile)
   201  			}
   202  			body = resp.Body
   203  		} else {
   204  			resp, err := s.client.GetObject(ctx, remote.Repository, remote.Ref, &apigen.GetObjectParams{
   205  				Path: filepath.ToSlash(filepath.Join(remote.GetPath(), path)),
   206  			})
   207  			if err != nil {
   208  				return err
   209  			}
   210  			defer func() {
   211  				_ = resp.Body.Close()
   212  			}()
   213  			if resp.StatusCode != http.StatusOK {
   214  				return fmt.Errorf("%s (GetObject: HTTP %d): %w", path, resp.StatusCode, ErrDownloadingFile)
   215  			}
   216  			body = resp.Body
   217  		}
   218  
   219  		b := s.progressBar.AddReader(fmt.Sprintf("download %s", path), sizeBytes)
   220  		barReader := b.Reader(body)
   221  		defer func() {
   222  			if err != nil {
   223  				b.Error()
   224  			} else {
   225  				atomic.AddUint64(&s.tasks.Downloaded, 1)
   226  				b.Done()
   227  			}
   228  		}()
   229  		_, err = io.Copy(f, barReader)
   230  
   231  		if err != nil {
   232  			return fmt.Errorf("could not write file '%s': %w", destination, err)
   233  		}
   234  	}
   235  
   236  	// set mtime to the server returned one
   237  	err = os.Chtimes(destination, time.Now(), lastModified) // Explicit to catch in deferred func
   238  	return err
   239  }
   240  
   241  func (s *SyncManager) upload(ctx context.Context, rootPath string, remote *uri.URI, path string) error {
   242  	source := filepath.Join(rootPath, path)
   243  	if err := fileutil.VerifySafeFilename(source); err != nil {
   244  		return err
   245  	}
   246  	dest := filepath.ToSlash(filepath.Join(remote.GetPath(), path))
   247  
   248  	f, err := os.Open(source)
   249  	if err != nil {
   250  		return err
   251  	}
   252  	defer func() {
   253  		_ = f.Close()
   254  	}()
   255  
   256  	fileStat, err := f.Stat()
   257  	if err != nil {
   258  		return err
   259  	}
   260  
   261  	b := s.progressBar.AddReader(fmt.Sprintf("upload %s", path), fileStat.Size())
   262  	defer func() {
   263  		if err != nil {
   264  			b.Error()
   265  		} else {
   266  			atomic.AddUint64(&s.tasks.Uploaded, 1)
   267  			b.Done()
   268  		}
   269  	}()
   270  
   271  	metadata := map[string]string{
   272  		ClientMtimeMetadataKey: strconv.FormatInt(fileStat.ModTime().Unix(), 10),
   273  	}
   274  	reader := fileWrapper{
   275  		file:   f,
   276  		reader: b.Reader(f),
   277  	}
   278  	if s.flags.Presign {
   279  		_, err = helpers.ClientUploadPreSign(
   280  			ctx, s.client, remote.Repository, remote.Ref, dest, metadata, "", reader, s.flags.PresignMultipart)
   281  		return err
   282  	}
   283  	// not pre-signed
   284  	_, err = helpers.ClientUpload(
   285  		ctx, s.client, remote.Repository, remote.Ref, dest, metadata, "", reader)
   286  	return err
   287  }
   288  
   289  func (s *SyncManager) deleteLocal(rootPath string, change *Change) (err error) {
   290  	b := s.progressBar.AddSpinner("delete local: " + change.Path)
   291  	defer func() {
   292  		defer func() {
   293  			if err != nil {
   294  				b.Error()
   295  			} else {
   296  				atomic.AddUint64(&s.tasks.Removed, 1)
   297  				b.Done()
   298  			}
   299  		}()
   300  	}()
   301  	source := filepath.Join(rootPath, change.Path)
   302  	err = fileutil.RemoveFile(source)
   303  	if err != nil {
   304  		return err
   305  	}
   306  	return nil
   307  }
   308  
   309  func (s *SyncManager) deleteRemote(ctx context.Context, remote *uri.URI, change *Change) (err error) {
   310  	b := s.progressBar.AddSpinner("delete remote path: " + change.Path)
   311  	defer func() {
   312  		if err != nil {
   313  			b.Error()
   314  		} else {
   315  			atomic.AddUint64(&s.tasks.Removed, 1)
   316  			b.Done()
   317  		}
   318  	}()
   319  	dest := filepath.ToSlash(filepath.Join(remote.GetPath(), change.Path))
   320  	resp, err := s.client.DeleteObjectWithResponse(ctx, remote.Repository, remote.Ref, &apigen.DeleteObjectParams{
   321  		Path: dest,
   322  	})
   323  	if err != nil {
   324  		return
   325  	}
   326  	if resp.StatusCode() != http.StatusNoContent {
   327  		return fmt.Errorf("could not delete object: HTTP %d: %w", resp.StatusCode(), helpers.ErrRequestFailed)
   328  	}
   329  	return
   330  }
   331  
   332  func (s *SyncManager) Summary() Tasks {
   333  	return s.tasks
   334  }