github.com/fawick/restic@v0.1.1-0.20171126184616-c02923fbfc79/internal/restic/restorer_test.go (about)

     1  package restic_test
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"io/ioutil"
     7  	"os"
     8  	"path/filepath"
     9  	"strings"
    10  	"testing"
    11  	"time"
    12  
    13  	"github.com/restic/restic/internal/fs"
    14  	"github.com/restic/restic/internal/repository"
    15  	"github.com/restic/restic/internal/restic"
    16  	rtest "github.com/restic/restic/internal/test"
    17  )
    18  
    19  type Node interface{}
    20  
    21  type Snapshot struct {
    22  	Nodes  map[string]Node
    23  	treeID restic.ID
    24  }
    25  
    26  type File struct {
    27  	Data string
    28  }
    29  
    30  type Dir struct {
    31  	Nodes map[string]Node
    32  }
    33  
    34  func saveFile(t testing.TB, repo restic.Repository, node File) restic.ID {
    35  	ctx, cancel := context.WithCancel(context.Background())
    36  	defer cancel()
    37  
    38  	id, err := repo.SaveBlob(ctx, restic.DataBlob, []byte(node.Data), restic.ID{})
    39  	if err != nil {
    40  		t.Fatal(err)
    41  	}
    42  
    43  	return id
    44  }
    45  
    46  func saveDir(t testing.TB, repo restic.Repository, nodes map[string]Node) restic.ID {
    47  	ctx, cancel := context.WithCancel(context.Background())
    48  	defer cancel()
    49  
    50  	tree := &restic.Tree{}
    51  	for name, n := range nodes {
    52  		var id restic.ID
    53  		switch node := n.(type) {
    54  		case File:
    55  			id = saveFile(t, repo, node)
    56  			tree.Insert(&restic.Node{
    57  				Type:    "file",
    58  				Mode:    0644,
    59  				Name:    name,
    60  				UID:     uint32(os.Getuid()),
    61  				GID:     uint32(os.Getgid()),
    62  				Content: []restic.ID{id},
    63  			})
    64  		case Dir:
    65  			id = saveDir(t, repo, node.Nodes)
    66  			tree.Insert(&restic.Node{
    67  				Type:    "dir",
    68  				Mode:    0755,
    69  				Name:    name,
    70  				UID:     uint32(os.Getuid()),
    71  				GID:     uint32(os.Getgid()),
    72  				Subtree: &id,
    73  			})
    74  		default:
    75  			t.Fatalf("unknown node type %T", node)
    76  		}
    77  	}
    78  
    79  	id, err := repo.SaveTree(ctx, tree)
    80  	if err != nil {
    81  		t.Fatal(err)
    82  	}
    83  
    84  	return id
    85  }
    86  
    87  func saveSnapshot(t testing.TB, repo restic.Repository, snapshot Snapshot) (restic.Repository, restic.ID) {
    88  	ctx, cancel := context.WithCancel(context.Background())
    89  	defer cancel()
    90  
    91  	treeID := saveDir(t, repo, snapshot.Nodes)
    92  
    93  	err := repo.Flush()
    94  	if err != nil {
    95  		t.Fatal(err)
    96  	}
    97  
    98  	err = repo.SaveIndex(ctx)
    99  	if err != nil {
   100  		t.Fatal(err)
   101  	}
   102  
   103  	sn, err := restic.NewSnapshot([]string{"test"}, nil, "", time.Now())
   104  	if err != nil {
   105  		t.Fatal(err)
   106  	}
   107  
   108  	sn.Tree = &treeID
   109  	id, err := repo.SaveJSONUnpacked(ctx, restic.SnapshotFile, sn)
   110  	if err != nil {
   111  		t.Fatal(err)
   112  	}
   113  
   114  	return repo, id
   115  }
   116  
   117  // toSlash converts the OS specific path dir to a slash-separated path.
   118  func toSlash(dir string) string {
   119  	data := strings.Split(dir, string(filepath.Separator))
   120  	return strings.Join(data, "/")
   121  }
   122  
   123  func TestRestorer(t *testing.T) {
   124  	var tests = []struct {
   125  		Snapshot
   126  		Files      map[string]string
   127  		ErrorsMust map[string]string
   128  		ErrorsMay  map[string]string
   129  	}{
   130  		// valid test cases
   131  		{
   132  			Snapshot: Snapshot{
   133  				Nodes: map[string]Node{
   134  					"foo": File{"content: foo\n"},
   135  					"dirtest": Dir{
   136  						Nodes: map[string]Node{
   137  							"file": File{"content: file\n"},
   138  						},
   139  					},
   140  				},
   141  			},
   142  			Files: map[string]string{
   143  				"foo":          "content: foo\n",
   144  				"dirtest/file": "content: file\n",
   145  			},
   146  		},
   147  		{
   148  			Snapshot: Snapshot{
   149  				Nodes: map[string]Node{
   150  					"top": File{"toplevel file"},
   151  					"dir": Dir{
   152  						Nodes: map[string]Node{
   153  							"file": File{"file in dir"},
   154  							"subdir": Dir{
   155  								Nodes: map[string]Node{
   156  									"file": File{"file in subdir"},
   157  								},
   158  							},
   159  						},
   160  					},
   161  				},
   162  			},
   163  			Files: map[string]string{
   164  				"top":             "toplevel file",
   165  				"dir/file":        "file in dir",
   166  				"dir/subdir/file": "file in subdir",
   167  			},
   168  		},
   169  
   170  		// test cases with invalid/constructed names
   171  		{
   172  			Snapshot: Snapshot{
   173  				Nodes: map[string]Node{
   174  					`..\test`:                      File{"foo\n"},
   175  					`..\..\foo\..\bar\..\xx\test2`: File{"test2\n"},
   176  				},
   177  			},
   178  			ErrorsMay: map[string]string{
   179  				`/#..\test`:                      "node has invalid name",
   180  				`/#..\..\foo\..\bar\..\xx\test2`: "node has invalid name",
   181  			},
   182  		},
   183  		{
   184  			Snapshot: Snapshot{
   185  				Nodes: map[string]Node{
   186  					`../test`:                      File{"foo\n"},
   187  					`../../foo/../bar/../xx/test2`: File{"test2\n"},
   188  				},
   189  			},
   190  			ErrorsMay: map[string]string{
   191  				`/#../test`:                      "node has invalid name",
   192  				`/#../../foo/../bar/../xx/test2`: "node has invalid name",
   193  			},
   194  		},
   195  		{
   196  			Snapshot: Snapshot{
   197  				Nodes: map[string]Node{
   198  					"top": File{"toplevel file"},
   199  					"x": Dir{
   200  						Nodes: map[string]Node{
   201  							"file1": File{"file1"},
   202  							"..": Dir{
   203  								Nodes: map[string]Node{
   204  									"file2": File{"file2"},
   205  									"..": Dir{
   206  										Nodes: map[string]Node{
   207  											"file2": File{"file2"},
   208  										},
   209  									},
   210  								},
   211  							},
   212  						},
   213  					},
   214  				},
   215  			},
   216  			Files: map[string]string{
   217  				"top": "toplevel file",
   218  			},
   219  			ErrorsMust: map[string]string{
   220  				`/x#..`: "node has invalid name",
   221  			},
   222  		},
   223  	}
   224  
   225  	for _, test := range tests {
   226  		t.Run("", func(t *testing.T) {
   227  			repo, cleanup := repository.TestRepository(t)
   228  			defer cleanup()
   229  			_, id := saveSnapshot(t, repo, test.Snapshot)
   230  			t.Logf("snapshot saved as %v", id.Str())
   231  
   232  			res, err := restic.NewRestorer(repo, id)
   233  			if err != nil {
   234  				t.Fatal(err)
   235  			}
   236  
   237  			tempdir, cleanup := rtest.TempDir(t)
   238  			defer cleanup()
   239  
   240  			res.SelectFilter = func(item, dstpath string, node *restic.Node) (selectedForRestore bool, childMayBeSelected bool) {
   241  				t.Logf("restore %v to %v", item, dstpath)
   242  				if !fs.HasPathPrefix(tempdir, dstpath) {
   243  					t.Errorf("would restore %v to %v, which is not within the target dir %v",
   244  						item, dstpath, tempdir)
   245  					return false, false
   246  				}
   247  				return true, true
   248  			}
   249  
   250  			errors := make(map[string]string)
   251  			res.Error = func(dir string, node *restic.Node, err error) error {
   252  				t.Logf("restore returned error for %q in dir %v: %v", node.Name, dir, err)
   253  				dir = toSlash(dir)
   254  				errors[dir+"#"+node.Name] = err.Error()
   255  				return nil
   256  			}
   257  
   258  			ctx, cancel := context.WithCancel(context.Background())
   259  			defer cancel()
   260  
   261  			err = res.RestoreTo(ctx, tempdir)
   262  			if err != nil {
   263  				t.Fatal(err)
   264  			}
   265  
   266  			for filename, errorMessage := range test.ErrorsMust {
   267  				msg, ok := errors[filename]
   268  				if !ok {
   269  					t.Errorf("expected error for %v, found none", filename)
   270  					continue
   271  				}
   272  
   273  				if msg != "" && msg != errorMessage {
   274  					t.Errorf("wrong error message for %v: got %q, want %q",
   275  						filename, msg, errorMessage)
   276  				}
   277  
   278  				delete(errors, filename)
   279  			}
   280  
   281  			for filename, errorMessage := range test.ErrorsMay {
   282  				msg, ok := errors[filename]
   283  				if !ok {
   284  					continue
   285  				}
   286  
   287  				if msg != "" && msg != errorMessage {
   288  					t.Errorf("wrong error message for %v: got %q, want %q",
   289  						filename, msg, errorMessage)
   290  				}
   291  
   292  				delete(errors, filename)
   293  			}
   294  
   295  			for filename, err := range errors {
   296  				t.Errorf("unexpected error for %v found: %v", filename, err)
   297  			}
   298  
   299  			for filename, content := range test.Files {
   300  				data, err := ioutil.ReadFile(filepath.Join(tempdir, filepath.FromSlash(filename)))
   301  				if err != nil {
   302  					t.Errorf("unable to read file %v: %v", filename, err)
   303  					continue
   304  				}
   305  
   306  				if !bytes.Equal(data, []byte(content)) {
   307  					t.Errorf("file %v has wrong content: want %q, got %q", filename, content, data)
   308  				}
   309  			}
   310  		})
   311  	}
   312  }
   313  
   314  func chdir(t testing.TB, target string) func() {
   315  	prev, err := os.Getwd()
   316  	if err != nil {
   317  		t.Fatal(err)
   318  	}
   319  
   320  	t.Logf("chdir to %v", target)
   321  	err = os.Chdir(target)
   322  	if err != nil {
   323  		t.Fatal(err)
   324  	}
   325  
   326  	return func() {
   327  		t.Logf("chdir back to %v", prev)
   328  		err = os.Chdir(prev)
   329  		if err != nil {
   330  			t.Fatal(err)
   331  		}
   332  	}
   333  }
   334  
   335  func TestRestorerRelative(t *testing.T) {
   336  	var tests = []struct {
   337  		Snapshot
   338  		Files map[string]string
   339  	}{
   340  		{
   341  			Snapshot: Snapshot{
   342  				Nodes: map[string]Node{
   343  					"foo": File{"content: foo\n"},
   344  					"dirtest": Dir{
   345  						Nodes: map[string]Node{
   346  							"file": File{"content: file\n"},
   347  						},
   348  					},
   349  				},
   350  			},
   351  			Files: map[string]string{
   352  				"foo":          "content: foo\n",
   353  				"dirtest/file": "content: file\n",
   354  			},
   355  		},
   356  	}
   357  
   358  	for _, test := range tests {
   359  		t.Run("", func(t *testing.T) {
   360  			repo, cleanup := repository.TestRepository(t)
   361  			defer cleanup()
   362  
   363  			_, id := saveSnapshot(t, repo, test.Snapshot)
   364  			t.Logf("snapshot saved as %v", id.Str())
   365  
   366  			res, err := restic.NewRestorer(repo, id)
   367  			if err != nil {
   368  				t.Fatal(err)
   369  			}
   370  
   371  			tempdir, cleanup := rtest.TempDir(t)
   372  			defer cleanup()
   373  
   374  			cleanup = chdir(t, tempdir)
   375  			defer cleanup()
   376  
   377  			errors := make(map[string]string)
   378  			res.Error = func(dir string, node *restic.Node, err error) error {
   379  				t.Logf("restore returned error for %q in dir %v: %v", node.Name, dir, err)
   380  				dir = toSlash(dir)
   381  				errors[dir+"#"+node.Name] = err.Error()
   382  				return nil
   383  			}
   384  
   385  			ctx, cancel := context.WithCancel(context.Background())
   386  			defer cancel()
   387  
   388  			err = res.RestoreTo(ctx, "restore")
   389  			if err != nil {
   390  				t.Fatal(err)
   391  			}
   392  
   393  			for filename, err := range errors {
   394  				t.Errorf("unexpected error for %v found: %v", filename, err)
   395  			}
   396  
   397  			for filename, content := range test.Files {
   398  				data, err := ioutil.ReadFile(filepath.Join(tempdir, "restore", filepath.FromSlash(filename)))
   399  				if err != nil {
   400  					t.Errorf("unable to read file %v: %v", filename, err)
   401  					continue
   402  				}
   403  
   404  				if !bytes.Equal(data, []byte(content)) {
   405  					t.Errorf("file %v has wrong content: want %q, got %q", filename, content, data)
   406  				}
   407  			}
   408  		})
   409  	}
   410  }