github.com/sqlitebrowser/dio@v0.0.0-20240125125356-b587368e5c6b/cmd/pull.go (about)

     1  package cmd
     2  
     3  import (
     4  	"crypto/sha256"
     5  	"encoding/hex"
     6  	"fmt"
     7  	"io/ioutil"
     8  	"os"
     9  	"path/filepath"
    10  	"strings"
    11  	"time"
    12  
    13  	"github.com/pkg/errors"
    14  	"github.com/spf13/cobra"
    15  )
    16  
    17  var (
    18  	pullCmdBranch, pullCmdCommit string
    19  	pullForce                    *bool
    20  )
    21  
    22  // Downloads a database from DBHub.io.
    23  var pullCmd = &cobra.Command{
    24  	Use:   "pull [database name]",
    25  	Short: "Download a database from DBHub.io",
    26  	RunE: func(cmd *cobra.Command, args []string) error {
    27  		return pull(args)
    28  	},
    29  }
    30  
    31  func init() {
    32  	RootCmd.AddCommand(pullCmd)
    33  	pullCmd.Flags().StringVar(&pullCmdBranch, "branch", "",
    34  		"Remote branch the database will be downloaded from")
    35  	pullCmd.Flags().StringVar(&pullCmdCommit, "commit", "",
    36  		"Commit ID of the database to download")
    37  	pullForce = pullCmd.Flags().BoolP("force", "f", false,
    38  		"Overwrite unsaved changes to the database?")
    39  }
    40  
    41  func pull(args []string) error {
    42  	// Ensure a database file was given
    43  	var db, defDB string
    44  	var err error
    45  	if len(args) == 0 {
    46  		db, err = getDefaultDatabase()
    47  		if err != nil {
    48  			return err
    49  		}
    50  		if db == "" {
    51  			// No database name was given on the command line, and we don't have a default database selected
    52  			return errors.New("No database file specified")
    53  		}
    54  	} else {
    55  		db = args[0]
    56  	}
    57  
    58  	// TODO: Allow giving multiple database files on the command line.  Hopefully just needs turning this
    59  	// TODO  into a for loop
    60  	if len(args) > 1 {
    61  		return errors.New("Only one database can be downloaded at a time (for now)")
    62  	}
    63  
    64  	// TODO: Add a --licence option, for automatically grabbing the licence as well
    65  	//       * Probably save it as <database name>-<license short name>.txt/html
    66  
    67  	// Ensure we weren't given potentially conflicting info on what to pull down
    68  	if pullCmdBranch != "" && pullCmdCommit != "" {
    69  		return errors.New("Either a branch name or commit ID can be given.  Not both at the same time!")
    70  	}
    71  
    72  	// Retrieve metadata for the database
    73  	var meta metaData
    74  	meta, err = updateMetadata(db, false) // Don't store the metadata to disk yet, in case the download fails
    75  	if err != nil {
    76  		return err
    77  	}
    78  
    79  	// If the database file already exists locally, check whether the file has changed since the last commit, and let
    80  	// the user know.  The --force option on the command line overrides this
    81  	if _, err = os.Stat(db); err == nil {
    82  		if *pullForce == false {
    83  			changed, err := dbChanged(db, meta)
    84  			if err != nil {
    85  				return err
    86  			}
    87  			if changed {
    88  				_, err = fmt.Fprintf(fOut, "%s has been changed since the last commit.  Use --force if you "+
    89  					"really want to overwrite it\n", db)
    90  				return err
    91  			}
    92  		}
    93  	}
    94  
    95  	// If given, make sure the requested branch exists
    96  	if pullCmdBranch != "" {
    97  		if _, ok := meta.Branches[pullCmdBranch]; ok == false {
    98  			return errors.New("The requested branch doesn't exist")
    99  		}
   100  	}
   101  
   102  	// If no specific branch nor commit were requested, we use the active branch set in the metadata
   103  	if pullCmdBranch == "" && pullCmdCommit == "" {
   104  		pullCmdBranch = meta.ActiveBranch
   105  	}
   106  
   107  	// If given, make sure the requested commit exists
   108  	var lastMod time.Time
   109  	var ok bool
   110  	var thisSha string
   111  	var thisCommit commitEntry
   112  	if pullCmdCommit != "" {
   113  		thisCommit, ok = meta.Commits[pullCmdCommit]
   114  		if ok == false {
   115  			return errors.New("The requested commit doesn't exist")
   116  		}
   117  		thisSha = thisCommit.Tree.Entries[0].Sha256
   118  		lastMod = thisCommit.Tree.Entries[0].LastModified
   119  	} else {
   120  		// Determine the sha256 of the database file
   121  		c := meta.Branches[pullCmdBranch].Commit
   122  		thisCommit, ok = meta.Commits[c]
   123  		if ok == false {
   124  			return errors.New("The requested commit doesn't exist")
   125  		}
   126  		thisSha = thisCommit.Tree.Entries[0].Sha256
   127  		lastMod = thisCommit.Tree.Entries[0].LastModified
   128  	}
   129  
   130  	// Check if the database file already exists in local cache
   131  	if thisSha != "" {
   132  		if _, err = os.Stat(filepath.Join(".dio", db, "db", thisSha)); err == nil {
   133  			// The database is already in the local cache, so use that instead of downloading from DBHub.io
   134  			var b []byte
   135  			b, err = ioutil.ReadFile(filepath.Join(".dio", db, "db", thisSha))
   136  			if err != nil {
   137  				return err
   138  			}
   139  			err = ioutil.WriteFile(db, b, 0644)
   140  			if err != nil {
   141  				return err
   142  			}
   143  			err = os.Chtimes(db, time.Now(), lastMod)
   144  			if err != nil {
   145  				return err
   146  			}
   147  
   148  			_, err = fmt.Fprintf(fOut, "Database '%s' refreshed from local cache\n", db)
   149  			if err != nil {
   150  				return err
   151  			}
   152  			if pullCmdBranch != "" {
   153  				_, err = fmt.Fprintf(fOut, "  * Branch: '%s'\n", pullCmdBranch)
   154  				if err != nil {
   155  					return err
   156  				}
   157  			}
   158  			if pullCmdCommit != "" {
   159  				_, err = fmt.Fprintf(fOut, "  * Commit: %s\n", pullCmdCommit)
   160  				if err != nil {
   161  					return err
   162  				}
   163  			}
   164  			_, err = numFormat.Fprintf(fOut, "  * Size: %d bytes\n", len(b))
   165  			if err != nil {
   166  				return err
   167  			}
   168  
   169  			// Update the branch metadata with the commit info
   170  			var oldBranch branchEntry
   171  			if pullCmdBranch == "" {
   172  				oldBranch = meta.Branches[meta.ActiveBranch]
   173  			} else {
   174  				oldBranch = meta.Branches[pullCmdBranch]
   175  			}
   176  			commitCount := 1
   177  			z := meta.Commits[thisCommit.ID]
   178  			for z.Parent != "" {
   179  				commitCount++
   180  				z = meta.Commits[z.Parent]
   181  			}
   182  			newBranch := branchEntry{
   183  				Commit:      thisCommit.ID,
   184  				CommitCount: commitCount,
   185  				Description: oldBranch.Description,
   186  			}
   187  			if pullCmdBranch == "" {
   188  				meta.Branches[meta.ActiveBranch] = newBranch
   189  			} else {
   190  				meta.Branches[pullCmdBranch] = newBranch
   191  			}
   192  
   193  			// Save the updated metadata to disk
   194  			err = saveMetadata(db, meta)
   195  			if err != nil {
   196  				return err
   197  			}
   198  
   199  			// If a default database isn't already selected, we use this one as the default
   200  			defDB, err = getDefaultDatabase()
   201  			if err != nil {
   202  				return err
   203  			}
   204  			if defDB == "" {
   205  				err = saveDefaultDatabase(db)
   206  				if err != nil {
   207  					return err
   208  				}
   209  			}
   210  			return nil
   211  		}
   212  	}
   213  
   214  	// Download the database file
   215  	// TODO: Use a streaming download approach, so download progress can be shown.  Something like this should help:
   216  	//         https://stackoverflow.com/questions/22108519/how-do-i-read-a-streaming-response-body-using-golangs-net-http-package
   217  	_, err = fmt.Fprintf(fOut, "Downloading '%s' from %s...\n", db, cloud)
   218  	if err != nil {
   219  		return err
   220  	}
   221  	resp, body, err := retrieveDatabase(db, pullCmdBranch, pullCmdCommit)
   222  	if err != nil {
   223  		return err
   224  	}
   225  
   226  	// Create the local database cache directory, if it doesn't yet exist
   227  	if _, err = os.Stat(filepath.Join(".dio", db, "db")); os.IsNotExist(err) {
   228  		err = os.MkdirAll(filepath.Join(".dio", db, "db"), 0770)
   229  		if err != nil {
   230  			return err
   231  		}
   232  	}
   233  
   234  	// Calculate the sha256 of the database file
   235  	s := sha256.Sum256(body)
   236  	shaSum := hex.EncodeToString(s[:])
   237  
   238  	// Write the database file to disk in the cache directory
   239  	err = ioutil.WriteFile(filepath.Join(".dio", db, "db", shaSum), body, 0644)
   240  	if err != nil {
   241  		return err
   242  	}
   243  
   244  	// Write the database file to disk again, this time in the working directory
   245  	err = ioutil.WriteFile(db, body, 0644)
   246  	if err != nil {
   247  		return err
   248  	}
   249  
   250  	// If the headers included the modification-date parameter for the database, set the last accessed and last
   251  	// modified times on the new database file
   252  	if disp := resp.Header.Get("Content-Disposition"); disp != "" {
   253  		s := strings.Split(disp, ";")
   254  		if len(s) == 4 {
   255  			a := strings.TrimLeft(s[2], " ")
   256  			if strings.HasPrefix(a, "modification-date=") {
   257  				b := strings.Split(a, "=")
   258  				c := strings.Trim(b[1], "\"")
   259  				lastMod, err := time.Parse(time.RFC3339, c)
   260  				if err != nil {
   261  					return err
   262  				}
   263  				err = os.Chtimes(db, time.Now(), lastMod)
   264  				if err != nil {
   265  					return err
   266  				}
   267  			}
   268  		}
   269  	}
   270  
   271  	// If the server provided a branch name, add it to the local metadata cache
   272  	if branch := resp.Header.Get("Branch"); branch != "" {
   273  		meta.ActiveBranch = branch
   274  	}
   275  
   276  	// The download succeeded, so save the updated metadata to disk
   277  	err = saveMetadata(db, meta)
   278  	if err != nil {
   279  		return err
   280  	}
   281  
   282  	// If a default database isn't already selected, we use this one as the default
   283  	defDB, err = getDefaultDatabase()
   284  	if err != nil {
   285  		return err
   286  	}
   287  	if defDB == "" {
   288  		err = saveDefaultDatabase(db)
   289  		if err != nil {
   290  			return err
   291  		}
   292  	}
   293  
   294  	// Display success message to the user
   295  	comID := resp.Header.Get("Commit-Id")
   296  	_, err = fmt.Fprintln(fOut, "Downloaded complete")
   297  	if err != nil {
   298  		return err
   299  	}
   300  	if pullCmdBranch != "" {
   301  		_, err = fmt.Fprintf(fOut, "  * Branch: '%s'\n", pullCmdBranch)
   302  		if err != nil {
   303  			return err
   304  		}
   305  	}
   306  	if comID != "" {
   307  		_, err = fmt.Fprintf(fOut, "  * Commit: %s\n", comID)
   308  		if err != nil {
   309  			return err
   310  		}
   311  	}
   312  	_, err = numFormat.Fprintf(fOut, "  * Size: %d bytes\n", len(body))
   313  	return err
   314  }