github.com/nais/knorten@v0.0.0-20240104110906-55926958e361/pkg/api/chart.go (about)

     1  package api
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"encoding/json"
     7  	"errors"
     8  	"fmt"
     9  	"net/http"
    10  	"strconv"
    11  	"strings"
    12  
    13  	"github.com/gin-contrib/sessions"
    14  	"github.com/gin-gonic/gin"
    15  	"github.com/gin-gonic/gin/binding"
    16  	"github.com/go-playground/validator/v10"
    17  	"github.com/nais/knorten/pkg/chart"
    18  	"github.com/nais/knorten/pkg/database/gensql"
    19  )
    20  
    21  type jupyterForm struct {
    22  	CPU         string   `form:"cpu"`
    23  	Memory      string   `form:"memory"`
    24  	ImageName   string   `form:"imagename"`
    25  	ImageTag    string   `form:"imagetag"`
    26  	CullTimeout string   `form:"culltimeout"`
    27  	Allowlist   []string `form:"allowlist[]"`
    28  }
    29  
    30  func (v jupyterForm) MemoryWithoutUnit() string {
    31  	if v.Memory == "" {
    32  		return ""
    33  	}
    34  
    35  	return v.Memory[:len(v.Memory)-1]
    36  }
    37  
    38  type airflowForm struct {
    39  	DagRepo       string `form:"dagrepo" binding:"required,startswith=navikt/,validAirflowRepo"`
    40  	DagRepoBranch string `form:"dagrepobranch" binding:"validRepoBranch"`
    41  	AirflowImage  string `form:"airflowimage" binding:"validAirflowImage"`
    42  	ApiAccess     string `form:"apiaccess"`
    43  }
    44  
    45  func getChartType(chartType string) gensql.ChartType {
    46  	switch chartType {
    47  	case string(gensql.ChartTypeJupyterhub):
    48  		return gensql.ChartTypeJupyterhub
    49  	case string(gensql.ChartTypeAirflow):
    50  		return gensql.ChartTypeAirflow
    51  	default:
    52  		return ""
    53  	}
    54  }
    55  
    56  func descriptiveMessageForChartError(fieldError validator.FieldError) string {
    57  	switch fieldError.Tag() {
    58  	case "required":
    59  		return fmt.Sprintf("%v er et påkrevd felt", fieldError.Field())
    60  	case "startswith":
    61  		return fmt.Sprintf("%v må starte med 'navikt/'", fieldError.Field())
    62  	default:
    63  		return fieldError.Error()
    64  	}
    65  }
    66  
    67  func (c *client) setupChartRoutes() {
    68  	if v, ok := binding.Validator.Engine().(*validator.Validate); ok {
    69  		err := v.RegisterValidation("validAirflowRepo", chart.ValidateAirflowRepo)
    70  		if err != nil {
    71  			c.log.WithError(err).Error("can't register validator")
    72  			return
    73  		}
    74  	}
    75  
    76  	if v, ok := binding.Validator.Engine().(*validator.Validate); ok {
    77  		err := v.RegisterValidation("validRepoBranch", chart.ValidateRepoBranch)
    78  		if err != nil {
    79  			c.log.WithError(err).Error("can't register validator")
    80  			return
    81  		}
    82  	}
    83  
    84  	if v, ok := binding.Validator.Engine().(*validator.Validate); ok {
    85  		err := v.RegisterValidation("validAirflowImage", chart.ValidateAirflowImage)
    86  		if err != nil {
    87  			c.log.WithError(err).Error("can't register validator")
    88  			return
    89  		}
    90  	}
    91  
    92  	c.router.GET("/team/:slug/:chart/new", func(ctx *gin.Context) {
    93  		slug := ctx.Param("slug")
    94  		chartType := getChartType(ctx.Param("chart"))
    95  
    96  		var form any
    97  		switch chartType {
    98  		case gensql.ChartTypeJupyterhub:
    99  			form = jupyterForm{}
   100  		case gensql.ChartTypeAirflow:
   101  			form = airflowForm{}
   102  		default:
   103  			ctx.JSON(http.StatusBadRequest, map[string]string{
   104  				"status":  strconv.Itoa(http.StatusBadRequest),
   105  				"message": fmt.Sprintf("Chart type %v is not supported", chartType),
   106  			})
   107  			return
   108  		}
   109  
   110  		session := sessions.Default(ctx)
   111  		flashes := session.Flashes()
   112  		err := session.Save()
   113  		if err != nil {
   114  			c.log.WithField("team", slug).WithField("chart", chartType).WithError(err).Error("problem saving session")
   115  			ctx.JSON(http.StatusInternalServerError, map[string]string{
   116  				"status":  strconv.Itoa(http.StatusInternalServerError),
   117  				"message": "Internal server error",
   118  			})
   119  			return
   120  		}
   121  
   122  		c.htmlResponseWrapper(ctx, http.StatusOK, fmt.Sprintf("charts/%v", chartType), gin.H{
   123  			"team":   slug,
   124  			"form":   form,
   125  			"errors": flashes,
   126  		})
   127  	})
   128  
   129  	c.router.POST("/team/:slug/:chart/new", func(ctx *gin.Context) {
   130  		slug := ctx.Param("slug")
   131  		chartType := getChartType(ctx.Param("chart"))
   132  		log := c.log.WithField("team", slug).WithField("chart", chartType)
   133  
   134  		err := c.newChart(ctx, slug, chartType)
   135  		if err != nil {
   136  			session := sessions.Default(ctx)
   137  			var validationErrorse validator.ValidationErrors
   138  			if errors.As(err, &validationErrorse) {
   139  				for _, fieldError := range validationErrorse {
   140  					log.WithError(err).Infof("field error: %v", fieldError)
   141  					session.AddFlash(descriptiveMessageForChartError(fieldError))
   142  				}
   143  			} else {
   144  				log.WithError(err).Info("non-field error")
   145  				session.AddFlash(err.Error())
   146  			}
   147  
   148  			err := session.Save()
   149  			if err != nil {
   150  				log.WithError(err).Error("problem saving session")
   151  				ctx.Redirect(http.StatusSeeOther, fmt.Sprintf("/team/%v/%v/new", slug, chartType))
   152  				return
   153  			}
   154  
   155  			ctx.Redirect(http.StatusSeeOther, fmt.Sprintf("/team/%v/%v/new", slug, chartType))
   156  			return
   157  		}
   158  
   159  		ctx.Redirect(http.StatusSeeOther, "/oversikt")
   160  	})
   161  
   162  	c.router.GET("/team/:slug/:chart/edit", func(ctx *gin.Context) {
   163  		teamSlug := ctx.Param("slug")
   164  		chartType := getChartType(ctx.Param("chart"))
   165  		log := c.log.WithField("team", teamSlug).WithField("chart", chartType)
   166  
   167  		session := sessions.Default(ctx)
   168  
   169  		form, err := c.getEditChart(ctx, teamSlug, chartType)
   170  		if err != nil {
   171  			var validationErrorse validator.ValidationErrors
   172  			if errors.As(err, &validationErrorse) {
   173  				for _, fieldError := range validationErrorse {
   174  					log.WithError(err).Infof("field error: %v", fieldError)
   175  					session.AddFlash(descriptiveMessageForChartError(fieldError))
   176  				}
   177  			} else {
   178  				log.WithError(err).Info("non-field error")
   179  				session.AddFlash(err.Error())
   180  			}
   181  
   182  			err := session.Save()
   183  			if err != nil {
   184  				log.WithError(err).Error("problem saving session")
   185  				ctx.Redirect(http.StatusSeeOther, "/oversikt")
   186  				return
   187  			}
   188  			ctx.Redirect(http.StatusSeeOther, "/oversikt")
   189  			return
   190  		}
   191  
   192  		flashes := session.Flashes()
   193  		err = session.Save()
   194  		if err != nil {
   195  			log.WithError(err).Error("problem saving session")
   196  			return
   197  		}
   198  
   199  		c.htmlResponseWrapper(ctx, http.StatusOK, fmt.Sprintf("charts/%v", chartType), gin.H{
   200  			"team":   teamSlug,
   201  			"values": form,
   202  			"errors": flashes,
   203  		})
   204  	})
   205  
   206  	c.router.POST("/team/:slug/:chart/edit", func(ctx *gin.Context) {
   207  		teamSlug := ctx.Param("slug")
   208  		chartType := getChartType(ctx.Param("chart"))
   209  		log := c.log.WithField("team", teamSlug).WithField("chart", chartType)
   210  
   211  		err := c.editChart(ctx, teamSlug, chartType)
   212  		if err != nil {
   213  			session := sessions.Default(ctx)
   214  			var validationErrorse validator.ValidationErrors
   215  			if errors.As(err, &validationErrorse) {
   216  				for _, fieldError := range validationErrorse {
   217  					log.WithError(err).Infof("field error: %v", fieldError)
   218  					session.AddFlash(descriptiveMessageForChartError(fieldError))
   219  				}
   220  			} else {
   221  				log.WithError(err).Info("non-field error")
   222  				session.AddFlash(err.Error())
   223  			}
   224  
   225  			err := session.Save()
   226  			if err != nil {
   227  				log.WithError(err).Error("problem saving session")
   228  				ctx.Redirect(http.StatusSeeOther, fmt.Sprintf("/team/%v/%v/edit", teamSlug, chartType))
   229  				return
   230  			}
   231  
   232  			ctx.Redirect(http.StatusSeeOther, fmt.Sprintf("/team/%v/%v/edit", teamSlug, chartType))
   233  			return
   234  		}
   235  
   236  		ctx.Redirect(http.StatusSeeOther, "/oversikt")
   237  	})
   238  
   239  	c.router.POST("/team/:slug/:chart/delete", func(ctx *gin.Context) {
   240  		teamSlug := ctx.Param("slug")
   241  		chartTypeString := ctx.Param("chart")
   242  		log := c.log.WithField("team", teamSlug).WithField("chart", chartTypeString)
   243  
   244  		err := c.deleteChart(ctx, teamSlug, chartTypeString)
   245  		if err != nil {
   246  			log.WithError(err).Errorf("problem deleting chart %v for team %v", chartTypeString, teamSlug)
   247  			session := sessions.Default(ctx)
   248  			session.AddFlash(err.Error())
   249  			err := session.Save()
   250  			if err != nil {
   251  				log.WithError(err).Error("problem saving session")
   252  			}
   253  		}
   254  
   255  		ctx.Redirect(http.StatusSeeOther, "/oversikt")
   256  	})
   257  }
   258  
   259  func (c *client) getExistingAllowlist(ctx context.Context, teamID string) ([]string, error) {
   260  	extraAnnotations, err := c.repo.TeamValueGet(ctx, "singleuser.extraAnnotations", teamID)
   261  	if err != nil {
   262  		if errors.Is(err, sql.ErrNoRows) {
   263  			return []string{}, nil
   264  		}
   265  		return nil, err
   266  	}
   267  
   268  	var annotations map[string]string
   269  	if err := json.Unmarshal([]byte(extraAnnotations.Value), &annotations); err != nil {
   270  		return nil, err
   271  	}
   272  
   273  	for k, v := range annotations {
   274  		if k == "allowlist" {
   275  			return strings.Split(v, ","), nil
   276  		}
   277  	}
   278  
   279  	return []string{}, nil
   280  }
   281  
   282  func (c *client) newChart(ctx *gin.Context, teamSlug string, chartType gensql.ChartType) error {
   283  	team, err := c.repo.TeamBySlugGet(ctx, teamSlug)
   284  	if err != nil {
   285  		return err
   286  	}
   287  
   288  	switch chartType {
   289  	case gensql.ChartTypeJupyterhub:
   290  		var form jupyterForm
   291  		err := ctx.ShouldBindWith(&form, binding.Form)
   292  		if err != nil {
   293  			return err
   294  		}
   295  
   296  		cullTimeout, err := strconv.ParseUint(form.CullTimeout, 10, 64)
   297  		if err != nil {
   298  			return err
   299  		}
   300  
   301  		userIdents, err := c.azureClient.ConvertEmailsToIdents(team.Users)
   302  		if err != nil {
   303  			return err
   304  		}
   305  
   306  		cpu, err := parseCPU(form.CPU)
   307  		if err != nil {
   308  			return err
   309  		}
   310  
   311  		memory, err := parseMemory(form.Memory)
   312  		if err != nil {
   313  			return err
   314  		}
   315  
   316  		values := chart.JupyterConfigurableValues{
   317  			TeamID:      team.ID,
   318  			UserIdents:  userIdents,
   319  			CPU:         cpu,
   320  			Memory:      memory,
   321  			ImageName:   form.ImageName,
   322  			ImageTag:    form.ImageTag,
   323  			CullTimeout: strconv.FormatUint(cullTimeout, 10),
   324  			AllowList:   removeEmptySliceElements(form.Allowlist),
   325  		}
   326  
   327  		return c.repo.RegisterCreateJupyterEvent(ctx, team.ID, values)
   328  	case gensql.ChartTypeAirflow:
   329  		var form airflowForm
   330  		err := ctx.ShouldBindWith(&form, binding.Form)
   331  		if err != nil {
   332  			return err
   333  		}
   334  
   335  		dagRepoBranch := form.DagRepoBranch
   336  		if dagRepoBranch == "" {
   337  			dagRepoBranch = "main"
   338  		}
   339  
   340  		airflowImage := ""
   341  		airflowTag := ""
   342  		if form.AirflowImage != "" {
   343  			imageParts := strings.Split(form.AirflowImage, ":")
   344  			airflowImage = imageParts[0]
   345  			airflowTag = imageParts[1]
   346  		}
   347  
   348  		values := chart.AirflowConfigurableValues{
   349  			TeamID:        team.ID,
   350  			DagRepo:       form.DagRepo,
   351  			DagRepoBranch: dagRepoBranch,
   352  			ApiAccess:     form.ApiAccess == "on",
   353  			AirflowImage:  airflowImage,
   354  			AirflowTag:    airflowTag,
   355  		}
   356  
   357  		return c.repo.RegisterCreateAirflowEvent(ctx, team.ID, values)
   358  	}
   359  
   360  	return fmt.Errorf("chart type %v is not supported", chartType)
   361  }
   362  
   363  func (c *client) getEditChart(ctx *gin.Context, teamSlug string, chartType gensql.ChartType) (any, error) {
   364  	team, err := c.repo.TeamBySlugGet(ctx, teamSlug)
   365  	if err != nil {
   366  		return nil, err
   367  	}
   368  
   369  	var chartObjects any
   370  	switch chartType {
   371  	case gensql.ChartTypeJupyterhub:
   372  		chartObjects = &chart.JupyterConfigurableValues{}
   373  	case gensql.ChartTypeAirflow:
   374  		chartObjects = &chart.AirflowConfigurableValues{}
   375  	default:
   376  		return nil, fmt.Errorf("chart type %v is not supported", chartType)
   377  	}
   378  
   379  	err = c.repo.TeamConfigurableValuesGet(ctx, chartType, team.ID, chartObjects)
   380  	if err != nil {
   381  		return nil, err
   382  	}
   383  
   384  	var form any
   385  	switch chartType {
   386  	case gensql.ChartTypeJupyterhub:
   387  		jupyterhubValues := chartObjects.(*chart.JupyterConfigurableValues)
   388  		allowlist, err := c.getExistingAllowlist(ctx, team.ID)
   389  		if err != nil && !errors.Is(err, sql.ErrNoRows) {
   390  			return nil, err
   391  		}
   392  
   393  		form = jupyterForm{
   394  			CPU:         jupyterhubValues.CPU,
   395  			Memory:      jupyterhubValues.Memory,
   396  			ImageName:   jupyterhubValues.ImageName,
   397  			ImageTag:    jupyterhubValues.ImageTag,
   398  			CullTimeout: jupyterhubValues.CullTimeout,
   399  			Allowlist:   allowlist,
   400  		}
   401  	case gensql.ChartTypeAirflow:
   402  		airflowValues := chartObjects.(*chart.AirflowConfigurableValues)
   403  		apiAccessTeamValue, err := c.repo.TeamValueGet(ctx, chart.TeamValueKeyApiAccess, team.ID)
   404  		if err != nil && !errors.Is(err, sql.ErrNoRows) {
   405  			return nil, err
   406  		}
   407  
   408  		apiAccess := ""
   409  		if apiAccessTeamValue.Value == "true" {
   410  			apiAccess = "on"
   411  		}
   412  
   413  		airflowImage := ""
   414  		if airflowValues.AirflowImage != "" && airflowValues.AirflowTag != "" {
   415  			airflowImage = fmt.Sprintf("%v:%v", airflowValues.AirflowImage, airflowValues.AirflowTag)
   416  		}
   417  
   418  		form = airflowForm{
   419  			DagRepo:       airflowValues.DagRepo,
   420  			DagRepoBranch: airflowValues.DagRepoBranch,
   421  			ApiAccess:     apiAccess,
   422  			AirflowImage:  airflowImage,
   423  		}
   424  	}
   425  
   426  	return form, nil
   427  }
   428  
   429  func (c *client) editChart(ctx *gin.Context, teamSlug string, chartType gensql.ChartType) error {
   430  	team, err := c.repo.TeamBySlugGet(ctx, teamSlug)
   431  	if err != nil {
   432  		return err
   433  	}
   434  
   435  	switch chartType {
   436  	case gensql.ChartTypeJupyterhub:
   437  		var form jupyterForm
   438  		err := ctx.ShouldBindWith(&form, binding.Form)
   439  		if err != nil {
   440  			return err
   441  		}
   442  
   443  		userIdents, err := c.azureClient.ConvertEmailsToIdents(team.Users)
   444  		if err != nil {
   445  			return err
   446  		}
   447  
   448  		cpu, err := parseCPU(form.CPU)
   449  		if err != nil {
   450  			return err
   451  		}
   452  
   453  		memory, err := parseMemory(form.Memory)
   454  		if err != nil {
   455  			return err
   456  		}
   457  
   458  		values := chart.JupyterConfigurableValues{
   459  			TeamID:      team.ID,
   460  			UserIdents:  userIdents,
   461  			CPU:         cpu,
   462  			Memory:      memory,
   463  			ImageName:   form.ImageName,
   464  			ImageTag:    form.ImageTag,
   465  			CullTimeout: form.CullTimeout,
   466  			AllowList:   removeEmptySliceElements(form.Allowlist),
   467  		}
   468  
   469  		return c.repo.RegisterUpdateJupyterEvent(ctx, team.ID, values)
   470  	case gensql.ChartTypeAirflow:
   471  		var form airflowForm
   472  		err := ctx.ShouldBindWith(&form, binding.Form)
   473  		if err != nil {
   474  			return err
   475  		}
   476  
   477  		dagRepoBranch := form.DagRepoBranch
   478  		if dagRepoBranch == "" {
   479  			dagRepoBranch = "main"
   480  		}
   481  
   482  		airflowImage := ""
   483  		airflowTag := ""
   484  		if form.AirflowImage != "" {
   485  			imageParts := strings.Split(form.AirflowImage, ":")
   486  			airflowImage = imageParts[0]
   487  			airflowTag = imageParts[1]
   488  		}
   489  
   490  		values := chart.AirflowConfigurableValues{
   491  			TeamID:        team.ID,
   492  			DagRepo:       form.DagRepo,
   493  			DagRepoBranch: dagRepoBranch,
   494  			ApiAccess:     form.ApiAccess == "on",
   495  			AirflowImage:  airflowImage,
   496  			AirflowTag:    airflowTag,
   497  		}
   498  
   499  		return c.repo.RegisterUpdateAirflowEvent(ctx, team.ID, values)
   500  	}
   501  
   502  	return fmt.Errorf("chart type %v is not supported", chartType)
   503  }
   504  
   505  func (c *client) deleteChart(ctx *gin.Context, teamSlug, chartTypeString string) error {
   506  	team, err := c.repo.TeamBySlugGet(ctx, teamSlug)
   507  	if err != nil {
   508  		return err
   509  	}
   510  
   511  	switch getChartType(chartTypeString) {
   512  	case gensql.ChartTypeJupyterhub:
   513  		return c.repo.RegisterDeleteJupyterEvent(ctx, team.ID)
   514  	case gensql.ChartTypeAirflow:
   515  		return c.repo.RegisterDeleteAirflowEvent(ctx, team.ID)
   516  	}
   517  
   518  	return fmt.Errorf("chart type %v is not supported", chartTypeString)
   519  }
   520  
   521  func parseCPU(cpu string) (string, error) {
   522  	floatVal, err := strconv.ParseFloat(cpu, 64)
   523  	if err != nil {
   524  		return "", err
   525  	}
   526  
   527  	return fmt.Sprintf("%.1f", floatVal), nil
   528  }
   529  
   530  func parseMemory(memory string) (string, error) {
   531  	if strings.HasSuffix(memory, "G") {
   532  		return memory, nil
   533  	}
   534  	_, err := strconv.ParseFloat(memory, 64)
   535  	if err != nil {
   536  		return "", err
   537  	}
   538  	return memory + "G", nil
   539  }