github.com/navikt/knorten@v0.0.0-20240419132333-1333f46ed8b6/pkg/api/chart.go (about)

     1  package api
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"encoding/json"
     7  	"errors"
     8  	"fmt"
     9  	"net/http"
    10  	"strconv"
    11  	"strings"
    12  
    13  	"github.com/navikt/knorten/pkg/api/middlewares"
    14  
    15  	"github.com/gin-contrib/sessions"
    16  	"github.com/gin-gonic/gin"
    17  	"github.com/gin-gonic/gin/binding"
    18  	"github.com/go-playground/validator/v10"
    19  	"github.com/navikt/knorten/pkg/chart"
    20  	"github.com/navikt/knorten/pkg/database/gensql"
    21  )
    22  
    23  type jupyterForm struct {
    24  	CPULimit      string   `form:"cpulimit" binding:"validCPUSpec"`
    25  	CPURequest    string   `form:"cpurequest" binding:"validCPUSpec"`
    26  	MemoryLimit   string   `form:"memorylimit" binding:"validMemorySpec"`
    27  	MemoryRequest string   `form:"memoryrequest" binding:"validMemorySpec"`
    28  	ImageName     string   `form:"imagename"`
    29  	ImageTag      string   `form:"imagetag"`
    30  	CullTimeout   string   `form:"culltimeout"`
    31  	PYPIAccess    string   `form:"pypiaccess"`
    32  	Allowlist     []string `form:"allowlist[]"`
    33  }
    34  
    35  func (v jupyterForm) MemoryLimitWithoutUnit() string {
    36  	if v.MemoryLimit == "" {
    37  		return ""
    38  	}
    39  
    40  	return v.MemoryLimit[:len(v.MemoryLimit)-1]
    41  }
    42  
    43  func (v jupyterForm) MemoryRequestWithoutUnit() string {
    44  	if v.MemoryRequest == "" {
    45  		return ""
    46  	}
    47  
    48  	return v.MemoryRequest[:len(v.MemoryRequest)-1]
    49  }
    50  
    51  type airflowForm struct {
    52  	DagRepo       string `form:"dagrepo" binding:"required,startswith=navikt/,validAirflowRepo"`
    53  	DagRepoBranch string `form:"dagrepobranch" binding:"validRepoBranch"`
    54  	AirflowImage  string `form:"airflowimage" binding:"validAirflowImage"`
    55  	ApiAccess     string `form:"apiaccess"`
    56  }
    57  
    58  func getChartType(chartType string) gensql.ChartType {
    59  	switch chartType {
    60  	case string(gensql.ChartTypeJupyterhub):
    61  		return gensql.ChartTypeJupyterhub
    62  	case string(gensql.ChartTypeAirflow):
    63  		return gensql.ChartTypeAirflow
    64  	default:
    65  		return ""
    66  	}
    67  }
    68  
    69  func descriptiveMessageForChartError(fieldError validator.FieldError) string {
    70  	switch fieldError.Tag() {
    71  	case "required":
    72  		return fmt.Sprintf("%v er et påkrevd felt", fieldError.Field())
    73  	case "startswith":
    74  		return fmt.Sprintf("%v må starte med 'navikt/'", fieldError.Field())
    75  	default:
    76  		return fieldError.Error()
    77  	}
    78  }
    79  
    80  func (c *client) setupChartRoutes() {
    81  	if v, ok := binding.Validator.Engine().(*validator.Validate); ok {
    82  		err := v.RegisterValidation("validAirflowRepo", chart.ValidateAirflowRepo)
    83  		if err != nil {
    84  			c.log.WithError(err).Error("can't register validator")
    85  			return
    86  		}
    87  	}
    88  
    89  	if v, ok := binding.Validator.Engine().(*validator.Validate); ok {
    90  		err := v.RegisterValidation("validRepoBranch", chart.ValidateRepoBranch)
    91  		if err != nil {
    92  			c.log.WithError(err).Error("can't register validator")
    93  			return
    94  		}
    95  	}
    96  
    97  	if v, ok := binding.Validator.Engine().(*validator.Validate); ok {
    98  		err := v.RegisterValidation("validAirflowImage", chart.ValidateAirflowImage)
    99  		if err != nil {
   100  			c.log.WithError(err).Error("can't register validator")
   101  			return
   102  		}
   103  	}
   104  
   105  	if v, ok := binding.Validator.Engine().(*validator.Validate); ok {
   106  		err := v.RegisterValidation("validCPUSpec", chart.ValidateCPUSpec)
   107  		if err != nil {
   108  			c.log.WithError(err).Error("can't register validator")
   109  			return
   110  		}
   111  	}
   112  
   113  	if v, ok := binding.Validator.Engine().(*validator.Validate); ok {
   114  		err := v.RegisterValidation("validMemorySpec", chart.ValidateMemorySpec)
   115  		if err != nil {
   116  			c.log.WithError(err).Error("can't register validator")
   117  			return
   118  		}
   119  	}
   120  
   121  	c.router.GET("/team/:slug/:chart/new", func(ctx *gin.Context) {
   122  		slug := ctx.Param("slug")
   123  		chartType := getChartType(ctx.Param("chart"))
   124  
   125  		var form any
   126  		switch chartType {
   127  		case gensql.ChartTypeJupyterhub:
   128  			form = jupyterForm{}
   129  		case gensql.ChartTypeAirflow:
   130  			form = airflowForm{}
   131  		default:
   132  			ctx.JSON(http.StatusBadRequest, map[string]string{
   133  				"status":  strconv.Itoa(http.StatusBadRequest),
   134  				"message": fmt.Sprintf("Chart type %v is not supported", chartType),
   135  			})
   136  			return
   137  		}
   138  
   139  		session := sessions.Default(ctx)
   140  		flashes := session.Flashes()
   141  		err := session.Save()
   142  		if err != nil {
   143  			c.log.WithField("team", slug).WithField("chart", chartType).WithError(err).Error("problem saving session")
   144  			ctx.JSON(http.StatusInternalServerError, map[string]string{
   145  				"status":  strconv.Itoa(http.StatusInternalServerError),
   146  				"message": "Internal server error",
   147  			})
   148  			return
   149  		}
   150  
   151  		ctx.HTML(http.StatusOK, fmt.Sprintf("charts/%v", chartType), gin.H{
   152  			"team":     slug,
   153  			"form":     form,
   154  			"errors":   flashes,
   155  			"loggedIn": ctx.GetBool(middlewares.LoggedInKey),
   156  			"isAdmin":  ctx.GetBool(middlewares.AdminKey),
   157  		})
   158  	})
   159  
   160  	c.router.POST("/team/:slug/:chart/new", func(ctx *gin.Context) {
   161  		slug := ctx.Param("slug")
   162  		chartType := getChartType(ctx.Param("chart"))
   163  		log := c.log.WithField("team", slug).WithField("chart", chartType)
   164  
   165  		err := c.newChart(ctx, slug, chartType)
   166  		if err != nil {
   167  			session := sessions.Default(ctx)
   168  			var validationErrorse validator.ValidationErrors
   169  			if errors.As(err, &validationErrorse) {
   170  				for _, fieldError := range validationErrorse {
   171  					log.WithError(err).Infof("field error: %v", fieldError)
   172  					session.AddFlash(descriptiveMessageForChartError(fieldError))
   173  				}
   174  			} else {
   175  				log.WithError(err).Info("non-field error")
   176  				session.AddFlash(err.Error())
   177  			}
   178  
   179  			err := session.Save()
   180  			if err != nil {
   181  				log.WithError(err).Error("problem saving session")
   182  				ctx.Redirect(http.StatusSeeOther, fmt.Sprintf("/team/%v/%v/new", slug, chartType))
   183  				return
   184  			}
   185  
   186  			ctx.Redirect(http.StatusSeeOther, fmt.Sprintf("/team/%v/%v/new", slug, chartType))
   187  			return
   188  		}
   189  
   190  		ctx.Redirect(http.StatusSeeOther, "/oversikt")
   191  	})
   192  
   193  	c.router.GET("/team/:slug/:chart/edit", func(ctx *gin.Context) {
   194  		teamSlug := ctx.Param("slug")
   195  		chartType := getChartType(ctx.Param("chart"))
   196  		log := c.log.WithField("team", teamSlug).WithField("chart", chartType)
   197  
   198  		session := sessions.Default(ctx)
   199  
   200  		form, err := c.getEditChart(ctx, teamSlug, chartType)
   201  		if err != nil {
   202  			var validationErrorse validator.ValidationErrors
   203  			if errors.As(err, &validationErrorse) {
   204  				for _, fieldError := range validationErrorse {
   205  					log.WithError(err).Infof("field error: %v", fieldError)
   206  					session.AddFlash(descriptiveMessageForChartError(fieldError))
   207  				}
   208  			} else {
   209  				log.WithError(err).Info("non-field error")
   210  				session.AddFlash(err.Error())
   211  			}
   212  
   213  			err := session.Save()
   214  			if err != nil {
   215  				log.WithError(err).Error("problem saving session")
   216  				ctx.Redirect(http.StatusSeeOther, "/oversikt")
   217  				return
   218  			}
   219  			ctx.Redirect(http.StatusSeeOther, "/oversikt")
   220  			return
   221  		}
   222  
   223  		flashes := session.Flashes()
   224  		err = session.Save()
   225  		if err != nil {
   226  			log.WithError(err).Error("problem saving session")
   227  			return
   228  		}
   229  
   230  		ctx.HTML(http.StatusOK, fmt.Sprintf("charts/%v", chartType), gin.H{
   231  			"team":     teamSlug,
   232  			"values":   form,
   233  			"errors":   flashes,
   234  			"loggedIn": ctx.GetBool(middlewares.LoggedInKey),
   235  			"isAdmin":  ctx.GetBool(middlewares.AdminKey),
   236  		})
   237  	})
   238  
   239  	c.router.POST("/team/:slug/:chart/edit", func(ctx *gin.Context) {
   240  		teamSlug := ctx.Param("slug")
   241  		chartType := getChartType(ctx.Param("chart"))
   242  		log := c.log.WithField("team", teamSlug).WithField("chart", chartType)
   243  
   244  		err := c.editChart(ctx, teamSlug, chartType)
   245  		if err != nil {
   246  			session := sessions.Default(ctx)
   247  			var validationErrorse validator.ValidationErrors
   248  			if errors.As(err, &validationErrorse) {
   249  				for _, fieldError := range validationErrorse {
   250  					log.WithError(err).Infof("field error: %v", fieldError)
   251  					session.AddFlash(descriptiveMessageForChartError(fieldError))
   252  				}
   253  			} else {
   254  				log.WithError(err).Info("non-field error")
   255  				session.AddFlash(err.Error())
   256  			}
   257  
   258  			err := session.Save()
   259  			if err != nil {
   260  				log.WithError(err).Error("problem saving session")
   261  				ctx.Redirect(http.StatusSeeOther, fmt.Sprintf("/team/%v/%v/edit", teamSlug, chartType))
   262  				return
   263  			}
   264  
   265  			ctx.Redirect(http.StatusSeeOther, fmt.Sprintf("/team/%v/%v/edit", teamSlug, chartType))
   266  			return
   267  		}
   268  
   269  		ctx.Redirect(http.StatusSeeOther, "/oversikt")
   270  	})
   271  
   272  	c.router.POST("/team/:slug/:chart/delete", func(ctx *gin.Context) {
   273  		teamSlug := ctx.Param("slug")
   274  		chartTypeString := ctx.Param("chart")
   275  		log := c.log.WithField("team", teamSlug).WithField("chart", chartTypeString)
   276  
   277  		err := c.deleteChart(ctx, teamSlug, chartTypeString)
   278  		if err != nil {
   279  			log.WithError(err).Errorf("problem deleting chart %v for team %v", chartTypeString, teamSlug)
   280  			session := sessions.Default(ctx)
   281  			session.AddFlash(err.Error())
   282  			err := session.Save()
   283  			if err != nil {
   284  				log.WithError(err).Error("problem saving session")
   285  			}
   286  		}
   287  
   288  		ctx.Redirect(http.StatusSeeOther, "/oversikt")
   289  	})
   290  }
   291  
   292  func (c *client) getExistingAllowlist(ctx context.Context, teamID string) ([]string, error) {
   293  	extraAnnotations, err := c.repo.TeamValueGet(ctx, "singleuser.extraAnnotations", teamID)
   294  	if err != nil {
   295  		if errors.Is(err, sql.ErrNoRows) {
   296  			return []string{}, nil
   297  		}
   298  		return nil, err
   299  	}
   300  
   301  	var annotations map[string]string
   302  	if err := json.Unmarshal([]byte(extraAnnotations.Value), &annotations); err != nil {
   303  		return nil, err
   304  	}
   305  
   306  	for k, v := range annotations {
   307  		if k == "allowlist" {
   308  			return strings.Split(v, ","), nil
   309  		}
   310  	}
   311  
   312  	return []string{}, nil
   313  }
   314  
   315  func (c *client) newChart(ctx *gin.Context, teamSlug string, chartType gensql.ChartType) error {
   316  	team, err := c.repo.TeamBySlugGet(ctx, teamSlug)
   317  	if err != nil {
   318  		return err
   319  	}
   320  
   321  	switch chartType {
   322  	case gensql.ChartTypeJupyterhub:
   323  		var form jupyterForm
   324  		err := ctx.ShouldBindWith(&form, binding.Form)
   325  		if err != nil {
   326  			return err
   327  		}
   328  
   329  		cullTimeout, err := strconv.ParseUint(form.CullTimeout, 10, 64)
   330  		if err != nil {
   331  			return err
   332  		}
   333  
   334  		userIdents, err := c.azureClient.ConvertEmailsToIdents(team.Users)
   335  		if err != nil {
   336  			return err
   337  		}
   338  
   339  		cpuLimit, err := parseCPU(form.CPULimit)
   340  		if err != nil {
   341  			return err
   342  		}
   343  
   344  		cpuRequest, err := parseCPU(form.CPURequest)
   345  		if err != nil {
   346  			return err
   347  		}
   348  
   349  		memoryLimit, err := parseMemory(form.MemoryLimit)
   350  		if err != nil {
   351  			return err
   352  		}
   353  
   354  		memoryRequest, err := parseMemory(form.MemoryRequest)
   355  		if err != nil {
   356  			return err
   357  		}
   358  
   359  		values := chart.JupyterConfigurableValues{
   360  			TeamID:        team.ID,
   361  			UserIdents:    userIdents,
   362  			CPULimit:      cpuLimit,
   363  			CPURequest:    cpuRequest,
   364  			MemoryLimit:   memoryLimit,
   365  			MemoryRequest: memoryRequest,
   366  			ImageName:     form.ImageName,
   367  			ImageTag:      form.ImageTag,
   368  			CullTimeout:   strconv.FormatUint(cullTimeout, 10),
   369  			AllowList:     removeEmptySliceElements(form.Allowlist),
   370  			PYPIAccess:    form.PYPIAccess == "on",
   371  		}
   372  
   373  		return c.repo.RegisterCreateJupyterEvent(ctx, team.ID, values)
   374  	case gensql.ChartTypeAirflow:
   375  		var form airflowForm
   376  		err := ctx.ShouldBindWith(&form, binding.Form)
   377  		if err != nil {
   378  			return err
   379  		}
   380  
   381  		dagRepoBranch := form.DagRepoBranch
   382  		if dagRepoBranch == "" {
   383  			dagRepoBranch = "main"
   384  		}
   385  
   386  		airflowImage := ""
   387  		airflowTag := ""
   388  		if form.AirflowImage != "" {
   389  			imageParts := strings.Split(form.AirflowImage, ":")
   390  			airflowImage = imageParts[0]
   391  			airflowTag = imageParts[1]
   392  		}
   393  
   394  		values := chart.AirflowConfigurableValues{
   395  			TeamID:        team.ID,
   396  			DagRepo:       form.DagRepo,
   397  			DagRepoBranch: dagRepoBranch,
   398  			ApiAccess:     form.ApiAccess == "on",
   399  			AirflowImage:  airflowImage,
   400  			AirflowTag:    airflowTag,
   401  		}
   402  
   403  		return c.repo.RegisterCreateAirflowEvent(ctx, team.ID, values)
   404  	}
   405  
   406  	return fmt.Errorf("chart type %v is not supported", chartType)
   407  }
   408  
   409  func (c *client) getEditChart(ctx *gin.Context, teamSlug string, chartType gensql.ChartType) (any, error) {
   410  	team, err := c.repo.TeamBySlugGet(ctx, teamSlug)
   411  	if err != nil {
   412  		return nil, err
   413  	}
   414  
   415  	var chartObjects any
   416  	switch chartType {
   417  	case gensql.ChartTypeJupyterhub:
   418  		chartObjects = &chart.JupyterConfigurableValues{}
   419  	case gensql.ChartTypeAirflow:
   420  		chartObjects = &chart.AirflowConfigurableValues{}
   421  	default:
   422  		return nil, fmt.Errorf("chart type %v is not supported", chartType)
   423  	}
   424  
   425  	err = c.repo.TeamConfigurableValuesGet(ctx, chartType, team.ID, chartObjects)
   426  	if err != nil {
   427  		return nil, err
   428  	}
   429  
   430  	var form any
   431  	switch chartType {
   432  	case gensql.ChartTypeJupyterhub:
   433  		jupyterhubValues := chartObjects.(*chart.JupyterConfigurableValues)
   434  		allowlist, err := c.getExistingAllowlist(ctx, team.ID)
   435  		if err != nil && !errors.Is(err, sql.ErrNoRows) {
   436  			return nil, err
   437  		}
   438  
   439  		pypiAccessTeamValue, err := c.repo.TeamValueGet(ctx, chart.TeamValueKeyPYPIAccess, team.ID)
   440  		if err != nil && !errors.Is(err, sql.ErrNoRows) {
   441  			return nil, err
   442  		}
   443  
   444  		pypiAccess := "off"
   445  		if pypiAccessTeamValue.Value == "true" {
   446  			pypiAccess = "on"
   447  		}
   448  
   449  		form = jupyterForm{
   450  			CPULimit:      jupyterhubValues.CPULimit,
   451  			CPURequest:    jupyterhubValues.CPURequest,
   452  			MemoryLimit:   jupyterhubValues.MemoryLimit,
   453  			MemoryRequest: jupyterhubValues.MemoryRequest,
   454  			ImageName:     jupyterhubValues.ImageName,
   455  			ImageTag:      jupyterhubValues.ImageTag,
   456  			CullTimeout:   jupyterhubValues.CullTimeout,
   457  			PYPIAccess:    pypiAccess,
   458  			Allowlist:     allowlist,
   459  		}
   460  	case gensql.ChartTypeAirflow:
   461  		airflowValues := chartObjects.(*chart.AirflowConfigurableValues)
   462  		apiAccessTeamValue, err := c.repo.TeamValueGet(ctx, chart.TeamValueKeyApiAccess, team.ID)
   463  		if err != nil && !errors.Is(err, sql.ErrNoRows) {
   464  			return nil, err
   465  		}
   466  
   467  		apiAccess := ""
   468  		if apiAccessTeamValue.Value == "true" {
   469  			apiAccess = "on"
   470  		}
   471  
   472  		airflowImage := ""
   473  		if airflowValues.AirflowImage != "" && airflowValues.AirflowTag != "" {
   474  			airflowImage = fmt.Sprintf("%v:%v", airflowValues.AirflowImage, airflowValues.AirflowTag)
   475  		}
   476  
   477  		form = airflowForm{
   478  			DagRepo:       airflowValues.DagRepo,
   479  			DagRepoBranch: airflowValues.DagRepoBranch,
   480  			ApiAccess:     apiAccess,
   481  			AirflowImage:  airflowImage,
   482  		}
   483  	}
   484  
   485  	return form, nil
   486  }
   487  
   488  func (c *client) editChart(ctx *gin.Context, teamSlug string, chartType gensql.ChartType) error {
   489  	team, err := c.repo.TeamBySlugGet(ctx, teamSlug)
   490  	if err != nil {
   491  		return err
   492  	}
   493  
   494  	switch chartType {
   495  	case gensql.ChartTypeJupyterhub:
   496  		var form jupyterForm
   497  		err := ctx.ShouldBindWith(&form, binding.Form)
   498  		if err != nil {
   499  			return err
   500  		}
   501  
   502  		userIdents, err := c.azureClient.ConvertEmailsToIdents(team.Users)
   503  		if err != nil {
   504  			return err
   505  		}
   506  
   507  		cpuLimit, err := parseCPU(form.CPULimit)
   508  		if err != nil {
   509  			return err
   510  		}
   511  
   512  		cpuRequest, err := parseCPU(form.CPURequest)
   513  		if err != nil {
   514  			return err
   515  		}
   516  
   517  		memoryLimit, err := parseMemory(form.MemoryLimit)
   518  		if err != nil {
   519  			return err
   520  		}
   521  
   522  		memoryRequest, err := parseMemory(form.MemoryRequest)
   523  		if err != nil {
   524  			return err
   525  		}
   526  
   527  		values := chart.JupyterConfigurableValues{
   528  			TeamID:        team.ID,
   529  			UserIdents:    userIdents,
   530  			CPULimit:      cpuLimit,
   531  			CPURequest:    cpuRequest,
   532  			MemoryLimit:   memoryLimit,
   533  			MemoryRequest: memoryRequest,
   534  			ImageName:     form.ImageName,
   535  			ImageTag:      form.ImageTag,
   536  			CullTimeout:   form.CullTimeout,
   537  			PYPIAccess:    form.PYPIAccess == "on",
   538  			AllowList:     removeEmptySliceElements(form.Allowlist),
   539  		}
   540  
   541  		return c.repo.RegisterUpdateJupyterEvent(ctx, team.ID, values)
   542  	case gensql.ChartTypeAirflow:
   543  		var form airflowForm
   544  		err := ctx.ShouldBindWith(&form, binding.Form)
   545  		if err != nil {
   546  			return err
   547  		}
   548  
   549  		dagRepoBranch := form.DagRepoBranch
   550  		if dagRepoBranch == "" {
   551  			dagRepoBranch = "main"
   552  		}
   553  
   554  		airflowImage := ""
   555  		airflowTag := ""
   556  		if form.AirflowImage != "" {
   557  			imageParts := strings.Split(form.AirflowImage, ":")
   558  			airflowImage = imageParts[0]
   559  			airflowTag = imageParts[1]
   560  		}
   561  
   562  		values := chart.AirflowConfigurableValues{
   563  			TeamID:        team.ID,
   564  			DagRepo:       form.DagRepo,
   565  			DagRepoBranch: dagRepoBranch,
   566  			ApiAccess:     form.ApiAccess == "on",
   567  			AirflowImage:  airflowImage,
   568  			AirflowTag:    airflowTag,
   569  		}
   570  
   571  		return c.repo.RegisterUpdateAirflowEvent(ctx, team.ID, values)
   572  	}
   573  
   574  	return fmt.Errorf("chart type %v is not supported", chartType)
   575  }
   576  
   577  func (c *client) deleteChart(ctx *gin.Context, teamSlug, chartTypeString string) error {
   578  	team, err := c.repo.TeamBySlugGet(ctx, teamSlug)
   579  	if err != nil {
   580  		return err
   581  	}
   582  
   583  	switch getChartType(chartTypeString) {
   584  	case gensql.ChartTypeJupyterhub:
   585  		return c.repo.RegisterDeleteJupyterEvent(ctx, team.ID)
   586  	case gensql.ChartTypeAirflow:
   587  		return c.repo.RegisterDeleteAirflowEvent(ctx, team.ID)
   588  	}
   589  
   590  	return fmt.Errorf("chart type %v is not supported", chartTypeString)
   591  }
   592  
   593  func parseCPU(cpu string) (string, error) {
   594  	floatVal, err := strconv.ParseFloat(cpu, 64)
   595  	if err != nil {
   596  		return "", err
   597  	}
   598  
   599  	return fmt.Sprintf("%.1f", floatVal), nil
   600  }
   601  
   602  func parseMemory(memory string) (string, error) {
   603  	if strings.HasSuffix(memory, "G") {
   604  		return memory, nil
   605  	}
   606  	_, err := strconv.ParseFloat(memory, 64)
   607  	if err != nil {
   608  		return "", err
   609  	}
   610  	return memory + "G", nil
   611  }