github.com/filecoin-project/bacalhau@v0.3.23-0.20230228154132-45c989550ace/pkg/storage/url/urldownload/storage.go (about)

     1  package urldownload
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"io"
     7  	"net/http"
     8  	"net/url"
     9  	"os"
    10  	"path"
    11  	"path/filepath"
    12  	"strings"
    13  	"time"
    14  
    15  	"github.com/filecoin-project/bacalhau/pkg/config"
    16  	"github.com/filecoin-project/bacalhau/pkg/model"
    17  	"github.com/filecoin-project/bacalhau/pkg/storage"
    18  	"github.com/filecoin-project/bacalhau/pkg/system"
    19  	"github.com/filecoin-project/bacalhau/pkg/util/closer"
    20  	"github.com/google/uuid"
    21  	"github.com/hashicorp/go-retryablehttp"
    22  	"github.com/rs/zerolog"
    23  	"github.com/rs/zerolog/log"
    24  	"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
    25  	semconv "go.opentelemetry.io/otel/semconv/v1.17.0"
    26  	"go.opentelemetry.io/otel/trace"
    27  )
    28  
    29  // a storage driver runs the downloads content
    30  // from a public URL source and copies it to
    31  // a local directory in preparation for
    32  // a job to run - it will remove the folder/file once complete
    33  
    34  type StorageProvider struct {
    35  	localDir string
    36  	client   *retryablehttp.Client
    37  }
    38  
    39  func NewStorage(cm *system.CleanupManager) (*StorageProvider, error) {
    40  	// TODO: consolidate the various config inputs into one package otherwise they are scattered across the codebase
    41  	dir, err := os.MkdirTemp(config.GetStoragePath(), "bacalhau-url")
    42  	if err != nil {
    43  		return nil, err
    44  	}
    45  
    46  	cm.RegisterCallback(func() error {
    47  		if err := os.RemoveAll(dir); err != nil {
    48  			return fmt.Errorf("unable to remove storage folder: %w", err)
    49  		}
    50  		return nil
    51  	})
    52  
    53  	log.Debug().Str("dir", dir).Msg("URL download driver created with output dir")
    54  
    55  	return newStorage(dir), nil
    56  }
    57  
    58  func newStorage(dir string) *StorageProvider {
    59  	client := retryablehttp.NewClient()
    60  	client.HTTPClient = &http.Client{
    61  		Timeout: config.GetDownloadURLRequestTimeout(),
    62  		Transport: otelhttp.NewTransport(nil, otelhttp.WithSpanNameFormatter(func(operation string, r *http.Request) string {
    63  			return fmt.Sprintf("%s %s", r.Method, r.URL.Path)
    64  		}), otelhttp.WithSpanOptions(trace.WithAttributes(semconv.PeerService("url-download")))),
    65  	}
    66  	client.RetryMax = config.GetDownloadURLRequestRetries()
    67  	client.RetryWaitMax = time.Second * 1
    68  	client.Logger = retryLogger{}
    69  	client.CheckRetry = func(ctx context.Context, resp *http.Response, err error) (bool, error) {
    70  		if err := ctx.Err(); err != nil { //nolint:govet
    71  			return false, err
    72  		}
    73  		if err == nil {
    74  			// Existing behavior around retrying is to retry on _all_ non 2xx status codes. This includes codes that would have no
    75  			// realistic hope of succeeding like `Unauthorized` or `Gone`
    76  			if resp.StatusCode >= http.StatusOK && resp.StatusCode < http.StatusBadRequest {
    77  				return false, nil
    78  			}
    79  			return true, nil
    80  		}
    81  
    82  		return retryablehttp.DefaultRetryPolicy(ctx, resp, err)
    83  	}
    84  
    85  	return &StorageProvider{
    86  		localDir: dir,
    87  		client:   client,
    88  	}
    89  }
    90  
    91  func (sp *StorageProvider) IsInstalled(context.Context) (bool, error) {
    92  	return true, nil
    93  }
    94  
    95  func (sp *StorageProvider) HasStorageLocally(context.Context, model.StorageSpec) (bool, error) {
    96  	return false, nil
    97  }
    98  
    99  func (sp *StorageProvider) GetVolumeSize(context.Context, model.StorageSpec) (uint64, error) {
   100  	// Could do a HEAD request and check Content-Length, but in some cases that's not guaranteed to be the real end file size
   101  	return 0, nil
   102  }
   103  
   104  // PrepareStorage will download the file from the URL
   105  func (sp *StorageProvider) PrepareStorage(ctx context.Context, storageSpec model.StorageSpec) (storage.StorageVolume, error) {
   106  	u, err := IsURLSupported(storageSpec.URL)
   107  	if err != nil {
   108  		return storage.StorageVolume{}, err
   109  	}
   110  
   111  	outputPath, err := os.MkdirTemp(sp.localDir, "*")
   112  	if err != nil {
   113  		return storage.StorageVolume{}, err
   114  	}
   115  
   116  	req, err := retryablehttp.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
   117  	if err != nil {
   118  		return storage.StorageVolume{}, err
   119  	}
   120  	res, err := sp.client.Do(req) //nolint:bodyclose // this is being closed - golangci-lint is wrong again
   121  	if err != nil {
   122  		return storage.StorageVolume{}, fmt.Errorf("failed to begin download from url %s: %w", u, err)
   123  	}
   124  	defer closer.DrainAndCloseWithLogOnError(ctx, "response", res.Body)
   125  
   126  	if res.StatusCode < http.StatusOK || res.StatusCode >= http.StatusMultipleChoices {
   127  		return storage.StorageVolume{}, fmt.Errorf("non-200 response from URL (%s): %s", storageSpec.URL, res.Status)
   128  	}
   129  
   130  	baseName := path.Base(res.Request.URL.Path)
   131  	var fileName string
   132  	if baseName == "." || baseName == "/" {
   133  		// There is no filename in the URL, so we need to a temp one
   134  		fileName = uuid.UUID.String(uuid.New())
   135  	} else {
   136  		fileName = baseName
   137  	}
   138  
   139  	filePath := filepath.Join(outputPath, fileName)
   140  	w, err := os.Create(filePath)
   141  	if err != nil {
   142  		return storage.StorageVolume{}, fmt.Errorf("failed to create file %s: %s", filePath, err)
   143  	}
   144  
   145  	defer closer.CloseWithLogOnError("file", w)
   146  
   147  	// stream the body to the client without fully loading it into memory
   148  	if _, err := io.Copy(w, res.Body); err != nil {
   149  		return storage.StorageVolume{}, fmt.Errorf("failed to write to file %s: %s", filePath, err)
   150  	}
   151  
   152  	if err := w.Sync(); err != nil {
   153  		return storage.StorageVolume{}, fmt.Errorf("failed to sync file %s: %w", filePath, err)
   154  	}
   155  
   156  	targetPath := filepath.Join(storageSpec.Path, fileName)
   157  
   158  	log.Ctx(ctx).Debug().
   159  		Stringer("url", u).
   160  		Stringer("final-url", res.Request.URL).
   161  		Str("file", filePath).
   162  		Str("targetFile", targetPath).
   163  		Msg("Downloaded file")
   164  
   165  	volume := storage.StorageVolume{
   166  		Type:   storage.StorageVolumeConnectorBind,
   167  		Source: filePath,   // The source is the full path to the file
   168  		Target: targetPath, // So we should alter the target to include the file name
   169  	}
   170  
   171  	return volume, nil
   172  }
   173  
   174  func (sp *StorageProvider) CleanupStorage(
   175  	ctx context.Context,
   176  	_ model.StorageSpec,
   177  	volume storage.StorageVolume,
   178  ) error {
   179  	pathToCleanup := filepath.Dir(volume.Source)
   180  	log.Ctx(ctx).Debug().Str("Path", pathToCleanup).Msg("Cleaning up")
   181  	return os.RemoveAll(pathToCleanup)
   182  }
   183  
   184  func (sp *StorageProvider) Upload(context.Context, string) (model.StorageSpec, error) {
   185  	// we don't "upload" anything to a URL
   186  	return model.StorageSpec{}, fmt.Errorf("not implemented")
   187  }
   188  
   189  func (sp *StorageProvider) Explode(_ context.Context, spec model.StorageSpec) ([]model.StorageSpec, error) {
   190  	// for the url download - explode will always result in a single item
   191  	// mounted at the path specified in the spec
   192  	return []model.StorageSpec{
   193  		{
   194  			Name:          spec.Name,
   195  			StorageSource: model.StorageSourceURLDownload,
   196  			Path:          spec.Path,
   197  			URL:           spec.URL,
   198  		},
   199  	}, nil
   200  }
   201  
   202  func IsURLSupported(rawURL string) (*url.URL, error) {
   203  	rawURL = strings.Trim(rawURL, " '\"")
   204  	u, err := url.Parse(rawURL)
   205  	if err != nil {
   206  		return nil, fmt.Errorf("invalid URL: %s", err)
   207  	}
   208  	if (u.Scheme != "http") && (u.Scheme != "https") {
   209  		return nil, fmt.Errorf("URLs must begin with 'http' or 'https'. The submitted one began with %s", u.Scheme)
   210  	}
   211  
   212  	basePath := path.Base(u.Path)
   213  
   214  	// Need to check for both because a bare host
   215  	// Like http://localhost/ gets converted to "." by path.Base
   216  	if basePath == "" || u.Path == "" {
   217  		return nil, fmt.Errorf("URL must end with a file name")
   218  	}
   219  
   220  	return u, nil
   221  }
   222  
   223  var _ storage.Storage = (*StorageProvider)(nil)
   224  
   225  var _ retryablehttp.LeveledLogger = retryLogger{}
   226  
   227  // This logger needs to change to fetch the logger from the context once
   228  // https://github.com/hashicorp/go-retryablehttp/issues/182 is implemented and released.
   229  type retryLogger struct {
   230  }
   231  
   232  func (r retryLogger) Error(msg string, keysAndValues ...interface{}) {
   233  	parseKeysAndValues(log.Error(), keysAndValues...).Msg(msg)
   234  }
   235  
   236  func (r retryLogger) Info(msg string, keysAndValues ...interface{}) {
   237  	parseKeysAndValues(log.Info(), keysAndValues...).Msg(msg)
   238  }
   239  
   240  func (r retryLogger) Debug(msg string, keysAndValues ...interface{}) {
   241  	parseKeysAndValues(log.Debug(), keysAndValues...).Msg(msg)
   242  }
   243  
   244  func (r retryLogger) Warn(msg string, keysAndValues ...interface{}) {
   245  	parseKeysAndValues(log.Warn(), keysAndValues...).Msg(msg)
   246  }
   247  
   248  func parseKeysAndValues(e *zerolog.Event, keysAndValues ...interface{}) *zerolog.Event {
   249  	for i := 0; i < len(keysAndValues); i = i + 2 {
   250  		name := keysAndValues[i].(string)
   251  		value := keysAndValues[i+1]
   252  		if v, ok := value.(string); ok {
   253  			e = e.Str(name, v)
   254  		} else if v, ok := value.(error); ok {
   255  			e = e.AnErr(name, v)
   256  		} else if v, ok := value.(fmt.Stringer); ok {
   257  			e = e.Stringer(name, v)
   258  		} else if v, ok := value.(int); ok {
   259  			e = e.Int(name, v)
   260  		} else {
   261  			e = e.Interface(name, value)
   262  		}
   263  	}
   264  	return e
   265  }