github.com/iqoqo/nomad@v0.11.3-0.20200911112621-d7021c74d101/client/allocrunner/taskrunner/validate_hook.go (about) 1 package taskrunner 2 3 import ( 4 "context" 5 "fmt" 6 7 log "github.com/hashicorp/go-hclog" 8 multierror "github.com/hashicorp/go-multierror" 9 "github.com/hashicorp/nomad/client/allocrunner/interfaces" 10 "github.com/hashicorp/nomad/client/config" 11 "github.com/hashicorp/nomad/client/taskenv" 12 "github.com/hashicorp/nomad/nomad/structs" 13 ) 14 15 // validateHook validates the task is able to be run. 16 type validateHook struct { 17 config *config.Config 18 logger log.Logger 19 } 20 21 func newValidateHook(config *config.Config, logger log.Logger) *validateHook { 22 h := &validateHook{ 23 config: config, 24 } 25 h.logger = logger.Named(h.Name()) 26 return h 27 } 28 29 func (*validateHook) Name() string { 30 return "validate" 31 } 32 33 func (h *validateHook) Prestart(ctx context.Context, req *interfaces.TaskPrestartRequest, resp *interfaces.TaskPrestartResponse) error { 34 if err := validateTask(req.Task, req.TaskEnv, h.config); err != nil { 35 return err 36 } 37 38 resp.Done = true 39 return nil 40 } 41 42 func validateTask(task *structs.Task, taskEnv *taskenv.TaskEnv, conf *config.Config) error { 43 var mErr multierror.Error 44 45 // Validate the user 46 unallowedUsers := conf.ReadStringListToMapDefault("user.blacklist", config.DefaultUserBlacklist) 47 checkDrivers := conf.ReadStringListToMapDefault("user.checked_drivers", config.DefaultUserCheckedDrivers) 48 if _, driverMatch := checkDrivers[task.Driver]; driverMatch { 49 if _, unallowed := unallowedUsers[task.User]; unallowed { 50 mErr.Errors = append(mErr.Errors, fmt.Errorf("running as user %q is disallowed", task.User)) 51 } 52 } 53 54 // Validate the Service names once they're interpolated 55 for i, service := range task.Services { 56 name := taskEnv.ReplaceEnv(service.Name) 57 if err := service.ValidateName(name); err != nil { 58 mErr.Errors = append(mErr.Errors, fmt.Errorf("service (%d) failed validation: %v", i, err)) 59 } 60 } 61 62 if len(mErr.Errors) == 1 { 63 return mErr.Errors[0] 64 } 65 return mErr.ErrorOrNil() 66 }