github.com/mckael/restic@v0.8.3/internal/restic/restorer_test.go (about)

     1  package restic_test
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"io/ioutil"
     7  	"os"
     8  	"path/filepath"
     9  	"strings"
    10  	"testing"
    11  	"time"
    12  
    13  	"github.com/restic/restic/internal/fs"
    14  	"github.com/restic/restic/internal/repository"
    15  	"github.com/restic/restic/internal/restic"
    16  	rtest "github.com/restic/restic/internal/test"
    17  )
    18  
    19  type Node interface{}
    20  
    21  type Snapshot struct {
    22  	Nodes  map[string]Node
    23  	treeID restic.ID
    24  }
    25  
    26  type File struct {
    27  	Data string
    28  }
    29  
    30  type Dir struct {
    31  	Nodes map[string]Node
    32  	Mode  os.FileMode
    33  }
    34  
    35  func saveFile(t testing.TB, repo restic.Repository, node File) restic.ID {
    36  	ctx, cancel := context.WithCancel(context.Background())
    37  	defer cancel()
    38  
    39  	id, err := repo.SaveBlob(ctx, restic.DataBlob, []byte(node.Data), restic.ID{})
    40  	if err != nil {
    41  		t.Fatal(err)
    42  	}
    43  
    44  	return id
    45  }
    46  
    47  func saveDir(t testing.TB, repo restic.Repository, nodes map[string]Node) restic.ID {
    48  	ctx, cancel := context.WithCancel(context.Background())
    49  	defer cancel()
    50  
    51  	tree := &restic.Tree{}
    52  	for name, n := range nodes {
    53  		var id restic.ID
    54  		switch node := n.(type) {
    55  		case File:
    56  			id = saveFile(t, repo, node)
    57  			tree.Insert(&restic.Node{
    58  				Type:    "file",
    59  				Mode:    0644,
    60  				Name:    name,
    61  				UID:     uint32(os.Getuid()),
    62  				GID:     uint32(os.Getgid()),
    63  				Content: []restic.ID{id},
    64  			})
    65  		case Dir:
    66  			id = saveDir(t, repo, node.Nodes)
    67  
    68  			mode := node.Mode
    69  			if mode == 0 {
    70  				mode = 0755
    71  			}
    72  
    73  			tree.Insert(&restic.Node{
    74  				Type:    "dir",
    75  				Mode:    mode,
    76  				Name:    name,
    77  				UID:     uint32(os.Getuid()),
    78  				GID:     uint32(os.Getgid()),
    79  				Subtree: &id,
    80  			})
    81  		default:
    82  			t.Fatalf("unknown node type %T", node)
    83  		}
    84  	}
    85  
    86  	id, err := repo.SaveTree(ctx, tree)
    87  	if err != nil {
    88  		t.Fatal(err)
    89  	}
    90  
    91  	return id
    92  }
    93  
    94  func saveSnapshot(t testing.TB, repo restic.Repository, snapshot Snapshot) (restic.Repository, restic.ID) {
    95  	ctx, cancel := context.WithCancel(context.Background())
    96  	defer cancel()
    97  
    98  	treeID := saveDir(t, repo, snapshot.Nodes)
    99  
   100  	err := repo.Flush(ctx)
   101  	if err != nil {
   102  		t.Fatal(err)
   103  	}
   104  
   105  	err = repo.SaveIndex(ctx)
   106  	if err != nil {
   107  		t.Fatal(err)
   108  	}
   109  
   110  	sn, err := restic.NewSnapshot([]string{"test"}, nil, "", time.Now())
   111  	if err != nil {
   112  		t.Fatal(err)
   113  	}
   114  
   115  	sn.Tree = &treeID
   116  	id, err := repo.SaveJSONUnpacked(ctx, restic.SnapshotFile, sn)
   117  	if err != nil {
   118  		t.Fatal(err)
   119  	}
   120  
   121  	return repo, id
   122  }
   123  
   124  // toSlash converts the OS specific path dir to a slash-separated path.
   125  func toSlash(dir string) string {
   126  	data := strings.Split(dir, string(filepath.Separator))
   127  	return strings.Join(data, "/")
   128  }
   129  
   130  func TestRestorer(t *testing.T) {
   131  	var tests = []struct {
   132  		Snapshot
   133  		Files      map[string]string
   134  		ErrorsMust map[string]string
   135  		ErrorsMay  map[string]string
   136  	}{
   137  		// valid test cases
   138  		{
   139  			Snapshot: Snapshot{
   140  				Nodes: map[string]Node{
   141  					"foo": File{"content: foo\n"},
   142  					"dirtest": Dir{
   143  						Nodes: map[string]Node{
   144  							"file": File{"content: file\n"},
   145  						},
   146  					},
   147  				},
   148  			},
   149  			Files: map[string]string{
   150  				"foo":          "content: foo\n",
   151  				"dirtest/file": "content: file\n",
   152  			},
   153  		},
   154  		{
   155  			Snapshot: Snapshot{
   156  				Nodes: map[string]Node{
   157  					"top": File{"toplevel file"},
   158  					"dir": Dir{
   159  						Nodes: map[string]Node{
   160  							"file": File{"file in dir"},
   161  							"subdir": Dir{
   162  								Nodes: map[string]Node{
   163  									"file": File{"file in subdir"},
   164  								},
   165  							},
   166  						},
   167  					},
   168  				},
   169  			},
   170  			Files: map[string]string{
   171  				"top":             "toplevel file",
   172  				"dir/file":        "file in dir",
   173  				"dir/subdir/file": "file in subdir",
   174  			},
   175  		},
   176  		{
   177  			Snapshot: Snapshot{
   178  				Nodes: map[string]Node{
   179  					"dir": Dir{
   180  						Mode: 0444,
   181  					},
   182  					"file": File{"top-level file"},
   183  				},
   184  			},
   185  			Files: map[string]string{
   186  				"file": "top-level file",
   187  			},
   188  		},
   189  		{
   190  			Snapshot: Snapshot{
   191  				Nodes: map[string]Node{
   192  					"dir": Dir{
   193  						Mode: 0555,
   194  						Nodes: map[string]Node{
   195  							"file": File{"file in dir"},
   196  						},
   197  					},
   198  				},
   199  			},
   200  			Files: map[string]string{
   201  				"dir/file": "file in dir",
   202  			},
   203  		},
   204  
   205  		// test cases with invalid/constructed names
   206  		{
   207  			Snapshot: Snapshot{
   208  				Nodes: map[string]Node{
   209  					`..\test`:                      File{"foo\n"},
   210  					`..\..\foo\..\bar\..\xx\test2`: File{"test2\n"},
   211  				},
   212  			},
   213  			ErrorsMay: map[string]string{
   214  				`/#..\test`:                      "node has invalid name",
   215  				`/#..\..\foo\..\bar\..\xx\test2`: "node has invalid name",
   216  			},
   217  		},
   218  		{
   219  			Snapshot: Snapshot{
   220  				Nodes: map[string]Node{
   221  					`../test`:                      File{"foo\n"},
   222  					`../../foo/../bar/../xx/test2`: File{"test2\n"},
   223  				},
   224  			},
   225  			ErrorsMay: map[string]string{
   226  				`/#../test`:                      "node has invalid name",
   227  				`/#../../foo/../bar/../xx/test2`: "node has invalid name",
   228  			},
   229  		},
   230  		{
   231  			Snapshot: Snapshot{
   232  				Nodes: map[string]Node{
   233  					"top": File{"toplevel file"},
   234  					"x": Dir{
   235  						Nodes: map[string]Node{
   236  							"file1": File{"file1"},
   237  							"..": Dir{
   238  								Nodes: map[string]Node{
   239  									"file2": File{"file2"},
   240  									"..": Dir{
   241  										Nodes: map[string]Node{
   242  											"file2": File{"file2"},
   243  										},
   244  									},
   245  								},
   246  							},
   247  						},
   248  					},
   249  				},
   250  			},
   251  			Files: map[string]string{
   252  				"top": "toplevel file",
   253  			},
   254  			ErrorsMust: map[string]string{
   255  				`/x#..`: "node has invalid name",
   256  			},
   257  		},
   258  	}
   259  
   260  	for _, test := range tests {
   261  		t.Run("", func(t *testing.T) {
   262  			repo, cleanup := repository.TestRepository(t)
   263  			defer cleanup()
   264  			_, id := saveSnapshot(t, repo, test.Snapshot)
   265  			t.Logf("snapshot saved as %v", id.Str())
   266  
   267  			res, err := restic.NewRestorer(repo, id)
   268  			if err != nil {
   269  				t.Fatal(err)
   270  			}
   271  
   272  			tempdir, cleanup := rtest.TempDir(t)
   273  			defer cleanup()
   274  
   275  			res.SelectFilter = func(item, dstpath string, node *restic.Node) (selectedForRestore bool, childMayBeSelected bool) {
   276  				t.Logf("restore %v to %v", item, dstpath)
   277  				if !fs.HasPathPrefix(tempdir, dstpath) {
   278  					t.Errorf("would restore %v to %v, which is not within the target dir %v",
   279  						item, dstpath, tempdir)
   280  					return false, false
   281  				}
   282  				return true, true
   283  			}
   284  
   285  			errors := make(map[string]string)
   286  			res.Error = func(dir string, node *restic.Node, err error) error {
   287  				t.Logf("restore returned error for %q in dir %v: %v", node.Name, dir, err)
   288  				dir = toSlash(dir)
   289  				errors[dir+"#"+node.Name] = err.Error()
   290  				return nil
   291  			}
   292  
   293  			ctx, cancel := context.WithCancel(context.Background())
   294  			defer cancel()
   295  
   296  			err = res.RestoreTo(ctx, tempdir)
   297  			if err != nil {
   298  				t.Fatal(err)
   299  			}
   300  
   301  			for filename, errorMessage := range test.ErrorsMust {
   302  				msg, ok := errors[filename]
   303  				if !ok {
   304  					t.Errorf("expected error for %v, found none", filename)
   305  					continue
   306  				}
   307  
   308  				if msg != "" && msg != errorMessage {
   309  					t.Errorf("wrong error message for %v: got %q, want %q",
   310  						filename, msg, errorMessage)
   311  				}
   312  
   313  				delete(errors, filename)
   314  			}
   315  
   316  			for filename, errorMessage := range test.ErrorsMay {
   317  				msg, ok := errors[filename]
   318  				if !ok {
   319  					continue
   320  				}
   321  
   322  				if msg != "" && msg != errorMessage {
   323  					t.Errorf("wrong error message for %v: got %q, want %q",
   324  						filename, msg, errorMessage)
   325  				}
   326  
   327  				delete(errors, filename)
   328  			}
   329  
   330  			for filename, err := range errors {
   331  				t.Errorf("unexpected error for %v found: %v", filename, err)
   332  			}
   333  
   334  			for filename, content := range test.Files {
   335  				data, err := ioutil.ReadFile(filepath.Join(tempdir, filepath.FromSlash(filename)))
   336  				if err != nil {
   337  					t.Errorf("unable to read file %v: %v", filename, err)
   338  					continue
   339  				}
   340  
   341  				if !bytes.Equal(data, []byte(content)) {
   342  					t.Errorf("file %v has wrong content: want %q, got %q", filename, content, data)
   343  				}
   344  			}
   345  		})
   346  	}
   347  }
   348  
   349  func chdir(t testing.TB, target string) func() {
   350  	prev, err := os.Getwd()
   351  	if err != nil {
   352  		t.Fatal(err)
   353  	}
   354  
   355  	t.Logf("chdir to %v", target)
   356  	err = os.Chdir(target)
   357  	if err != nil {
   358  		t.Fatal(err)
   359  	}
   360  
   361  	return func() {
   362  		t.Logf("chdir back to %v", prev)
   363  		err = os.Chdir(prev)
   364  		if err != nil {
   365  			t.Fatal(err)
   366  		}
   367  	}
   368  }
   369  
   370  func TestRestorerRelative(t *testing.T) {
   371  	var tests = []struct {
   372  		Snapshot
   373  		Files map[string]string
   374  	}{
   375  		{
   376  			Snapshot: Snapshot{
   377  				Nodes: map[string]Node{
   378  					"foo": File{"content: foo\n"},
   379  					"dirtest": Dir{
   380  						Nodes: map[string]Node{
   381  							"file": File{"content: file\n"},
   382  						},
   383  					},
   384  				},
   385  			},
   386  			Files: map[string]string{
   387  				"foo":          "content: foo\n",
   388  				"dirtest/file": "content: file\n",
   389  			},
   390  		},
   391  	}
   392  
   393  	for _, test := range tests {
   394  		t.Run("", func(t *testing.T) {
   395  			repo, cleanup := repository.TestRepository(t)
   396  			defer cleanup()
   397  
   398  			_, id := saveSnapshot(t, repo, test.Snapshot)
   399  			t.Logf("snapshot saved as %v", id.Str())
   400  
   401  			res, err := restic.NewRestorer(repo, id)
   402  			if err != nil {
   403  				t.Fatal(err)
   404  			}
   405  
   406  			tempdir, cleanup := rtest.TempDir(t)
   407  			defer cleanup()
   408  
   409  			cleanup = chdir(t, tempdir)
   410  			defer cleanup()
   411  
   412  			errors := make(map[string]string)
   413  			res.Error = func(dir string, node *restic.Node, err error) error {
   414  				t.Logf("restore returned error for %q in dir %v: %v", node.Name, dir, err)
   415  				dir = toSlash(dir)
   416  				errors[dir+"#"+node.Name] = err.Error()
   417  				return nil
   418  			}
   419  
   420  			ctx, cancel := context.WithCancel(context.Background())
   421  			defer cancel()
   422  
   423  			err = res.RestoreTo(ctx, "restore")
   424  			if err != nil {
   425  				t.Fatal(err)
   426  			}
   427  
   428  			for filename, err := range errors {
   429  				t.Errorf("unexpected error for %v found: %v", filename, err)
   430  			}
   431  
   432  			for filename, content := range test.Files {
   433  				data, err := ioutil.ReadFile(filepath.Join(tempdir, "restore", filepath.FromSlash(filename)))
   434  				if err != nil {
   435  					t.Errorf("unable to read file %v: %v", filename, err)
   436  					continue
   437  				}
   438  
   439  				if !bytes.Equal(data, []byte(content)) {
   440  					t.Errorf("file %v has wrong content: want %q, got %q", filename, content, data)
   441  				}
   442  			}
   443  		})
   444  	}
   445  }