github.com/pf-qiu/concourse/v6@v6.7.3-0.20201207032516-1f455d73275f/atc/db/worker_factory.go (about)

     1  package db
     2  
     3  import (
     4  	"database/sql"
     5  	"encoding/json"
     6  	"errors"
     7  	"fmt"
     8  	"time"
     9  
    10  	sq "github.com/Masterminds/squirrel"
    11  	"github.com/pf-qiu/concourse/v6/atc"
    12  	"github.com/lib/pq"
    13  )
    14  
    15  //go:generate counterfeiter . WorkerFactory
    16  
    17  type WorkerFactory interface {
    18  	GetWorker(name string) (Worker, bool, error)
    19  	SaveWorker(atcWorker atc.Worker, ttl time.Duration) (Worker, error)
    20  	HeartbeatWorker(worker atc.Worker, ttl time.Duration) (Worker, error)
    21  	Workers() ([]Worker, error)
    22  	VisibleWorkers([]string) ([]Worker, error)
    23  
    24  	FindWorkersForContainerByOwner(ContainerOwner) ([]Worker, error)
    25  	BuildContainersCountPerWorker() (map[string]int, error)
    26  }
    27  
    28  type workerFactory struct {
    29  	conn Conn
    30  }
    31  
    32  func NewWorkerFactory(conn Conn) WorkerFactory {
    33  	return &workerFactory{
    34  		conn: conn,
    35  	}
    36  }
    37  
    38  var workersQuery = psql.Select(`
    39  		w.name,
    40  		w.version,
    41  		w.addr,
    42  		w.state,
    43  		w.baggageclaim_url,
    44  		w.certs_path,
    45  		w.http_proxy_url,
    46  		w.https_proxy_url,
    47  		w.no_proxy,
    48  		w.active_containers,
    49  		w.active_volumes,
    50  		w.resource_types,
    51  		w.platform,
    52  		w.tags,
    53  		t.name,
    54  		w.team_id,
    55  		w.start_time,
    56  		w.expires,
    57  		w.ephemeral
    58  	`).
    59  	From("workers w").
    60  	LeftJoin("teams t ON w.team_id = t.id")
    61  
    62  func (f *workerFactory) GetWorker(name string) (Worker, bool, error) {
    63  	return getWorker(f.conn, workersQuery.Where(sq.Eq{"w.name": name}))
    64  }
    65  
    66  func (f *workerFactory) VisibleWorkers(teamNames []string) ([]Worker, error) {
    67  	workersQuery := workersQuery.
    68  		Where(sq.Or{
    69  			sq.Eq{"t.name": teamNames},
    70  			sq.Eq{"w.team_id": nil},
    71  		})
    72  
    73  	workers, err := getWorkers(f.conn, workersQuery)
    74  	if err != nil {
    75  		return nil, err
    76  	}
    77  
    78  	return workers, nil
    79  }
    80  
    81  func (f *workerFactory) Workers() ([]Worker, error) {
    82  	return getWorkers(f.conn, workersQuery)
    83  }
    84  
    85  func getWorker(conn Conn, query sq.SelectBuilder) (Worker, bool, error) {
    86  	row := query.
    87  		RunWith(conn).
    88  		QueryRow()
    89  
    90  	w := &worker{conn: conn}
    91  
    92  	err := scanWorker(w, row)
    93  	if err != nil {
    94  		if err == sql.ErrNoRows {
    95  			return nil, false, nil
    96  		}
    97  		return nil, false, err
    98  	}
    99  
   100  	return w, true, nil
   101  }
   102  
   103  func getWorkers(conn Conn, query sq.SelectBuilder) ([]Worker, error) {
   104  	rows, err := query.RunWith(conn).Query()
   105  	if err != nil {
   106  		return nil, err
   107  	}
   108  	defer Close(rows)
   109  
   110  	workers := []Worker{}
   111  
   112  	for rows.Next() {
   113  		worker := &worker{conn: conn}
   114  		err := scanWorker(worker, rows)
   115  		if err != nil {
   116  			return nil, err
   117  		}
   118  
   119  		workers = append(workers, worker)
   120  	}
   121  
   122  	return workers, nil
   123  }
   124  
   125  func scanWorker(worker *worker, row scannable) error {
   126  	var (
   127  		version       sql.NullString
   128  		addStr        sql.NullString
   129  		state         string
   130  		bcURLStr      sql.NullString
   131  		certsPathStr  sql.NullString
   132  		httpProxyURL  sql.NullString
   133  		httpsProxyURL sql.NullString
   134  		noProxy       sql.NullString
   135  		resourceTypes []byte
   136  		platform      sql.NullString
   137  		tags          []byte
   138  		teamName      sql.NullString
   139  		teamID        sql.NullInt64
   140  		startTime     pq.NullTime
   141  		expiresAt     pq.NullTime
   142  		ephemeral     sql.NullBool
   143  	)
   144  
   145  	err := row.Scan(
   146  		&worker.name,
   147  		&version,
   148  		&addStr,
   149  		&state,
   150  		&bcURLStr,
   151  		&certsPathStr,
   152  		&httpProxyURL,
   153  		&httpsProxyURL,
   154  		&noProxy,
   155  		&worker.activeContainers,
   156  		&worker.activeVolumes,
   157  		&resourceTypes,
   158  		&platform,
   159  		&tags,
   160  		&teamName,
   161  		&teamID,
   162  		&startTime,
   163  		&expiresAt,
   164  		&ephemeral,
   165  	)
   166  	if err != nil {
   167  		return err
   168  	}
   169  
   170  	if version.Valid {
   171  		worker.version = &version.String
   172  	}
   173  
   174  	if addStr.Valid {
   175  		worker.gardenAddr = &addStr.String
   176  	}
   177  
   178  	if bcURLStr.Valid {
   179  		worker.baggageclaimURL = &bcURLStr.String
   180  	}
   181  
   182  	if certsPathStr.Valid {
   183  		worker.certsPath = &certsPathStr.String
   184  	}
   185  
   186  	worker.state = WorkerState(state)
   187  	worker.startTime = startTime.Time
   188  	worker.expiresAt = expiresAt.Time
   189  
   190  	if httpProxyURL.Valid {
   191  		worker.httpProxyURL = httpProxyURL.String
   192  	}
   193  
   194  	if httpsProxyURL.Valid {
   195  		worker.httpsProxyURL = httpsProxyURL.String
   196  	}
   197  
   198  	if noProxy.Valid {
   199  		worker.noProxy = noProxy.String
   200  	}
   201  
   202  	if teamName.Valid {
   203  		worker.teamName = teamName.String
   204  	}
   205  
   206  	if teamID.Valid {
   207  		worker.teamID = int(teamID.Int64)
   208  	}
   209  
   210  	if platform.Valid {
   211  		worker.platform = platform.String
   212  	}
   213  
   214  	if ephemeral.Valid {
   215  		worker.ephemeral = ephemeral.Bool
   216  	}
   217  
   218  	err = json.Unmarshal(resourceTypes, &worker.resourceTypes)
   219  	if err != nil {
   220  		return err
   221  	}
   222  
   223  	return json.Unmarshal(tags, &worker.tags)
   224  }
   225  
   226  func (f *workerFactory) HeartbeatWorker(atcWorker atc.Worker, ttl time.Duration) (Worker, error) {
   227  	// In order to be able to calculate the ttl that we return to the caller
   228  	// we must compare time.Now() to the worker.expires column
   229  	// However, workers.expires column is a "timestamp (without timezone)"
   230  	// So we format time.Now() without any timezone information and then
   231  	// parse that using the same layout to strip the timezone information
   232  
   233  	tx, err := f.conn.Begin()
   234  	if err != nil {
   235  		return nil, err
   236  	}
   237  	defer Rollback(tx)
   238  
   239  	expires := "NULL"
   240  	if ttl != 0 {
   241  		expires = fmt.Sprintf(`NOW() + '%d second'::INTERVAL`, int(ttl.Seconds()))
   242  	}
   243  
   244  	cSQL, _, err := sq.Case("state").
   245  		When("'landing'::worker_state", "'landing'::worker_state").
   246  		When("'landed'::worker_state", "'landed'::worker_state").
   247  		When("'retiring'::worker_state", "'retiring'::worker_state").
   248  		Else("'running'::worker_state").
   249  		ToSql()
   250  
   251  	if err != nil {
   252  		return nil, err
   253  	}
   254  
   255  	_, err = psql.Update("workers").
   256  		Set("expires", sq.Expr(expires)).
   257  		Set("active_containers", atcWorker.ActiveContainers).
   258  		Set("active_volumes", atcWorker.ActiveVolumes).
   259  		Set("state", sq.Expr("("+cSQL+")")).
   260  		Where(sq.Eq{"name": atcWorker.Name}).
   261  		RunWith(tx).
   262  		Exec()
   263  	if err != nil {
   264  		if err == sql.ErrNoRows {
   265  			return nil, ErrWorkerNotPresent
   266  		}
   267  		return nil, err
   268  	}
   269  
   270  	row := workersQuery.Where(sq.Eq{"w.name": atcWorker.Name}).
   271  		RunWith(tx).
   272  		QueryRow()
   273  
   274  	worker := &worker{conn: f.conn}
   275  	err = scanWorker(worker, row)
   276  	if err != nil {
   277  		if err == sql.ErrNoRows {
   278  			return nil, ErrWorkerNotPresent
   279  		}
   280  		return nil, err
   281  	}
   282  
   283  	err = tx.Commit()
   284  	if err != nil {
   285  		return nil, err
   286  	}
   287  	return worker, nil
   288  
   289  }
   290  
   291  func (f *workerFactory) SaveWorker(atcWorker atc.Worker, ttl time.Duration) (Worker, error) {
   292  	tx, err := f.conn.Begin()
   293  	if err != nil {
   294  		return nil, err
   295  	}
   296  
   297  	defer Rollback(tx)
   298  
   299  	savedWorker, err := saveWorker(tx, atcWorker, nil, ttl, f.conn)
   300  	if err != nil {
   301  		return nil, err
   302  	}
   303  
   304  	err = tx.Commit()
   305  	if err != nil {
   306  		return nil, err
   307  	}
   308  
   309  	return savedWorker, nil
   310  }
   311  
   312  func (f *workerFactory) FindWorkersForContainerByOwner(owner ContainerOwner) ([]Worker, error) {
   313  	ownerQuery, found, err := owner.Find(f.conn)
   314  	if err != nil {
   315  		return nil, err
   316  	}
   317  
   318  	if !found {
   319  		return []Worker{}, nil
   320  	}
   321  
   322  	ownerEq := sq.Eq{}
   323  	for k, v := range ownerQuery {
   324  		ownerEq["c."+k] = v
   325  	}
   326  
   327  	workers, err := getWorkers(f.conn, workersQuery.Join("containers c ON c.worker_name = w.name").Where(sq.And{
   328  		ownerEq,
   329  	}))
   330  	if err != nil {
   331  		return nil, err
   332  	}
   333  
   334  	return workers, nil
   335  }
   336  
   337  func (f *workerFactory) BuildContainersCountPerWorker() (map[string]int, error) {
   338  	rows, err := psql.Select("worker_name, COUNT(*)").
   339  		From("containers").
   340  		Where("build_id IS NOT NULL").
   341  		GroupBy("worker_name").
   342  		RunWith(f.conn).
   343  		Query()
   344  	if err != nil {
   345  		return nil, err
   346  	}
   347  
   348  	defer Close(rows)
   349  
   350  	countByWorker := make(map[string]int)
   351  
   352  	for rows.Next() {
   353  		var workerName string
   354  		var containersCount int
   355  
   356  		err = rows.Scan(&workerName, &containersCount)
   357  		if err != nil {
   358  			return nil, err
   359  		}
   360  
   361  		countByWorker[workerName] = containersCount
   362  	}
   363  
   364  	return countByWorker, nil
   365  }
   366  
   367  func saveWorker(tx Tx, atcWorker atc.Worker, teamID *int, ttl time.Duration, conn Conn) (Worker, error) {
   368  	resourceTypes, err := json.Marshal(atcWorker.ResourceTypes)
   369  	if err != nil {
   370  		return nil, err
   371  	}
   372  
   373  	tags, err := json.Marshal(atcWorker.Tags)
   374  	if err != nil {
   375  		return nil, err
   376  	}
   377  
   378  	expires := "NULL"
   379  	if ttl != 0 {
   380  		expires = fmt.Sprintf(`NOW() + '%d second'::INTERVAL`, int(ttl.Seconds()))
   381  	}
   382  
   383  	startTime := fmt.Sprintf(`to_timestamp(%d)`, atcWorker.StartTime)
   384  
   385  	var workerState WorkerState
   386  	if atcWorker.State != "" {
   387  		workerState = WorkerState(atcWorker.State)
   388  	} else {
   389  		workerState = WorkerStateRunning
   390  	}
   391  
   392  	var workerVersion *string
   393  	if atcWorker.Version != "" {
   394  		workerVersion = &atcWorker.Version
   395  	}
   396  
   397  	values := []interface{}{
   398  		atcWorker.GardenAddr,
   399  		atcWorker.ActiveContainers,
   400  		atcWorker.ActiveVolumes,
   401  		resourceTypes,
   402  		tags,
   403  		atcWorker.Platform,
   404  		atcWorker.BaggageclaimURL,
   405  		atcWorker.CertsPath,
   406  		atcWorker.HTTPProxyURL,
   407  		atcWorker.HTTPSProxyURL,
   408  		atcWorker.NoProxy,
   409  		atcWorker.Name,
   410  		workerVersion,
   411  		string(workerState),
   412  		teamID,
   413  		atcWorker.Ephemeral,
   414  	}
   415  
   416  	conflictValues := values
   417  	var matchTeamUpsert string
   418  	if teamID == nil {
   419  		matchTeamUpsert = "workers.team_id IS NULL"
   420  	} else {
   421  		matchTeamUpsert = "workers.team_id = ?"
   422  		conflictValues = append(conflictValues, *teamID)
   423  	}
   424  
   425  	rows, err := psql.Insert("workers").
   426  		Columns(
   427  			"expires",
   428  			"start_time",
   429  			"addr",
   430  			"active_containers",
   431  			"active_volumes",
   432  			"resource_types",
   433  			"tags",
   434  			"platform",
   435  			"baggageclaim_url",
   436  			"certs_path",
   437  			"http_proxy_url",
   438  			"https_proxy_url",
   439  			"no_proxy",
   440  			"name",
   441  			"version",
   442  			"state",
   443  			"team_id",
   444  			"ephemeral",
   445  		).
   446  		Values(append([]interface{}{
   447  			sq.Expr(expires),
   448  			sq.Expr(startTime),
   449  		}, values...)...).
   450  		Suffix(`
   451  			ON CONFLICT (name) DO UPDATE SET
   452  				expires = `+expires+`,
   453  				start_time = `+startTime+`,
   454  				addr = ?,
   455  				active_containers = ?,
   456  				active_volumes = ?,
   457  				resource_types = ?,
   458  				tags = ?,
   459  				platform = ?,
   460  				baggageclaim_url = ?,
   461  				certs_path = ?,
   462  				http_proxy_url = ?,
   463  				https_proxy_url = ?,
   464  				no_proxy = ?,
   465  				name = ?,
   466  				version = ?,
   467  				state = ?,
   468  				team_id = ?,
   469  				ephemeral = ?
   470  			WHERE `+matchTeamUpsert,
   471  			conflictValues...,
   472  		).
   473  		RunWith(tx).
   474  		Exec()
   475  	if err != nil {
   476  		return nil, err
   477  	}
   478  
   479  	count, err := rows.RowsAffected()
   480  	if err != nil {
   481  		return nil, err
   482  	}
   483  
   484  	if count == 0 {
   485  		return nil, errors.New("worker already exists and is either global or owned by another team")
   486  	}
   487  
   488  	var workerTeamID int
   489  	if teamID != nil {
   490  		workerTeamID = *teamID
   491  	}
   492  
   493  	savedWorker := &worker{
   494  		name:             atcWorker.Name,
   495  		version:          workerVersion,
   496  		state:            workerState,
   497  		gardenAddr:       &atcWorker.GardenAddr,
   498  		baggageclaimURL:  &atcWorker.BaggageclaimURL,
   499  		certsPath:        atcWorker.CertsPath,
   500  		httpProxyURL:     atcWorker.HTTPProxyURL,
   501  		httpsProxyURL:    atcWorker.HTTPSProxyURL,
   502  		noProxy:          atcWorker.NoProxy,
   503  		activeContainers: atcWorker.ActiveContainers,
   504  		activeVolumes:    atcWorker.ActiveVolumes,
   505  		resourceTypes:    atcWorker.ResourceTypes,
   506  		platform:         atcWorker.Platform,
   507  		tags:             atcWorker.Tags,
   508  		teamName:         atcWorker.Team,
   509  		teamID:           workerTeamID,
   510  		startTime:        time.Unix(atcWorker.StartTime, 0),
   511  		ephemeral:        atcWorker.Ephemeral,
   512  		conn:             conn,
   513  	}
   514  
   515  	workerBaseResourceTypeIDs := []int{}
   516  
   517  	for _, resourceType := range atcWorker.ResourceTypes {
   518  		workerResourceType := WorkerResourceType{
   519  			Worker:  savedWorker,
   520  			Image:   resourceType.Image,
   521  			Version: resourceType.Version,
   522  			BaseResourceType: &BaseResourceType{
   523  				Name: resourceType.Type,
   524  			},
   525  		}
   526  
   527  		uwrt, err := workerResourceType.FindOrCreate(tx, resourceType.UniqueVersionHistory)
   528  		if err != nil {
   529  			return nil, err
   530  		}
   531  
   532  		workerBaseResourceTypeIDs = append(workerBaseResourceTypeIDs, uwrt.ID)
   533  	}
   534  
   535  	_, err = psql.Delete("worker_base_resource_types").
   536  		Where(sq.Eq{
   537  			"worker_name": atcWorker.Name,
   538  		}).
   539  		Where(sq.NotEq{
   540  			"id": workerBaseResourceTypeIDs,
   541  		}).
   542  		RunWith(tx).
   543  		Exec()
   544  	if err != nil {
   545  		return nil, err
   546  	}
   547  
   548  	if atcWorker.CertsPath != nil {
   549  		_, err := WorkerResourceCerts{
   550  			WorkerName: atcWorker.Name,
   551  			CertsPath:  *atcWorker.CertsPath,
   552  		}.FindOrCreate(tx)
   553  		if err != nil {
   554  			return nil, err
   555  		}
   556  	}
   557  
   558  	return savedWorker, nil
   559  }