github.com/hasnat/dolt/go@v0.0.0-20210628190320-9eb5d843fbb7/utils/remotesrv/http.go (about)

     1  // Copyright 2019 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package main
    16  
    17  import (
    18  	"bytes"
    19  	"crypto/md5"
    20  	"errors"
    21  	"fmt"
    22  	"io"
    23  	"io/ioutil"
    24  	"net/http"
    25  	"os"
    26  	"path/filepath"
    27  	"strconv"
    28  	"strings"
    29  
    30  	remotesapi "github.com/dolthub/dolt/go/gen/proto/dolt/services/remotesapi/v1alpha1"
    31  
    32  	"github.com/dolthub/dolt/go/libraries/utils/iohelp"
    33  	"github.com/dolthub/dolt/go/store/hash"
    34  )
    35  
    36  var expectedFiles = make(map[string]*remotesapi.TableFileDetails)
    37  
    38  func ServeHTTP(respWr http.ResponseWriter, req *http.Request) {
    39  	logger := getReqLogger("HTTP_"+req.Method, req.RequestURI)
    40  	defer func() { logger("finished") }()
    41  
    42  	path := strings.TrimLeft(req.URL.Path, "/")
    43  	tokens := strings.Split(path, "/")
    44  
    45  	if len(tokens) != 3 {
    46  		logger(fmt.Sprintf("response to: %v method: %v http response code: %v", req.RequestURI, req.Method, http.StatusNotFound))
    47  		respWr.WriteHeader(http.StatusNotFound)
    48  	}
    49  
    50  	org := tokens[0]
    51  	repo := tokens[1]
    52  	hashStr := tokens[2]
    53  
    54  	statusCode := http.StatusMethodNotAllowed
    55  	switch req.Method {
    56  	case http.MethodGet:
    57  		rangeStr := req.Header.Get("Range")
    58  
    59  		if rangeStr == "" {
    60  			statusCode = readFile(logger, org, repo, hashStr, respWr)
    61  		} else {
    62  			statusCode = readChunk(logger, org, repo, hashStr, rangeStr, respWr)
    63  		}
    64  
    65  	case http.MethodPost, http.MethodPut:
    66  		statusCode = writeTableFile(logger, org, repo, hashStr, req)
    67  	}
    68  
    69  	if statusCode != -1 {
    70  		respWr.WriteHeader(statusCode)
    71  	}
    72  }
    73  
    74  func writeTableFile(logger func(string), org, repo, fileId string, request *http.Request) int {
    75  	_, ok := hash.MaybeParse(fileId)
    76  
    77  	if !ok {
    78  		logger(fileId + " is not a valid hash")
    79  		return http.StatusBadRequest
    80  	}
    81  
    82  	tfd, ok := expectedFiles[fileId]
    83  
    84  	if !ok {
    85  		return http.StatusBadRequest
    86  	}
    87  
    88  	logger(fileId + " is valid")
    89  	data, err := ioutil.ReadAll(request.Body)
    90  
    91  	if tfd.ContentLength != 0 && tfd.ContentLength != uint64(len(data)) {
    92  		return http.StatusBadRequest
    93  	}
    94  
    95  	if len(tfd.ContentHash) > 0 {
    96  		actualMD5Bytes := md5.Sum(data)
    97  		if !bytes.Equal(tfd.ContentHash, actualMD5Bytes[:]) {
    98  			return http.StatusBadRequest
    99  		}
   100  	}
   101  
   102  	if err != nil {
   103  		logger("failed to read body " + err.Error())
   104  		return http.StatusInternalServerError
   105  	}
   106  
   107  	err = writeLocal(logger, org, repo, fileId, data)
   108  
   109  	if err != nil {
   110  		return http.StatusInternalServerError
   111  	}
   112  
   113  	return http.StatusOK
   114  }
   115  
   116  func writeLocal(logger func(string), org, repo, fileId string, data []byte) error {
   117  	path := filepath.Join(org, repo, fileId)
   118  
   119  	err := ioutil.WriteFile(path, data, os.ModePerm)
   120  
   121  	if err != nil {
   122  		logger(fmt.Sprintf("failed to write file %s", path))
   123  		return err
   124  	}
   125  
   126  	logger("Successfully wrote object to storage")
   127  
   128  	return nil
   129  }
   130  
   131  func offsetAndLenFromRange(rngStr string) (int64, int64, error) {
   132  	if rngStr == "" {
   133  		return -1, -1, nil
   134  	}
   135  
   136  	if !strings.HasPrefix(rngStr, "bytes=") {
   137  		return -1, -1, errors.New("range string does not start with 'bytes=")
   138  	}
   139  
   140  	tokens := strings.Split(rngStr[6:], "-")
   141  
   142  	if len(tokens) != 2 {
   143  		return -1, -1, errors.New("invalid range format. should be bytes=#-#")
   144  	}
   145  
   146  	start, err := strconv.ParseUint(strings.TrimSpace(tokens[0]), 10, 64)
   147  
   148  	if err != nil {
   149  		return -1, -1, errors.New("invalid offset is not a number. should be bytes=#-#")
   150  	}
   151  
   152  	end, err := strconv.ParseUint(strings.TrimSpace(tokens[1]), 10, 64)
   153  
   154  	if err != nil {
   155  		return -1, -1, errors.New("invalid length is not a number. should be bytes=#-#")
   156  	}
   157  
   158  	return int64(start), int64(end-start) + 1, nil
   159  }
   160  
   161  func readFile(logger func(string), org, repo, fileId string, writer io.Writer) int {
   162  	path := filepath.Join(org, repo, fileId)
   163  
   164  	info, err := os.Stat(path)
   165  
   166  	if err != nil {
   167  		logger("file not found. path: " + path)
   168  		return http.StatusNotFound
   169  	}
   170  
   171  	f, err := os.Open(path)
   172  
   173  	if err != nil {
   174  		logger("failed to open file. file: " + path + " err: " + err.Error())
   175  		return http.StatusInternalServerError
   176  	}
   177  
   178  	defer func() {
   179  		err := f.Close()
   180  
   181  		if err != nil {
   182  			logger(fmt.Sprintf("Close failed. file: %s, err: %v", path, err))
   183  		} else {
   184  			logger("Close Successful")
   185  		}
   186  	}()
   187  
   188  	n, err := io.Copy(writer, f)
   189  
   190  	if err != nil {
   191  		logger("failed to write data to response. err : " + err.Error())
   192  		return -1
   193  	}
   194  
   195  	if n != info.Size() {
   196  		logger(fmt.Sprintf("failed to write entire file to response. Copied %d of %d err: %v", n, info.Size(), err))
   197  		return -1
   198  	}
   199  
   200  	return -1
   201  }
   202  
   203  func readChunk(logger func(string), org, repo, fileId, rngStr string, writer io.Writer) int {
   204  	offset, length, err := offsetAndLenFromRange(rngStr)
   205  
   206  	if err != nil {
   207  		logger(fmt.Sprintln(rngStr, "is not a valid range"))
   208  		return http.StatusBadRequest
   209  	}
   210  
   211  	data, retVal := readLocalRange(logger, org, repo, fileId, int64(offset), int64(length))
   212  
   213  	if retVal != -1 {
   214  		return retVal
   215  	}
   216  
   217  	logger(fmt.Sprintf("writing %d bytes", len(data)))
   218  	err = iohelp.WriteAll(writer, data)
   219  
   220  	if err != nil {
   221  		logger("failed to write data to response " + err.Error())
   222  		return -1
   223  	}
   224  
   225  	logger("Successfully wrote data")
   226  	return -1
   227  }
   228  
   229  func readLocalRange(logger func(string), org, repo, fileId string, offset, length int64) ([]byte, int) {
   230  	path := filepath.Join(org, repo, fileId)
   231  
   232  	logger(fmt.Sprintf("Attempting to read bytes %d to %d from %s", offset, offset+length, path))
   233  	info, err := os.Stat(path)
   234  
   235  	if err != nil {
   236  		logger(fmt.Sprintf("file %s not found", path))
   237  		return nil, http.StatusNotFound
   238  	}
   239  
   240  	logger(fmt.Sprintf("Verified file %s exists", path))
   241  
   242  	if info.Size() < int64(offset+length) {
   243  		logger(fmt.Sprintf("Attempted to read bytes %d to %d, but the file is only %d bytes in size", offset, offset+length, info.Size()))
   244  		return nil, http.StatusBadRequest
   245  	}
   246  
   247  	logger(fmt.Sprintf("Verified the file is large enough to contain the range"))
   248  	f, err := os.Open(path)
   249  
   250  	if err != nil {
   251  		logger(fmt.Sprintf("Failed to open %s: %v", path, err))
   252  		return nil, http.StatusInternalServerError
   253  	}
   254  
   255  	defer func() {
   256  		err := f.Close()
   257  
   258  		if err != nil {
   259  			logger(fmt.Sprintf("Close failed. file: %s, err: %v", path, err))
   260  		} else {
   261  			logger("Close Successful")
   262  		}
   263  	}()
   264  
   265  	logger(fmt.Sprintf("Successfully opened file"))
   266  	pos, err := f.Seek(int64(offset), 0)
   267  
   268  	if err != nil {
   269  		logger(fmt.Sprintf("Failed to seek to %d: %v", offset, err))
   270  		return nil, http.StatusInternalServerError
   271  	}
   272  
   273  	logger(fmt.Sprintf("Seek succeeded.  Current position is %d", pos))
   274  	diff := offset - pos
   275  	data, err := iohelp.ReadNBytes(f, int(diff+int64(length)))
   276  
   277  	if err != nil {
   278  		logger(fmt.Sprintf("Failed to read %d bytes: %v", diff+length, err))
   279  		return nil, http.StatusInternalServerError
   280  	}
   281  
   282  	logger(fmt.Sprintf("Successfully read %d bytes", len(data)))
   283  	return data[diff:], -1
   284  }