github.com/juju/juju@v0.0.0-20240430160146-1752b71fcf00/apiserver/backup.go (about)

     1  // Copyright 2014 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package apiserver
     5  
     6  import (
     7  	"encoding/json"
     8  	"io"
     9  	"net/http"
    10  
    11  	"github.com/juju/errors"
    12  
    13  	apiservererrors "github.com/juju/juju/apiserver/errors"
    14  	"github.com/juju/juju/rpc/params"
    15  	"github.com/juju/juju/state/backups"
    16  )
    17  
    18  var newBackups = backups.NewBackups
    19  
    20  // backupHandler handles backup requests.
    21  type backupHandler struct {
    22  	ctxt httpContext
    23  }
    24  
    25  func (h *backupHandler) ServeHTTP(resp http.ResponseWriter, req *http.Request) {
    26  	// Validate before authenticate because the authentication is dependent
    27  	// on the state connection that is determined during the validation.
    28  	st, err := h.ctxt.stateForRequestAuthenticatedUser(req)
    29  	if err != nil {
    30  		h.sendError(resp, err)
    31  		return
    32  	}
    33  	defer st.Release()
    34  
    35  	if !st.IsController() {
    36  		h.sendError(resp, errors.New("requested model is not the controller model"))
    37  		return
    38  	}
    39  
    40  	switch req.Method {
    41  	case "GET":
    42  		logger.Infof("handling backups download request")
    43  		model, err := st.Model()
    44  		if err != nil {
    45  			h.sendError(resp, err)
    46  			return
    47  		}
    48  		modelConfig, err := model.ModelConfig()
    49  		if err != nil {
    50  			h.sendError(resp, err)
    51  			return
    52  		}
    53  		backupDir := backups.BackupDirToUse(modelConfig.BackupDir())
    54  		paths := &backups.Paths{
    55  			BackupDir: backupDir,
    56  		}
    57  		id, err := h.download(newBackups(paths), resp, req)
    58  		if err != nil {
    59  			h.sendError(resp, err)
    60  			return
    61  		}
    62  		logger.Infof("backups download request successful for %q", id)
    63  	default:
    64  		h.sendError(resp, errors.MethodNotAllowedf("unsupported method: %q", req.Method))
    65  	}
    66  }
    67  
    68  func (h *backupHandler) download(backups backups.Backups, resp http.ResponseWriter, req *http.Request) (string, error) {
    69  	args, err := h.parseGETArgs(req)
    70  	if err != nil {
    71  		return "", err
    72  	}
    73  	logger.Infof("backups download request for %q", args.ID)
    74  
    75  	meta, archive, err := backups.Get(args.ID)
    76  	if err != nil {
    77  		return "", err
    78  	}
    79  	defer archive.Close()
    80  
    81  	err = h.sendFile(archive, meta.Checksum(), resp)
    82  	return args.ID, err
    83  }
    84  
    85  func (h *backupHandler) read(req *http.Request, expectedType string) ([]byte, error) {
    86  	defer req.Body.Close()
    87  
    88  	ctype := req.Header.Get("Content-Type")
    89  	if ctype != expectedType {
    90  		return nil, errors.Errorf("expected Content-Type %q, got %q", expectedType, ctype)
    91  	}
    92  
    93  	body, err := io.ReadAll(req.Body)
    94  	if err != nil {
    95  		return nil, errors.Annotate(err, "while reading request body")
    96  	}
    97  
    98  	return body, nil
    99  }
   100  
   101  func (h *backupHandler) parseGETArgs(req *http.Request) (*params.BackupsDownloadArgs, error) {
   102  	body, err := h.read(req, params.ContentTypeJSON)
   103  	if err != nil {
   104  		return nil, errors.Trace(err)
   105  	}
   106  
   107  	var args params.BackupsDownloadArgs
   108  	if err := json.Unmarshal(body, &args); err != nil {
   109  		return nil, errors.Annotate(err, "while de-serializing args")
   110  	}
   111  
   112  	return &args, nil
   113  }
   114  
   115  func (h *backupHandler) sendFile(file io.Reader, checksum string, resp http.ResponseWriter) error {
   116  	// We don't set the Content-Length header, leaving it at -1.
   117  	resp.Header().Set("Content-Type", params.ContentTypeRaw)
   118  	resp.Header().Set("Digest", params.EncodeChecksum(checksum))
   119  	resp.WriteHeader(http.StatusOK)
   120  	if _, err := io.Copy(resp, file); err != nil {
   121  		return errors.Annotate(err, "while streaming archive")
   122  	}
   123  	return nil
   124  }
   125  
   126  // sendError sends a JSON-encoded error response.
   127  // Note the difference from the error response sent by
   128  // the sendError function - the error is encoded directly
   129  // rather than in the Error field.
   130  func (h *backupHandler) sendError(w http.ResponseWriter, err error) {
   131  	err, status := apiservererrors.ServerErrorAndStatus(err)
   132  	if err := sendStatusAndJSON(w, status, err); err != nil {
   133  		logger.Errorf("%v", err)
   134  	}
   135  }