github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/libraries/doltcore/remotesrv/http.go (about)

     1  // Copyright 2019 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package remotesrv
    16  
    17  import (
    18  	"bytes"
    19  	"context"
    20  	"crypto/md5"
    21  	"encoding/base64"
    22  	"errors"
    23  	"fmt"
    24  	gohash "hash"
    25  	"io"
    26  	"net/http"
    27  	"os"
    28  	"path/filepath"
    29  	"strconv"
    30  	"strings"
    31  
    32  	"github.com/sirupsen/logrus"
    33  	"google.golang.org/grpc/metadata"
    34  	"google.golang.org/grpc/peer"
    35  
    36  	"github.com/dolthub/dolt/go/libraries/utils/filesys"
    37  	"github.com/dolthub/dolt/go/store/hash"
    38  	"github.com/dolthub/dolt/go/store/types"
    39  )
    40  
    41  var (
    42  	ErrReadOutOfBounds = errors.New("cannot read file for given length and " +
    43  		"offset since the read would exceed the size of the file")
    44  )
    45  
    46  type filehandler struct {
    47  	dbCache  DBCache
    48  	fs       filesys.Filesys
    49  	readOnly bool
    50  	lgr      *logrus.Entry
    51  	sealer   Sealer
    52  }
    53  
    54  func newFileHandler(lgr *logrus.Entry, dbCache DBCache, fs filesys.Filesys, readOnly bool, sealer Sealer) filehandler {
    55  	return filehandler{
    56  		dbCache,
    57  		fs,
    58  		readOnly,
    59  		lgr.WithFields(logrus.Fields{
    60  			"service": "dolt.services.remotesapi.v1alpha1.HttpFileServer",
    61  		}),
    62  		sealer,
    63  	}
    64  }
    65  
    66  func (fh filehandler) ServeHTTP(respWr http.ResponseWriter, req *http.Request) {
    67  	logger := getReqLogger(fh.lgr, req.Method+"_"+req.RequestURI)
    68  	defer func() { logger.Info("finished") }()
    69  
    70  	var err error
    71  	req.URL, err = fh.sealer.Unseal(req.URL)
    72  	if err != nil {
    73  		logger.WithError(err).Warn("could not unseal incoming request URL")
    74  		respWr.WriteHeader(http.StatusBadRequest)
    75  		return
    76  	}
    77  
    78  	logger = logger.WithField("unsealed_url", req.URL.String())
    79  
    80  	path := strings.TrimLeft(req.URL.Path, "/")
    81  
    82  	statusCode := http.StatusMethodNotAllowed
    83  	switch req.Method {
    84  	case http.MethodGet:
    85  		path = filepath.Clean(path)
    86  		if strings.HasPrefix(path, "../") || strings.Contains(path, "/../") || strings.HasSuffix(path, "/..") {
    87  			logger.Warn("bad request with .. in URL path")
    88  			respWr.WriteHeader(http.StatusBadRequest)
    89  			return
    90  		}
    91  		i := strings.LastIndex(path, "/")
    92  		if i == -1 {
    93  			logger.Warn("bad request with -1 LastIndex of '/' for path")
    94  			respWr.WriteHeader(http.StatusBadRequest)
    95  			return
    96  		}
    97  		_, ok := hash.MaybeParse(path[i+1:])
    98  		if !ok {
    99  			logger.WithField("last_path_component", path[i+1:]).Warn("bad request with unparseable last path component")
   100  			respWr.WriteHeader(http.StatusBadRequest)
   101  			return
   102  		}
   103  		abs, err := fh.fs.Abs(path)
   104  		if err != nil {
   105  			logger.WithError(err).Error("could not get absolute path")
   106  			respWr.WriteHeader(http.StatusInternalServerError)
   107  			return
   108  		}
   109  		respWr.Header().Add("Accept-Ranges", "bytes")
   110  		logger, statusCode = readTableFile(logger, abs, respWr, req.Header.Get("Range"))
   111  
   112  	case http.MethodPost, http.MethodPut:
   113  		if fh.readOnly {
   114  			respWr.WriteHeader(http.StatusForbidden)
   115  			return
   116  		}
   117  
   118  		i := strings.LastIndex(path, "/")
   119  		// a table file name is currently 32 characters, plus the '/' is 33.
   120  		if i < 0 || len(path[i:]) != 33 {
   121  			logger = logger.WithField("status", http.StatusNotFound)
   122  			respWr.WriteHeader(http.StatusNotFound)
   123  			return
   124  		}
   125  
   126  		filepath := path[:i]
   127  		file := path[i+1:]
   128  
   129  		q := req.URL.Query()
   130  		ncs := q.Get("num_chunks")
   131  		if ncs == "" {
   132  			logger = logger.WithField("status", http.StatusBadRequest)
   133  			logger.Warn("bad request: num_chunks parameter not provided")
   134  			respWr.WriteHeader(http.StatusBadRequest)
   135  			return
   136  		}
   137  		num_chunks, err := strconv.Atoi(ncs)
   138  		if err != nil {
   139  			logger = logger.WithField("status", http.StatusBadRequest)
   140  			logger.WithError(err).Warn("bad request: num_chunks parameter did not parse")
   141  			respWr.WriteHeader(http.StatusBadRequest)
   142  			return
   143  		}
   144  		cls := q.Get("content_length")
   145  		if cls == "" {
   146  			logger = logger.WithField("status", http.StatusBadRequest)
   147  			logger.Warn("bad request: content_length parameter not provided")
   148  			respWr.WriteHeader(http.StatusBadRequest)
   149  			return
   150  		}
   151  		content_length, err := strconv.Atoi(cls)
   152  		if err != nil {
   153  			logger = logger.WithField("status", http.StatusBadRequest)
   154  			logger.WithError(err).Warn("bad request: content_length parameter did not parse")
   155  			respWr.WriteHeader(http.StatusBadRequest)
   156  			return
   157  		}
   158  		chs := q.Get("content_hash")
   159  		if chs == "" {
   160  			logger = logger.WithField("status", http.StatusBadRequest)
   161  			logger.Warn("bad request: content_hash parameter not provided")
   162  			respWr.WriteHeader(http.StatusBadRequest)
   163  			return
   164  		}
   165  		content_hash, err := base64.RawURLEncoding.DecodeString(chs)
   166  		if err != nil {
   167  			logger = logger.WithField("status", http.StatusBadRequest)
   168  			logger.WithError(err).Warn("bad request: content_hash parameter did not parse")
   169  			respWr.WriteHeader(http.StatusBadRequest)
   170  			return
   171  		}
   172  
   173  		logger, statusCode = writeTableFile(req.Context(), logger, fh.dbCache, filepath, file, num_chunks, content_hash, uint64(content_length), req.Body)
   174  	}
   175  
   176  	if statusCode != -1 {
   177  		respWr.WriteHeader(statusCode)
   178  	}
   179  }
   180  
   181  func readTableFile(logger *logrus.Entry, path string, respWr http.ResponseWriter, rangeStr string) (*logrus.Entry, int) {
   182  	var r io.ReadCloser
   183  	var readSize int64
   184  	var fileErr error
   185  	{
   186  		if rangeStr == "" {
   187  			logger = logger.WithField("whole_file", true)
   188  			r, readSize, fileErr = getFileReader(path)
   189  		} else {
   190  			offset, length, headerStr, err := offsetAndLenFromRange(rangeStr)
   191  			if err != nil {
   192  				logger.Println(err.Error())
   193  				return logger, http.StatusBadRequest
   194  			}
   195  			logger = logger.WithFields(logrus.Fields{
   196  				"read_offset": offset,
   197  				"read_length": length,
   198  			})
   199  			readSize = length
   200  			var fSize int64
   201  			r, fSize, fileErr = getFileReaderAt(path, offset, length)
   202  			if fileErr == nil {
   203  				respWr.Header().Add("Content-Range", headerStr+strconv.Itoa(int(fSize)))
   204  			}
   205  		}
   206  	}
   207  	if fileErr != nil {
   208  		logger.Println(fileErr.Error())
   209  		if errors.Is(fileErr, os.ErrNotExist) {
   210  			logger = logger.WithField("status", http.StatusNotFound)
   211  			return logger, http.StatusNotFound
   212  		} else if errors.Is(fileErr, ErrReadOutOfBounds) {
   213  			logger = logger.WithField("status", http.StatusBadRequest)
   214  			logger.Warn("bad request: offset out of bounds for path")
   215  			return logger, http.StatusBadRequest
   216  		}
   217  		logger = logger.WithError(fileErr)
   218  		return logger, http.StatusInternalServerError
   219  	}
   220  	defer func() {
   221  		err := r.Close()
   222  		if err != nil {
   223  			logger.WithError(err).Warn("failed to close file")
   224  		}
   225  	}()
   226  
   227  	if rangeStr == "" {
   228  		respWr.WriteHeader(http.StatusPartialContent)
   229  	} else {
   230  		respWr.WriteHeader(http.StatusOK)
   231  	}
   232  
   233  	n, err := io.Copy(respWr, r)
   234  	if err != nil {
   235  		logger = logger.WithField("status", http.StatusInternalServerError)
   236  		logger.WithError(err).Error("error copying data to response writer")
   237  		return logger, http.StatusInternalServerError
   238  	}
   239  	if n != readSize {
   240  		logger = logger.WithField("status", http.StatusInternalServerError)
   241  		logger.WithField("copied_size", n).Error("failed to copy all bytes to response")
   242  		return logger, http.StatusInternalServerError
   243  	}
   244  
   245  	return logger, -1
   246  }
   247  
   248  type uploadreader struct {
   249  	r            io.ReadCloser
   250  	totalread    int
   251  	expectedread uint64
   252  	expectedsum  []byte
   253  	checksum     gohash.Hash
   254  }
   255  
   256  func (u *uploadreader) Read(p []byte) (n int, err error) {
   257  	n, err = u.r.Read(p)
   258  	if err == nil || err == io.EOF {
   259  		u.totalread += n
   260  		u.checksum.Write(p[:n])
   261  	}
   262  	return n, err
   263  }
   264  
   265  var errBodyLengthTFDMismatch = errors.New("body upload length did not match table file details")
   266  var errBodyHashTFDMismatch = errors.New("body upload hash did not match table file details")
   267  
   268  func (u *uploadreader) Close() error {
   269  	cerr := u.r.Close()
   270  	if cerr != nil {
   271  		return cerr
   272  	}
   273  	if u.expectedread != 0 && u.expectedread != uint64(u.totalread) {
   274  		return errBodyLengthTFDMismatch
   275  	}
   276  	sum := u.checksum.Sum(nil)
   277  	if !bytes.Equal(u.expectedsum, sum[:]) {
   278  		return errBodyHashTFDMismatch
   279  	}
   280  	return nil
   281  }
   282  
   283  func writeTableFile(ctx context.Context, logger *logrus.Entry, dbCache DBCache, path, fileId string, numChunks int, contentHash []byte, contentLength uint64, body io.ReadCloser) (*logrus.Entry, int) {
   284  	_, ok := hash.MaybeParse(fileId)
   285  	if !ok {
   286  		logger = logger.WithField("status", http.StatusBadRequest)
   287  		logger.Warnf("%s is not a valid hash", fileId)
   288  		return logger, http.StatusBadRequest
   289  	}
   290  
   291  	cs, err := dbCache.Get(ctx, path, types.Format_Default.VersionString())
   292  	if err != nil {
   293  		logger = logger.WithField("status", http.StatusInternalServerError)
   294  		logger.WithError(err).Error("failed to get repository")
   295  		return logger, http.StatusInternalServerError
   296  	}
   297  
   298  	err = cs.WriteTableFile(ctx, fileId, numChunks, contentHash, func() (io.ReadCloser, uint64, error) {
   299  		reader := body
   300  		size := contentLength
   301  		return &uploadreader{
   302  			reader,
   303  			0,
   304  			contentLength,
   305  			contentHash,
   306  			md5.New(),
   307  		}, size, nil
   308  	})
   309  
   310  	if err != nil {
   311  		if errors.Is(err, errBodyLengthTFDMismatch) {
   312  			logger = logger.WithField("status", http.StatusBadRequest)
   313  			logger.Warn("bad request: body length mismatch")
   314  			return logger, http.StatusBadRequest
   315  		}
   316  		if errors.Is(err, errBodyHashTFDMismatch) {
   317  			logger = logger.WithField("status", http.StatusBadRequest)
   318  			logger.Warn("bad request: body hash mismatch")
   319  			return logger, http.StatusBadRequest
   320  		}
   321  		logger = logger.WithField("status", http.StatusInternalServerError)
   322  		logger.WithError(err).Error("failed to write upload to table file")
   323  		return logger, http.StatusInternalServerError
   324  	}
   325  
   326  	return logger, http.StatusOK
   327  }
   328  
   329  func offsetAndLenFromRange(rngStr string) (int64, int64, string, error) {
   330  	if rngStr == "" {
   331  		return -1, -1, "", nil
   332  	}
   333  
   334  	if !strings.HasPrefix(rngStr, "bytes=") {
   335  		return -1, -1, "", errors.New("range string does not start with 'bytes=")
   336  	}
   337  
   338  	tokens := strings.Split(rngStr[6:], "-")
   339  
   340  	if len(tokens) != 2 {
   341  		return -1, -1, "", errors.New("invalid range format. should be bytes=#-#")
   342  	}
   343  
   344  	start, err := strconv.ParseUint(strings.TrimSpace(tokens[0]), 10, 64)
   345  
   346  	if err != nil {
   347  		return -1, -1, "", errors.New("invalid offset is not a number. should be bytes=#-#")
   348  	}
   349  
   350  	end, err := strconv.ParseUint(strings.TrimSpace(tokens[1]), 10, 64)
   351  
   352  	if err != nil {
   353  		return -1, -1, "", errors.New("invalid length is not a number. should be bytes=#-#")
   354  	}
   355  
   356  	return int64(start), int64(end-start) + 1, "bytes " + tokens[0] + "-" + tokens[1] + "/", nil
   357  }
   358  
   359  // getFileReader opens a file at the given path and returns an io.ReadCloser,
   360  // the corresponding file's filesize, and a http status.
   361  func getFileReader(path string) (io.ReadCloser, int64, error) {
   362  	return openFile(path)
   363  }
   364  
   365  func openFile(path string) (*os.File, int64, error) {
   366  	info, err := os.Stat(path)
   367  	if err != nil {
   368  		return nil, 0, fmt.Errorf("failed to get stats for file at path %s: %w", path, err)
   369  	}
   370  
   371  	f, err := os.Open(path)
   372  	if err != nil {
   373  		return nil, 0, fmt.Errorf("failed to open file at path %s: %w", path, err)
   374  	}
   375  
   376  	return f, info.Size(), nil
   377  }
   378  
   379  type closerReaderWrapper struct {
   380  	io.Reader
   381  	io.Closer
   382  }
   383  
   384  func getFileReaderAt(path string, offset int64, length int64) (io.ReadCloser, int64, error) {
   385  	f, fSize, err := openFile(path)
   386  	if err != nil {
   387  		return nil, 0, err
   388  	}
   389  
   390  	if fSize < int64(offset+length) {
   391  		return nil, 0, fmt.Errorf("failed to read file %s at offset %d, length %d: %w", path, offset, length, ErrReadOutOfBounds)
   392  	}
   393  
   394  	_, err = f.Seek(int64(offset), 0)
   395  	if err != nil {
   396  		return nil, 0, fmt.Errorf("failed to seek file at path %s to offset %d: %w", path, offset, err)
   397  	}
   398  
   399  	r := closerReaderWrapper{io.LimitReader(f, length), f}
   400  	return r, fSize, nil
   401  }
   402  
   403  // ExtractBasicAuthCreds extracts the username and password from the incoming request. It returns RequestCredentials
   404  // populated with necessary information to authenticate the request. nil and an error will be returned if any error
   405  // occurs.
   406  func ExtractBasicAuthCreds(ctx context.Context) (*RequestCredentials, error) {
   407  	if md, ok := metadata.FromIncomingContext(ctx); !ok {
   408  		return nil, errors.New("no metadata in context")
   409  	} else {
   410  		var username string
   411  		var password string
   412  
   413  		auths := md.Get("authorization")
   414  		if len(auths) != 1 {
   415  			username = "root"
   416  			password = ""
   417  		} else {
   418  			auth := auths[0]
   419  			if !strings.HasPrefix(auth, "Basic ") {
   420  				return nil, fmt.Errorf("bad request: authorization header did not start with 'Basic '")
   421  			}
   422  			authTrim := strings.TrimPrefix(auth, "Basic ")
   423  			uDec, err := base64.URLEncoding.DecodeString(authTrim)
   424  			if err != nil {
   425  				return nil, fmt.Errorf("incoming request authorization header failed to decode: %v", err)
   426  			}
   427  			userPass := strings.Split(string(uDec), ":")
   428  			username = userPass[0]
   429  			password = userPass[1]
   430  		}
   431  		addr, ok := peer.FromContext(ctx)
   432  		if !ok {
   433  			return nil, errors.New("incoming request had no peer")
   434  		}
   435  
   436  		return &RequestCredentials{Username: username, Password: password, Address: addr.Addr.String()}, nil
   437  	}
   438  }