github.com/gitbundle/modules@v0.0.0-20231025071548-85b91c5c3b01/lfs/transferadapter.go (about)

     1  // Copyright 2023 The GitBundle Inc. All rights reserved.
     2  // Copyright 2017 The Gitea Authors. All rights reserved.
     3  // Use of this source code is governed by a MIT-style
     4  // license that can be found in the LICENSE file.
     5  
     6  package lfs
     7  
     8  import (
     9  	"bytes"
    10  	"context"
    11  	"errors"
    12  	"fmt"
    13  	"io"
    14  	"net/http"
    15  
    16  	"github.com/gitbundle/modules/json"
    17  	"github.com/gitbundle/modules/log"
    18  )
    19  
    20  // TransferAdapter represents an adapter for downloading/uploading LFS objects
    21  type TransferAdapter interface {
    22  	Name() string
    23  	Download(ctx context.Context, l *Link) (io.ReadCloser, error)
    24  	Upload(ctx context.Context, l *Link, p Pointer, r io.Reader) error
    25  	Verify(ctx context.Context, l *Link, p Pointer) error
    26  }
    27  
    28  // BasicTransferAdapter implements the "basic" adapter
    29  type BasicTransferAdapter struct {
    30  	client *http.Client
    31  }
    32  
    33  // Name returns the name of the adapter
    34  func (a *BasicTransferAdapter) Name() string {
    35  	return "basic"
    36  }
    37  
    38  // Download reads the download location and downloads the data
    39  func (a *BasicTransferAdapter) Download(ctx context.Context, l *Link) (io.ReadCloser, error) {
    40  	resp, err := a.performRequest(ctx, "GET", l, nil, nil)
    41  	if err != nil {
    42  		return nil, err
    43  	}
    44  	return resp.Body, nil
    45  }
    46  
    47  // Upload sends the content to the LFS server
    48  func (a *BasicTransferAdapter) Upload(ctx context.Context, l *Link, p Pointer, r io.Reader) error {
    49  	_, err := a.performRequest(ctx, "PUT", l, r, func(req *http.Request) {
    50  		if len(req.Header.Get("Content-Type")) == 0 {
    51  			req.Header.Set("Content-Type", "application/octet-stream")
    52  		}
    53  
    54  		if req.Header.Get("Transfer-Encoding") == "chunked" {
    55  			req.TransferEncoding = []string{"chunked"}
    56  		}
    57  
    58  		req.ContentLength = p.Size
    59  	})
    60  	if err != nil {
    61  		return err
    62  	}
    63  	return nil
    64  }
    65  
    66  // Verify calls the verify handler on the LFS server
    67  func (a *BasicTransferAdapter) Verify(ctx context.Context, l *Link, p Pointer) error {
    68  	b, err := json.Marshal(p)
    69  	if err != nil {
    70  		log.Error("Error encoding json: %v", err)
    71  		return err
    72  	}
    73  
    74  	_, err = a.performRequest(ctx, "POST", l, bytes.NewReader(b), func(req *http.Request) {
    75  		req.Header.Set("Content-Type", MediaType)
    76  	})
    77  	if err != nil {
    78  		return err
    79  	}
    80  	return nil
    81  }
    82  
    83  func (a *BasicTransferAdapter) performRequest(ctx context.Context, method string, l *Link, body io.Reader, callback func(*http.Request)) (*http.Response, error) {
    84  	log.Trace("Calling: %s %s", method, l.Href)
    85  
    86  	req, err := http.NewRequestWithContext(ctx, method, l.Href, body)
    87  	if err != nil {
    88  		log.Error("Error creating request: %v", err)
    89  		return nil, err
    90  	}
    91  	for key, value := range l.Header {
    92  		req.Header.Set(key, value)
    93  	}
    94  	req.Header.Set("Accept", MediaType)
    95  
    96  	if callback != nil {
    97  		callback(req)
    98  	}
    99  
   100  	res, err := a.client.Do(req)
   101  	if err != nil {
   102  		select {
   103  		case <-ctx.Done():
   104  			return res, ctx.Err()
   105  		default:
   106  		}
   107  		log.Error("Error while processing request: %v", err)
   108  		return res, err
   109  	}
   110  
   111  	if res.StatusCode != http.StatusOK {
   112  		return res, handleErrorResponse(res)
   113  	}
   114  
   115  	return res, nil
   116  }
   117  
   118  func handleErrorResponse(resp *http.Response) error {
   119  	defer resp.Body.Close()
   120  
   121  	er, err := decodeResponseError(resp.Body)
   122  	if err != nil {
   123  		return fmt.Errorf("Request failed with status %s", resp.Status)
   124  	}
   125  	log.Trace("ErrorRespone: %v", er)
   126  	return errors.New(er.Message)
   127  }
   128  
   129  func decodeResponseError(r io.Reader) (ErrorResponse, error) {
   130  	var er ErrorResponse
   131  	err := json.NewDecoder(r).Decode(&er)
   132  	if err != nil {
   133  		log.Error("Error decoding json: %v", err)
   134  	}
   135  	return er, err
   136  }