code.gitea.io/gitea@v1.19.3/modules/lfs/transferadapter.go (about)

     1  // Copyright 2021 The Gitea Authors. All rights reserved.
     2  // SPDX-License-Identifier: MIT
     3  
     4  package lfs
     5  
     6  import (
     7  	"bytes"
     8  	"context"
     9  	"errors"
    10  	"fmt"
    11  	"io"
    12  	"net/http"
    13  
    14  	"code.gitea.io/gitea/modules/json"
    15  	"code.gitea.io/gitea/modules/log"
    16  )
    17  
    18  // TransferAdapter represents an adapter for downloading/uploading LFS objects
    19  type TransferAdapter interface {
    20  	Name() string
    21  	Download(ctx context.Context, l *Link) (io.ReadCloser, error)
    22  	Upload(ctx context.Context, l *Link, p Pointer, r io.Reader) error
    23  	Verify(ctx context.Context, l *Link, p Pointer) error
    24  }
    25  
    26  // BasicTransferAdapter implements the "basic" adapter
    27  type BasicTransferAdapter struct {
    28  	client *http.Client
    29  }
    30  
    31  // Name returns the name of the adapter
    32  func (a *BasicTransferAdapter) Name() string {
    33  	return "basic"
    34  }
    35  
    36  // Download reads the download location and downloads the data
    37  func (a *BasicTransferAdapter) Download(ctx context.Context, l *Link) (io.ReadCloser, error) {
    38  	resp, err := a.performRequest(ctx, "GET", l, nil, nil)
    39  	if err != nil {
    40  		return nil, err
    41  	}
    42  	return resp.Body, nil
    43  }
    44  
    45  // Upload sends the content to the LFS server
    46  func (a *BasicTransferAdapter) Upload(ctx context.Context, l *Link, p Pointer, r io.Reader) error {
    47  	_, err := a.performRequest(ctx, "PUT", l, r, func(req *http.Request) {
    48  		if len(req.Header.Get("Content-Type")) == 0 {
    49  			req.Header.Set("Content-Type", "application/octet-stream")
    50  		}
    51  
    52  		if req.Header.Get("Transfer-Encoding") == "chunked" {
    53  			req.TransferEncoding = []string{"chunked"}
    54  		}
    55  
    56  		req.ContentLength = p.Size
    57  	})
    58  	if err != nil {
    59  		return err
    60  	}
    61  	return nil
    62  }
    63  
    64  // Verify calls the verify handler on the LFS server
    65  func (a *BasicTransferAdapter) Verify(ctx context.Context, l *Link, p Pointer) error {
    66  	b, err := json.Marshal(p)
    67  	if err != nil {
    68  		log.Error("Error encoding json: %v", err)
    69  		return err
    70  	}
    71  
    72  	_, err = a.performRequest(ctx, "POST", l, bytes.NewReader(b), func(req *http.Request) {
    73  		req.Header.Set("Content-Type", MediaType)
    74  	})
    75  	if err != nil {
    76  		return err
    77  	}
    78  	return nil
    79  }
    80  
    81  func (a *BasicTransferAdapter) performRequest(ctx context.Context, method string, l *Link, body io.Reader, callback func(*http.Request)) (*http.Response, error) {
    82  	log.Trace("Calling: %s %s", method, l.Href)
    83  
    84  	req, err := http.NewRequestWithContext(ctx, method, l.Href, body)
    85  	if err != nil {
    86  		log.Error("Error creating request: %v", err)
    87  		return nil, err
    88  	}
    89  	for key, value := range l.Header {
    90  		req.Header.Set(key, value)
    91  	}
    92  	req.Header.Set("Accept", MediaType)
    93  
    94  	if callback != nil {
    95  		callback(req)
    96  	}
    97  
    98  	res, err := a.client.Do(req)
    99  	if err != nil {
   100  		select {
   101  		case <-ctx.Done():
   102  			return res, ctx.Err()
   103  		default:
   104  		}
   105  		log.Error("Error while processing request: %v", err)
   106  		return res, err
   107  	}
   108  
   109  	if res.StatusCode != http.StatusOK {
   110  		return res, handleErrorResponse(res)
   111  	}
   112  
   113  	return res, nil
   114  }
   115  
   116  func handleErrorResponse(resp *http.Response) error {
   117  	defer resp.Body.Close()
   118  
   119  	er, err := decodeResponseError(resp.Body)
   120  	if err != nil {
   121  		return fmt.Errorf("Request failed with status %s", resp.Status)
   122  	}
   123  	log.Trace("ErrorRespone: %v", er)
   124  	return errors.New(er.Message)
   125  }
   126  
   127  func decodeResponseError(r io.Reader) (ErrorResponse, error) {
   128  	var er ErrorResponse
   129  	err := json.NewDecoder(r).Decode(&er)
   130  	if err != nil {
   131  		log.Error("Error decoding json: %v", err)
   132  	}
   133  	return er, err
   134  }