github.com/nais/knorten@v0.0.0-20240104110906-55926958e361/pkg/helm/application.go (about)

     1  package helm
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"fmt"
     7  	"log"
     8  	"os"
     9  	"strconv"
    10  	"strings"
    11  	"time"
    12  
    13  	"gopkg.in/yaml.v3"
    14  	"helm.sh/helm/v3/pkg/action"
    15  	"helm.sh/helm/v3/pkg/chart"
    16  	"helm.sh/helm/v3/pkg/cli"
    17  	"helm.sh/helm/v3/pkg/release"
    18  	"k8s.io/utils/strings/slices"
    19  
    20  	"github.com/nais/knorten/pkg/database"
    21  	"github.com/nais/knorten/pkg/database/gensql"
    22  	"github.com/nais/knorten/pkg/logger"
    23  )
    24  
    25  const (
    26  	timeout = 30 * time.Minute
    27  )
    28  
    29  type HelmEventData struct {
    30  	TeamID       string
    31  	Namespace    string
    32  	ReleaseName  string
    33  	ChartType    gensql.ChartType
    34  	ChartRepo    string
    35  	ChartName    string
    36  	ChartVersion string
    37  }
    38  
    39  type Client struct {
    40  	dryRun bool
    41  	repo   *database.Repo
    42  }
    43  
    44  func NewClient(dryRun bool, repo *database.Repo) Client {
    45  	return Client{
    46  		dryRun: dryRun,
    47  		repo:   repo,
    48  	}
    49  }
    50  
    51  func (c Client) InstallOrUpgrade(ctx context.Context, helmEvent HelmEventData, logger logger.Logger) error {
    52  	logger.Infof("Installing or upgrading %v", helmEvent.ChartType)
    53  	rollback, err := c.installOrUpgrade(ctx, helmEvent, logger)
    54  	if rollback {
    55  		switch helmEvent.ChartType {
    56  		case gensql.ChartTypeJupyterhub:
    57  			if err := c.repo.RegisterHelmRollbackJupyterEvent(context.Background(), helmEvent.TeamID, helmEvent); err != nil {
    58  				logger.WithError(err).Error("registering helm rollback jupyter event")
    59  			}
    60  		case gensql.ChartTypeAirflow:
    61  			if err := c.repo.RegisterHelmRollbackAirflowEvent(context.Background(), helmEvent.TeamID, helmEvent); err != nil {
    62  				logger.WithError(err).Error("registering helm rollback airflow event")
    63  			}
    64  		}
    65  	}
    66  	if err != nil {
    67  		logger.Infof("Installing or upgrading %v failed", helmEvent.ChartType)
    68  		return err
    69  	}
    70  
    71  	logger.Infof("Successfully installed or upgraded %v", helmEvent.ChartType)
    72  	return nil
    73  }
    74  
    75  func (c Client) installOrUpgrade(ctx context.Context, helmEvent HelmEventData, logger logger.Logger) (bool, error) {
    76  	helmChart, err := c.createChartWithValues(ctx, helmEvent)
    77  	if err != nil {
    78  		logger.WithError(err).Error("getting chart values")
    79  		return false, err
    80  	}
    81  
    82  	if c.dryRun {
    83  		out, err := yaml.Marshal(helmChart.Values)
    84  		if err != nil {
    85  			logger.WithError(err).Error("marshalling team values")
    86  			return false, err
    87  		}
    88  
    89  		if err = os.WriteFile(fmt.Sprintf("charts/%v-%v.yaml", helmEvent.ChartType, time.Now().Format("2006.01.02-15:04")), out, 0o644); err != nil {
    90  			logger.WithError(err).Error("writing values to file")
    91  			return true, err
    92  		}
    93  
    94  		return false, nil
    95  	}
    96  
    97  	settings := cli.New()
    98  	settings.SetNamespace(helmEvent.Namespace)
    99  	actionConfig := new(action.Configuration)
   100  	if err := actionConfig.Init(settings.RESTClientGetter(), settings.Namespace(), "secret", log.Printf); err != nil {
   101  		logger.WithError(err).Error("action config init")
   102  		return false, err
   103  	}
   104  
   105  	exists, err := releaseExists(actionConfig, helmEvent.ReleaseName)
   106  	if err != nil {
   107  		logger.WithError(err).Error("checking if release exists")
   108  		return false, err
   109  	}
   110  
   111  	if exists {
   112  		upgradeClient := action.NewUpgrade(actionConfig)
   113  		upgradeClient.Namespace = helmEvent.Namespace
   114  		upgradeClient.Timeout = timeout
   115  
   116  		_, err = upgradeClient.RunWithContext(ctx, helmEvent.ReleaseName, helmChart, helmChart.Values)
   117  		if err != nil {
   118  			logger.WithError(err).Error("helm upgrade")
   119  			return true, err
   120  		}
   121  	} else {
   122  		installClient := action.NewInstall(actionConfig)
   123  		installClient.Namespace = helmEvent.Namespace
   124  		installClient.ReleaseName = helmEvent.ReleaseName
   125  		installClient.Timeout = timeout
   126  
   127  		_, err = installClient.RunWithContext(ctx, helmChart, helmChart.Values)
   128  		if err != nil {
   129  			logger.WithError(err).Error("helm install")
   130  			return false, err
   131  		}
   132  	}
   133  
   134  	return false, nil
   135  }
   136  
   137  func (c Client) Uninstall(ctx context.Context, helmEvent HelmEventData, logger logger.Logger) bool {
   138  	logger.Infof("Uninstalling %v", helmEvent.ChartType)
   139  	if err := c.uninstall(ctx, helmEvent, logger); err != nil {
   140  		logger.Infof("Uninstalling %v failed", helmEvent.ChartType)
   141  		return true
   142  	}
   143  
   144  	logger.Infof("Successfully uninstalled %v", helmEvent.ChartType)
   145  	return false
   146  }
   147  
   148  func (c Client) uninstall(ctx context.Context, helmEvent HelmEventData, logger logger.Logger) error {
   149  	if c.dryRun {
   150  		return nil
   151  	}
   152  
   153  	settings := cli.New()
   154  	settings.SetNamespace(helmEvent.Namespace)
   155  	actionConfig := new(action.Configuration)
   156  	if err := actionConfig.Init(settings.RESTClientGetter(), settings.Namespace(), "secret", log.Printf); err != nil {
   157  		logger.WithError(err).Errorf("creating action config for helm uninstall: release %v, team %v", helmEvent.TeamID, helmEvent.TeamID)
   158  		return err
   159  	}
   160  
   161  	exists, err := releaseExists(actionConfig, helmEvent.ReleaseName)
   162  	if err != nil {
   163  		logger.WithError(err).Errorf("checking if release exists for helm uninstall: release %v, team %v", helmEvent.TeamID, helmEvent.TeamID)
   164  		return err
   165  	}
   166  
   167  	if !exists {
   168  		return nil
   169  	}
   170  
   171  	uninstallClient := action.NewUninstall(actionConfig)
   172  	_, err = uninstallClient.Run(helmEvent.ReleaseName)
   173  	if err != nil {
   174  		logger.WithError(err).Errorf("helm uninstall: release %v, team %v", helmEvent.TeamID, helmEvent.TeamID)
   175  		return err
   176  	}
   177  
   178  	return nil
   179  }
   180  
   181  func (c Client) Rollback(ctx context.Context, helmEvent HelmEventData, logger logger.Logger) (bool, error) {
   182  	logger.Infof("Rolling back %v", helmEvent.ChartType)
   183  	retry, err := c.rollback(ctx, helmEvent, logger)
   184  	if retry || err != nil {
   185  		logger.Infof("Rolling back %v failed", helmEvent.ChartType)
   186  		return retry, err
   187  	}
   188  
   189  	logger.Infof("Successfully rolled back %v", helmEvent.ChartType)
   190  	return false, nil
   191  }
   192  
   193  func (c Client) rollback(ctx context.Context, helmEvent HelmEventData, logger logger.Logger) (bool, error) {
   194  	if c.dryRun {
   195  		return false, nil
   196  	}
   197  
   198  	settings := cli.New()
   199  	settings.SetNamespace(helmEvent.Namespace)
   200  	actionConfig := new(action.Configuration)
   201  	if err := actionConfig.Init(settings.RESTClientGetter(), settings.Namespace(), "secret", log.Printf); err != nil {
   202  		logger.WithError(err).Error("action config init")
   203  		return true, nil
   204  	}
   205  
   206  	version, err := lastSuccessfulHelmRelease(helmEvent.ReleaseName, actionConfig)
   207  	if err != nil {
   208  		logger.WithError(err).Errorf("unable to rollback chart %v for team %v", helmEvent.ChartName, helmEvent.TeamID)
   209  		return false, err
   210  	}
   211  
   212  	rollbackClient := action.NewRollback(actionConfig)
   213  	rollbackClient.Version = version
   214  	if err := rollbackClient.Run(helmEvent.ReleaseName); err != nil {
   215  		logger.WithError(err).Errorf("rolling back release %v for team %v to version %v", helmEvent.ReleaseName, helmEvent.TeamID, version)
   216  		return true, nil
   217  	}
   218  
   219  	return false, nil
   220  }
   221  
   222  func lastSuccessfulHelmRelease(releaseName string, actionConfig *action.Configuration) (int, error) {
   223  	historyClient := action.NewHistory(actionConfig)
   224  
   225  	releases, err := historyClient.Run(releaseName)
   226  	if err != nil {
   227  		return 0, err
   228  	}
   229  
   230  	validStatuses := []string{release.StatusDeployed.String(), release.StatusSuperseded.String()}
   231  	for i := len(releases) - 1; i >= 0; i-- {
   232  		if slices.Contains(validStatuses, releases[i].Info.Status.String()) {
   233  			return releases[i].Version, nil
   234  		}
   235  	}
   236  
   237  	return 0, fmt.Errorf("no previous successful helm releases for %v", releaseName)
   238  }
   239  
   240  func (c Client) createChartWithValues(ctx context.Context, helmEvent HelmEventData) (*chart.Chart, error) {
   241  	helmChart, err := FetchChart(helmEvent.ChartRepo, helmEvent.ChartName, helmEvent.ChartVersion)
   242  	if err != nil {
   243  		return nil, err
   244  	}
   245  
   246  	err = c.mergeValues(ctx, helmEvent.ChartType, helmEvent.TeamID, helmChart.Values)
   247  	if err != nil {
   248  		return nil, err
   249  	}
   250  
   251  	return helmChart, nil
   252  }
   253  
   254  func (c Client) mergeValues(ctx context.Context, chartType gensql.ChartType, teamID string, defaultValues map[string]any) error {
   255  	values, err := c.globalValues(ctx, chartType)
   256  	if err != nil {
   257  		return err
   258  	}
   259  
   260  	err = c.enrichWithTeamValues(ctx, chartType, teamID, values)
   261  	if err != nil {
   262  		return err
   263  	}
   264  
   265  	switch chartType {
   266  	case gensql.ChartTypeJupyterhub:
   267  		if err := c.concatenateImageProfiles(ctx, teamID, values); err != nil {
   268  			return err
   269  		}
   270  	case gensql.ChartTypeAirflow:
   271  		knauditInitContainer, err := c.createKnauditInitContainer(ctx)
   272  		if err != nil {
   273  			return err
   274  		}
   275  		mergeMaps(values, knauditInitContainer)
   276  	}
   277  
   278  	mergeMaps(defaultValues, values)
   279  	return nil
   280  }
   281  
   282  func (c Client) globalValues(ctx context.Context, chartType gensql.ChartType) (map[string]any, error) {
   283  	dbValues, err := c.repo.GlobalValuesGet(ctx, chartType)
   284  	if err != nil {
   285  		return map[string]any{}, err
   286  	}
   287  
   288  	values := map[string]any{}
   289  	for _, v := range dbValues {
   290  		if v.Encrypted {
   291  			v.Value, err = c.repo.DecryptValue(v.Value)
   292  			if err != nil {
   293  				return nil, err
   294  			}
   295  		}
   296  
   297  		keys := keySplitHandleEscape(v.Key)
   298  		value, err := ParseValue(v.Value)
   299  		if err != nil {
   300  			return nil, err
   301  		}
   302  		SetChartValue(keys, value, values)
   303  	}
   304  
   305  	return values, nil
   306  }
   307  
   308  func (c Client) enrichWithTeamValues(ctx context.Context, chartType gensql.ChartType, teamID string, values map[string]any) error {
   309  	dbValues, err := c.repo.TeamValuesGet(ctx, chartType, teamID)
   310  	if err != nil {
   311  		return err
   312  	}
   313  
   314  	for _, v := range dbValues {
   315  		if slices.Contains([]string{"fernetKey", "webserverSecretKey"}, v.Key) {
   316  			continue
   317  		}
   318  
   319  		_, err = parseTeamValue(v.Key, v.Value, values)
   320  		if err != nil {
   321  			return err
   322  		}
   323  	}
   324  
   325  	return nil
   326  }
   327  
   328  func parseKey(key string) (string, []string) {
   329  	opts := strings.Split(key, ",")
   330  	return opts[0], opts[1:]
   331  }
   332  
   333  func parseTeamValue(key string, value any, values map[string]any) (any, error) {
   334  	key, opts := parseKey(key)
   335  	if slices.Contains(opts, "omit") {
   336  		return nil, nil
   337  	}
   338  
   339  	keys := keySplitHandleEscape(key)
   340  	value, err := ParseValue(value)
   341  	if err != nil {
   342  		return nil, err
   343  	}
   344  	SetChartValue(keys, value, values)
   345  
   346  	return values, nil
   347  }
   348  
   349  func mergeMaps(base, custom map[string]any) map[string]any {
   350  	for k, v := range custom {
   351  		if _, ok := v.(map[string]any); ok {
   352  			if _, ok := base[k].(map[string]any); !ok {
   353  				base[k] = map[string]any{}
   354  			}
   355  			base[k] = mergeMaps(base[k].(map[string]any), v.(map[string]any))
   356  			continue
   357  		}
   358  		base[k] = v
   359  	}
   360  	return base
   361  }
   362  
   363  func releaseExists(actionConfig *action.Configuration, releaseName string) (bool, error) {
   364  	listClient := action.NewList(actionConfig)
   365  	listClient.Deployed = true
   366  	results, err := listClient.Run()
   367  	if err != nil {
   368  		return false, err
   369  	}
   370  
   371  	for _, r := range results {
   372  		if r.Name == releaseName {
   373  			return true, nil
   374  		}
   375  	}
   376  
   377  	return false, nil
   378  }
   379  
   380  func keySplitHandleEscape(key string) []string {
   381  	escape := false
   382  	keys := strings.FieldsFunc(key, func(r rune) bool {
   383  		if r == '\\' {
   384  			escape = true
   385  		} else if escape {
   386  			escape = false
   387  			return false
   388  		}
   389  		return r == '.'
   390  	})
   391  
   392  	var keysWithoutEscape []string
   393  	for _, k := range keys {
   394  		keysWithoutEscape = append(keysWithoutEscape, strings.ReplaceAll(k, "\\", ""))
   395  	}
   396  
   397  	return keysWithoutEscape
   398  }
   399  
   400  func SetChartValue(keys []string, value any, chart map[string]any) {
   401  	key := keys[0]
   402  	if len(keys) > 1 {
   403  		if _, ok := chart[key].(map[string]any); !ok {
   404  			chart[key] = map[string]any{}
   405  		}
   406  		SetChartValue(keys[1:], value, chart[key].(map[string]any))
   407  		return
   408  	}
   409  
   410  	chart[key] = value
   411  }
   412  
   413  func ParseValue(value any) (any, error) {
   414  	var err error
   415  
   416  	switch v := value.(type) {
   417  	case string:
   418  		value, err = ParseString(v)
   419  		if err != nil {
   420  			return nil, fmt.Errorf("failed parsing value %v: %v", v, err)
   421  		}
   422  	default:
   423  		value = v
   424  	}
   425  
   426  	return value, nil
   427  }
   428  
   429  func ParseString(value any) (any, error) {
   430  	valueString := value.(string)
   431  
   432  	if d, err := strconv.ParseBool(valueString); err == nil {
   433  		return d, nil
   434  	} else if d, err := strconv.ParseInt(valueString, 10, 64); err == nil {
   435  		return d, nil
   436  	} else if d, err := strconv.ParseFloat(valueString, 64); err == nil {
   437  		return d, nil
   438  	} else if strings.HasPrefix(value.(string), "[") || strings.HasPrefix(value.(string), "{") {
   439  		var d any
   440  		if err := json.Unmarshal([]byte(valueString), &d); err != nil {
   441  			return nil, err
   442  		}
   443  		return d, nil
   444  	}
   445  
   446  	return removeQuotations(valueString), nil
   447  }
   448  
   449  func removeQuotations(s string) string {
   450  	s = strings.TrimPrefix(s, "\"")
   451  	return strings.TrimSuffix(s, "\"")
   452  }