github.com/anth0d/nomad@v0.0.0-20221214183521-ae3a0a2cad06/client/allocrunner/taskrunner/artifact_hook.go (about)

     1  package taskrunner
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"sync"
     7  
     8  	log "github.com/hashicorp/go-hclog"
     9  	"github.com/hashicorp/nomad/client/allocrunner/interfaces"
    10  	ti "github.com/hashicorp/nomad/client/allocrunner/taskrunner/interfaces"
    11  	ci "github.com/hashicorp/nomad/client/interfaces"
    12  	"github.com/hashicorp/nomad/nomad/structs"
    13  )
    14  
    15  // artifactHook downloads artifacts for a task.
    16  type artifactHook struct {
    17  	eventEmitter ti.EventEmitter
    18  	logger       log.Logger
    19  	getter       ci.ArtifactGetter
    20  }
    21  
    22  func newArtifactHook(e ti.EventEmitter, getter ci.ArtifactGetter, logger log.Logger) *artifactHook {
    23  	h := &artifactHook{
    24  		eventEmitter: e,
    25  		getter:       getter,
    26  	}
    27  	h.logger = logger.Named(h.Name())
    28  	return h
    29  }
    30  
    31  func (h *artifactHook) doWork(req *interfaces.TaskPrestartRequest, resp *interfaces.TaskPrestartResponse, jobs chan *structs.TaskArtifact, errorChannel chan error, wg *sync.WaitGroup, responseStateMutex *sync.Mutex) {
    32  	defer wg.Done()
    33  	for artifact := range jobs {
    34  		aid := artifact.Hash()
    35  		if req.PreviousState[aid] != "" {
    36  			h.logger.Trace("skipping already downloaded artifact", "artifact", artifact.GetterSource)
    37  			responseStateMutex.Lock()
    38  			resp.State[aid] = req.PreviousState[aid]
    39  			responseStateMutex.Unlock()
    40  			continue
    41  		}
    42  
    43  		h.logger.Debug("downloading artifact", "artifact", artifact.GetterSource, "aid", aid)
    44  
    45  		if err := h.getter.Get(req.TaskEnv, artifact); err != nil {
    46  			wrapped := structs.NewRecoverableError(
    47  				fmt.Errorf("failed to download artifact %q: %v", artifact.GetterSource, err),
    48  				true,
    49  			)
    50  			herr := NewHookError(wrapped, structs.NewTaskEvent(structs.TaskArtifactDownloadFailed).SetDownloadError(wrapped))
    51  
    52  			errorChannel <- herr
    53  			continue
    54  		}
    55  
    56  		// Mark artifact as downloaded to avoid re-downloading due to
    57  		// retries caused by subsequent artifacts failing. Any
    58  		// non-empty value works.
    59  		responseStateMutex.Lock()
    60  		resp.State[aid] = "1"
    61  		responseStateMutex.Unlock()
    62  	}
    63  }
    64  
    65  func (*artifactHook) Name() string {
    66  	// Copied in client/state when upgrading from <0.9 schemas, so if you
    67  	// change it here you also must change it there.
    68  	return "artifacts"
    69  }
    70  
    71  func (h *artifactHook) Prestart(ctx context.Context, req *interfaces.TaskPrestartRequest, resp *interfaces.TaskPrestartResponse) error {
    72  	if len(req.Task.Artifacts) == 0 {
    73  		resp.Done = true
    74  		return nil
    75  	}
    76  
    77  	// Initialize hook state to store download progress
    78  	resp.State = make(map[string]string, len(req.Task.Artifacts))
    79  
    80  	// responseStateMutex is a lock used to guard against concurrent writes to the above resp.State map
    81  	responseStateMutex := &sync.Mutex{}
    82  
    83  	h.eventEmitter.EmitEvent(structs.NewTaskEvent(structs.TaskDownloadingArtifacts))
    84  
    85  	// maxConcurrency denotes the number of workers that will download artifacts in parallel
    86  	maxConcurrency := 3
    87  
    88  	// jobsChannel is a buffered channel which will have all the artifacts that needs to be processed
    89  	jobsChannel := make(chan *structs.TaskArtifact, maxConcurrency)
    90  
    91  	// errorChannel is also a buffered channel that will be used to signal errors
    92  	errorChannel := make(chan error, maxConcurrency)
    93  
    94  	// create workers and process artifacts
    95  	go func() {
    96  		defer close(errorChannel)
    97  		var wg sync.WaitGroup
    98  		for i := 0; i < maxConcurrency; i++ {
    99  			wg.Add(1)
   100  			go h.doWork(req, resp, jobsChannel, errorChannel, &wg, responseStateMutex)
   101  		}
   102  		wg.Wait()
   103  	}()
   104  
   105  	// Push all artifact requests to job channel
   106  	go func() {
   107  		defer close(jobsChannel)
   108  		for _, artifact := range req.Task.Artifacts {
   109  			jobsChannel <- artifact
   110  		}
   111  	}()
   112  
   113  	// Iterate over the errorChannel and if there is an error, store it to a variable for future return
   114  	var err error
   115  	for e := range errorChannel {
   116  		err = e
   117  	}
   118  
   119  	// once error channel is closed, we can check and return the error
   120  	if err != nil {
   121  		return err
   122  	}
   123  
   124  	resp.Done = true
   125  	return nil
   126  }