github.com/hasnat/dolt/go@v0.0.0-20210628190320-9eb5d843fbb7/utils/remotesrv/http.go (about) 1 // Copyright 2019 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package main 16 17 import ( 18 "bytes" 19 "crypto/md5" 20 "errors" 21 "fmt" 22 "io" 23 "io/ioutil" 24 "net/http" 25 "os" 26 "path/filepath" 27 "strconv" 28 "strings" 29 30 remotesapi "github.com/dolthub/dolt/go/gen/proto/dolt/services/remotesapi/v1alpha1" 31 32 "github.com/dolthub/dolt/go/libraries/utils/iohelp" 33 "github.com/dolthub/dolt/go/store/hash" 34 ) 35 36 var expectedFiles = make(map[string]*remotesapi.TableFileDetails) 37 38 func ServeHTTP(respWr http.ResponseWriter, req *http.Request) { 39 logger := getReqLogger("HTTP_"+req.Method, req.RequestURI) 40 defer func() { logger("finished") }() 41 42 path := strings.TrimLeft(req.URL.Path, "/") 43 tokens := strings.Split(path, "/") 44 45 if len(tokens) != 3 { 46 logger(fmt.Sprintf("response to: %v method: %v http response code: %v", req.RequestURI, req.Method, http.StatusNotFound)) 47 respWr.WriteHeader(http.StatusNotFound) 48 } 49 50 org := tokens[0] 51 repo := tokens[1] 52 hashStr := tokens[2] 53 54 statusCode := http.StatusMethodNotAllowed 55 switch req.Method { 56 case http.MethodGet: 57 rangeStr := req.Header.Get("Range") 58 59 if rangeStr == "" { 60 statusCode = readFile(logger, org, repo, hashStr, respWr) 61 } else { 62 statusCode = readChunk(logger, org, repo, hashStr, rangeStr, respWr) 63 } 64 65 case http.MethodPost, http.MethodPut: 66 statusCode = writeTableFile(logger, org, repo, hashStr, req) 67 } 68 69 if statusCode != -1 { 70 respWr.WriteHeader(statusCode) 71 } 72 } 73 74 func writeTableFile(logger func(string), org, repo, fileId string, request *http.Request) int { 75 _, ok := hash.MaybeParse(fileId) 76 77 if !ok { 78 logger(fileId + " is not a valid hash") 79 return http.StatusBadRequest 80 } 81 82 tfd, ok := expectedFiles[fileId] 83 84 if !ok { 85 return http.StatusBadRequest 86 } 87 88 logger(fileId + " is valid") 89 data, err := ioutil.ReadAll(request.Body) 90 91 if tfd.ContentLength != 0 && tfd.ContentLength != uint64(len(data)) { 92 return http.StatusBadRequest 93 } 94 95 if len(tfd.ContentHash) > 0 { 96 actualMD5Bytes := md5.Sum(data) 97 if !bytes.Equal(tfd.ContentHash, actualMD5Bytes[:]) { 98 return http.StatusBadRequest 99 } 100 } 101 102 if err != nil { 103 logger("failed to read body " + err.Error()) 104 return http.StatusInternalServerError 105 } 106 107 err = writeLocal(logger, org, repo, fileId, data) 108 109 if err != nil { 110 return http.StatusInternalServerError 111 } 112 113 return http.StatusOK 114 } 115 116 func writeLocal(logger func(string), org, repo, fileId string, data []byte) error { 117 path := filepath.Join(org, repo, fileId) 118 119 err := ioutil.WriteFile(path, data, os.ModePerm) 120 121 if err != nil { 122 logger(fmt.Sprintf("failed to write file %s", path)) 123 return err 124 } 125 126 logger("Successfully wrote object to storage") 127 128 return nil 129 } 130 131 func offsetAndLenFromRange(rngStr string) (int64, int64, error) { 132 if rngStr == "" { 133 return -1, -1, nil 134 } 135 136 if !strings.HasPrefix(rngStr, "bytes=") { 137 return -1, -1, errors.New("range string does not start with 'bytes=") 138 } 139 140 tokens := strings.Split(rngStr[6:], "-") 141 142 if len(tokens) != 2 { 143 return -1, -1, errors.New("invalid range format. should be bytes=#-#") 144 } 145 146 start, err := strconv.ParseUint(strings.TrimSpace(tokens[0]), 10, 64) 147 148 if err != nil { 149 return -1, -1, errors.New("invalid offset is not a number. should be bytes=#-#") 150 } 151 152 end, err := strconv.ParseUint(strings.TrimSpace(tokens[1]), 10, 64) 153 154 if err != nil { 155 return -1, -1, errors.New("invalid length is not a number. should be bytes=#-#") 156 } 157 158 return int64(start), int64(end-start) + 1, nil 159 } 160 161 func readFile(logger func(string), org, repo, fileId string, writer io.Writer) int { 162 path := filepath.Join(org, repo, fileId) 163 164 info, err := os.Stat(path) 165 166 if err != nil { 167 logger("file not found. path: " + path) 168 return http.StatusNotFound 169 } 170 171 f, err := os.Open(path) 172 173 if err != nil { 174 logger("failed to open file. file: " + path + " err: " + err.Error()) 175 return http.StatusInternalServerError 176 } 177 178 defer func() { 179 err := f.Close() 180 181 if err != nil { 182 logger(fmt.Sprintf("Close failed. file: %s, err: %v", path, err)) 183 } else { 184 logger("Close Successful") 185 } 186 }() 187 188 n, err := io.Copy(writer, f) 189 190 if err != nil { 191 logger("failed to write data to response. err : " + err.Error()) 192 return -1 193 } 194 195 if n != info.Size() { 196 logger(fmt.Sprintf("failed to write entire file to response. Copied %d of %d err: %v", n, info.Size(), err)) 197 return -1 198 } 199 200 return -1 201 } 202 203 func readChunk(logger func(string), org, repo, fileId, rngStr string, writer io.Writer) int { 204 offset, length, err := offsetAndLenFromRange(rngStr) 205 206 if err != nil { 207 logger(fmt.Sprintln(rngStr, "is not a valid range")) 208 return http.StatusBadRequest 209 } 210 211 data, retVal := readLocalRange(logger, org, repo, fileId, int64(offset), int64(length)) 212 213 if retVal != -1 { 214 return retVal 215 } 216 217 logger(fmt.Sprintf("writing %d bytes", len(data))) 218 err = iohelp.WriteAll(writer, data) 219 220 if err != nil { 221 logger("failed to write data to response " + err.Error()) 222 return -1 223 } 224 225 logger("Successfully wrote data") 226 return -1 227 } 228 229 func readLocalRange(logger func(string), org, repo, fileId string, offset, length int64) ([]byte, int) { 230 path := filepath.Join(org, repo, fileId) 231 232 logger(fmt.Sprintf("Attempting to read bytes %d to %d from %s", offset, offset+length, path)) 233 info, err := os.Stat(path) 234 235 if err != nil { 236 logger(fmt.Sprintf("file %s not found", path)) 237 return nil, http.StatusNotFound 238 } 239 240 logger(fmt.Sprintf("Verified file %s exists", path)) 241 242 if info.Size() < int64(offset+length) { 243 logger(fmt.Sprintf("Attempted to read bytes %d to %d, but the file is only %d bytes in size", offset, offset+length, info.Size())) 244 return nil, http.StatusBadRequest 245 } 246 247 logger(fmt.Sprintf("Verified the file is large enough to contain the range")) 248 f, err := os.Open(path) 249 250 if err != nil { 251 logger(fmt.Sprintf("Failed to open %s: %v", path, err)) 252 return nil, http.StatusInternalServerError 253 } 254 255 defer func() { 256 err := f.Close() 257 258 if err != nil { 259 logger(fmt.Sprintf("Close failed. file: %s, err: %v", path, err)) 260 } else { 261 logger("Close Successful") 262 } 263 }() 264 265 logger(fmt.Sprintf("Successfully opened file")) 266 pos, err := f.Seek(int64(offset), 0) 267 268 if err != nil { 269 logger(fmt.Sprintf("Failed to seek to %d: %v", offset, err)) 270 return nil, http.StatusInternalServerError 271 } 272 273 logger(fmt.Sprintf("Seek succeeded. Current position is %d", pos)) 274 diff := offset - pos 275 data, err := iohelp.ReadNBytes(f, int(diff+int64(length))) 276 277 if err != nil { 278 logger(fmt.Sprintf("Failed to read %d bytes: %v", diff+length, err)) 279 return nil, http.StatusInternalServerError 280 } 281 282 logger(fmt.Sprintf("Successfully read %d bytes", len(data))) 283 return data[diff:], -1 284 }