go.charczuk.com@v0.0.0-20240327042549-bc490516bd1a/projects/nodes/pkg/dbmodel/manager.go (about)

     1  /*
     2  
     3  Copyright (c) 2023 - Present. Will Charczuk. All rights reserved.
     4  Use of this source code is governed by a MIT license that can be found in the LICENSE file at the root of the repository.
     5  
     6  */
     7  
     8  package dbmodel
     9  
    10  import (
    11  	"context"
    12  	"fmt"
    13  
    14  	"github.com/wcharczuk/go-incr"
    15  	"go.charczuk.com/projects/nodes/pkg/types"
    16  	"go.charczuk.com/sdk/collections"
    17  	"go.charczuk.com/sdk/db"
    18  	"go.charczuk.com/sdk/db/dbutil"
    19  	"go.charczuk.com/sdk/iter"
    20  	"go.charczuk.com/sdk/uuid"
    21  )
    22  
    23  type Manager struct {
    24  	dbutil.BaseManager
    25  }
    26  
    27  var graphMeta = db.TypeMetaFor(Graph{})
    28  var graphsForUserQuery = fmt.Sprintf(`SELECT %s FROM %s WHERE user_id = $1`, db.ColumnNamesCSV(graphMeta.Columns()), Graph{}.TableName())
    29  var nodesMeta = db.TypeMetaFor(Node{})
    30  var nodesQuery = fmt.Sprintf(`SELECT %s FROM %s WHERE graph_id = $1`, db.ColumnNamesCSV(nodesMeta.Columns()), Node{}.TableName())
    31  var nodeValuesMeta = db.TypeMetaFor(NodeValue{})
    32  var nodeValuesQuery = fmt.Sprintf(`SELECT %s FROM %s WHERE graph_id = $1`, db.ColumnNamesCSV(nodeValuesMeta.Columns()), NodeValue{}.TableName())
    33  var nodeValuesManyQuery = fmt.Sprintf(`SELECT %s FROM %s WHERE graph_id = $1 AND node_id = ANY($2)`, db.ColumnNamesCSV(nodeValuesMeta.Columns()), NodeValue{}.TableName())
    34  var edgesMeta = db.TypeMetaFor(Edge{})
    35  var edgesQuery = fmt.Sprintf(`SELECT %s FROM %s WHERE graph_id = $1`, db.ColumnNamesCSV(edgesMeta.Columns()), Edge{}.TableName())
    36  var graphRecomputeHeapMeta = db.TypeMetaFor(GraphRecomputeHeap{})
    37  var graphRecomputeHeapQuery = fmt.Sprintf(`SELECT %s FROM %s WHERE graph_id = $1`, db.ColumnNamesCSV(graphRecomputeHeapMeta.Columns()), GraphRecomputeHeap{}.TableName())
    38  var setViewportExec = fmt.Sprintf(`UPDATE %s SET viewport_x = $1, viewport_y = $2, viewport_zoom = $3, updated_utc = current_timestamp WHERE id = $4`, Graph{}.TableName())
    39  var graphLogsMeta = db.TypeMetaFor(GraphLogs{})
    40  var graphLogsLatestQuery = fmt.Sprintf(`SELECT %s FROM %s WHERE graph_id = $1 ORDER BY stabilization_num DESC LIMIT 1`, db.ColumnNamesCSV(graphLogsMeta.Columns()), GraphLogs{}.TableName())
    41  var setNodeSetAtExec = fmt.Sprintf(`UPDATE %s SET set_at = (SELECT stabilization_num FROM graph WHERE id = node.graph_id LIMIT 1) WHERE id = $1`, Node{}.TableName())
    42  var upsertRecomputeHeapExec = fmt.Sprintf(`INSERT INTO %s (graph_id, user_id, node_id) VALUES ($1, $2, $3) ON CONFLICT DO NOTHING`, GraphRecomputeHeap{}.TableName())
    43  var clearRecomputeHeapExec = fmt.Sprintf(`DELETE FROM %s WHERE graph_id = $1`, GraphRecomputeHeap{}.TableName())
    44  var insertRecomputeHeapExec = fmt.Sprintf(`INSERT INTO %s (graph_id, user_id, node_id) SELECT $1, $2, unnest FROM unnest($3::uuid[])`, GraphRecomputeHeap{}.TableName())
    45  var touchGraphExec = fmt.Sprintf(`UPDATE %s SET updated_utc = current_timestamp WHERE id = $1`, Graph{}.TableName())
    46  var graphActiveForUserQuery = fmt.Sprintf(`SELECT %s FROM %s WHERE user_id = $1 ORDER BY updated_utc DESC LIMIT 1`, db.ColumnNamesCSV(graphMeta.Columns()), Graph{}.TableName())
    47  
    48  // neverPatch are database column names (and by chance, json field names, but we care about column names) that users
    49  // can _never_ mutate with patch endpoints for both graphs and nodes.
    50  var neverPatch = collections.NewSet[string]([]string{"id", "graph_id", "node_id", "user_id", "created_utc"})
    51  
    52  func (m Manager) Deserialize(ctx context.Context, graphID uuid.UUID, skipValues, skipRecomputeHeap bool) (output *types.GraphFull, err error) {
    53  	var g Graph
    54  	var ok bool
    55  	if ok, err = m.Invoke(ctx).Get(&g, graphID); err != nil {
    56  		return
    57  	} else if !ok {
    58  		return
    59  	}
    60  	output = new(types.GraphFull)
    61  	output.Graph = TypeGraphFromGraph(g)
    62  
    63  	var nodes []Node
    64  	nodes, err = m.Nodes(ctx, g.ID)
    65  	if err != nil {
    66  		return
    67  	}
    68  	for _, n := range nodes {
    69  		output.Nodes = append(output.Nodes, TypeNodeFromNode(n))
    70  	}
    71  
    72  	var edges []Edge
    73  	edges, err = m.Edges(ctx, g.ID)
    74  	if err != nil {
    75  		return
    76  	}
    77  	for _, e := range edges {
    78  		output.Edges = append(output.Edges, types.Edge{
    79  			ParentID:       incr.Identifier(e.ParentID),
    80  			ChildID:        incr.Identifier(e.ChildID),
    81  			ChildInputName: e.ChildInputName,
    82  		})
    83  	}
    84  
    85  	if !skipValues {
    86  		var values []NodeValue
    87  		values, err = m.NodeValues(ctx, g.ID)
    88  		if err != nil {
    89  			return
    90  		}
    91  		for _, v := range values {
    92  			var parsedValue any
    93  			parsedValue, err = v.ParsedValue()
    94  			if err != nil {
    95  				return
    96  			}
    97  			output.Values = append(output.Values, types.NodeValue{
    98  				ID:        incr.Identifier(v.NodeID),
    99  				ValueType: v.ValueType,
   100  				Value:     parsedValue,
   101  			})
   102  		}
   103  	}
   104  	if !skipRecomputeHeap {
   105  		var recomputeHeap []GraphRecomputeHeap
   106  		recomputeHeap, err = m.RecomputeHeap(ctx, graphID)
   107  		if err != nil {
   108  			return
   109  		}
   110  		output.RecomputeHeap = iter.Apply(recomputeHeap, func(n GraphRecomputeHeap) incr.Identifier { return incr.Identifier(n.NodeID) })
   111  	}
   112  	return
   113  }
   114  
   115  func (m Manager) GraphsForUser(ctx context.Context, userID uuid.UUID) (output []Graph, err error) {
   116  	err = m.Invoke(ctx, db.OptLabel("graphs_for_user")).Query(graphsForUserQuery, userID).OutMany(&output)
   117  	return
   118  }
   119  
   120  func (m Manager) GraphActiveForUser(ctx context.Context, userID uuid.UUID) (output Graph, found bool, err error) {
   121  	found, err = m.Invoke(ctx, db.OptLabel("graph_active")).Query(graphActiveForUserQuery, userID).Out(&output)
   122  	return
   123  }
   124  
   125  func (m Manager) PatchGraph(ctx context.Context, graphID uuid.UUID, ps types.PatchSet) (found bool, err error) {
   126  	var params []any
   127  	var statement = fmt.Sprintf("UPDATE %s SET updated_utc = current_timestamp", Graph{}.TableName())
   128  	for key, value := range ps {
   129  		if !graphMeta.HasColumn(key) {
   130  			continue
   131  		}
   132  		if neverPatch.Has(key) {
   133  			continue
   134  		}
   135  		params = append(params, value)
   136  		statement = statement + fmt.Sprintf("\n, %s = $%d", key, len(params))
   137  	}
   138  	if len(params) == 0 {
   139  		return false, nil
   140  	}
   141  	params = append(params, graphID)
   142  	statement = statement + fmt.Sprintf("\nWHERE id = $%d", len(params))
   143  	return db.ExecAffectedAny(m.Invoke(ctx, db.OptLabel("patch_graph")).Exec(statement, params...))
   144  }
   145  
   146  func (m Manager) Nodes(ctx context.Context, graphID uuid.UUID) (output []Node, err error) {
   147  	err = m.Invoke(ctx, db.OptLabel("nodes")).Query(nodesQuery, graphID).OutMany(&output)
   148  	return
   149  }
   150  
   151  func (m Manager) NodeValues(ctx context.Context, graphID uuid.UUID, nodeIDs ...uuid.UUID) (output []NodeValue, err error) {
   152  	if len(nodeIDs) > 0 {
   153  		err = m.Invoke(ctx, db.OptLabel("graph_node_values_by_id")).Query(nodeValuesManyQuery, graphID, nodeIDs).OutMany(&output)
   154  	} else {
   155  		err = m.Invoke(ctx, db.OptLabel("graph_node_values_all")).Query(nodeValuesQuery, graphID).OutMany(&output)
   156  	}
   157  	return
   158  }
   159  
   160  func (m Manager) Edges(ctx context.Context, graphID uuid.UUID) (output []Edge, err error) {
   161  	err = m.Invoke(ctx, db.OptLabel("graph_edges")).Query(edgesQuery, graphID).OutMany(&output)
   162  	return
   163  }
   164  
   165  func (m Manager) RecomputeHeap(ctx context.Context, graphID uuid.UUID) (output []GraphRecomputeHeap, err error) {
   166  	err = m.Invoke(ctx, db.OptLabel("graph_recompute_heap")).Query(graphRecomputeHeapQuery, graphID).OutMany(&output)
   167  	return
   168  }
   169  
   170  func (m Manager) Graph(ctx context.Context, id uuid.UUID) (output Graph, found bool, err error) {
   171  	found, err = m.Invoke(ctx).Get(&output, id)
   172  	return
   173  }
   174  
   175  func (m Manager) Node(ctx context.Context, id uuid.UUID) (output Node, found bool, err error) {
   176  	found, err = m.Invoke(ctx).Get(&output, id)
   177  	return
   178  }
   179  
   180  func (m Manager) NodeValue(ctx context.Context, graphID, nodeID uuid.UUID) (output NodeValue, found bool, err error) {
   181  	found, err = m.Invoke(ctx).Get(&output, graphID, nodeID)
   182  	return
   183  }
   184  
   185  func (m Manager) SetNodeValues(ctx context.Context, graphID, userID uuid.UUID, values map[uuid.UUID]any) error {
   186  	if len(values) == 0 {
   187  		return nil
   188  	}
   189  	statement := `INSERT INTO node_value (graph_id, user_id, node_id, value_type, value) VALUES `
   190  	params := []any{graphID, userID}
   191  	var first = true
   192  	for key, value := range values {
   193  		if first {
   194  			statement = statement + fmt.Sprintf("($1, $2, $%d, $%d, $%d)", len(params)+1, len(params)+2, len(params)+3)
   195  			first = false
   196  		} else {
   197  			statement = statement + fmt.Sprintf(", ($1, $2, $%d, $%d, $%d)", len(params)+1, len(params)+2, len(params)+3)
   198  			first = false
   199  		}
   200  		params = append(params, key, DetectValueType(value), db.JSON(value))
   201  	}
   202  	statement = statement + ` ON CONFLICT (graph_id, node_id) DO UPDATE SET value_type = excluded.value_type, value = excluded.value`
   203  	return db.ExecErr(m.Invoke(ctx, db.OptLabel("set_node_values")).Exec(statement, params...))
   204  }
   205  
   206  func (m Manager) DeleteGraph(ctx context.Context, graphID uuid.UUID) (found bool, err error) {
   207  	if _, err = m.Invoke(ctx, db.OptLabel("delete_graph_edges")).Exec(`DELETE FROM edge WHERE graph_id = $1`, graphID); err != nil {
   208  		return
   209  	}
   210  	if _, err = m.Invoke(ctx, db.OptLabel("delete_graph_node_values")).Exec(`DELETE FROM node_value WHERE graph_id = $1`, graphID); err != nil {
   211  		return
   212  	}
   213  	if _, err = m.Invoke(ctx, db.OptLabel("delete_graph_recompute_heap")).Exec(`DELETE FROM graph_recompute_heap WHERE graph_id = $1`, graphID); err != nil {
   214  		return
   215  	}
   216  	if _, err = m.Invoke(ctx, db.OptLabel("delete_graph_nodes")).Exec(`DELETE FROM node WHERE graph_id = $1`, graphID); err != nil {
   217  		return
   218  	}
   219  	if _, err = m.Invoke(ctx, db.OptLabel("delete_graph_logs")).Exec(`DELETE FROM graph_log WHERE graph_id = $1`, graphID); err != nil {
   220  		return
   221  	}
   222  	return db.ExecAffectedAny(m.Invoke(ctx, db.OptLabel("delete_graph_graph")).Exec(`DELETE FROM graph WHERE id = $1`, graphID))
   223  }
   224  
   225  func (m Manager) DeleteNode(ctx context.Context, nodeID uuid.UUID) (found bool, err error) {
   226  	if _, err = m.Invoke(ctx, db.OptLabel("delete_node_edges")).Exec(`DELETE FROM edge WHERE child_id = $1 OR parent_id = $1`, nodeID); err != nil {
   227  		return
   228  	}
   229  	if _, err = m.Invoke(ctx, db.OptLabel("delete_node_recompute_heap")).Exec(`DELETE FROM graph_recompute_heap WHERE node_id = $1`, nodeID); err != nil {
   230  		return
   231  	}
   232  	if _, err = m.Invoke(ctx, db.OptLabel("delete_node_values")).Exec(`DELETE FROM node_value WHERE node_id = $1 `, nodeID); err != nil {
   233  		return
   234  	}
   235  	return db.ExecAffectedAny(m.Invoke(ctx, db.OptLabel("delete_node_node")).Exec(`DELETE FROM node WHERE id = $1`, nodeID))
   236  }
   237  
   238  func (m Manager) DeleteEdge(ctx context.Context, parentID, childID uuid.UUID, childInputName string) (err error) {
   239  	if _, err = m.Invoke(ctx, db.OptLabel("delete_edge")).Exec(`DELETE FROM edge WHERE parent_id = $1 AND child_id = $2 AND child_input_name = $3`, parentID, childID, childInputName); err != nil {
   240  		return
   241  	}
   242  	return
   243  }
   244  
   245  func (m Manager) SetViewport(ctx context.Context, graphID uuid.UUID, viewport types.Viewport) (err error) {
   246  	_, err = m.Invoke(ctx, db.OptLabel("set_viewport")).Exec(setViewportExec, viewport.X, viewport.Y, viewport.Zoom, graphID)
   247  	return
   248  }
   249  
   250  // TouchGraph sets the updated time for the graph in the database.
   251  //
   252  // The updated time is used to order the graph in lists, so this will push
   253  // the graph to be the most recently updated graph in those lists.
   254  func (m Manager) TouchGraph(ctx context.Context, graphID uuid.UUID) (err error) {
   255  	_, err = m.Invoke(ctx, db.OptLabel("touch_graph")).Exec(touchGraphExec, graphID)
   256  	return
   257  }
   258  
   259  func (m Manager) GraphLogsLatest(ctx context.Context, graphID uuid.UUID) (output GraphLogs, found bool, err error) {
   260  	found, err = m.Invoke(ctx, db.OptLabel("graph_logs_latest")).Query(graphLogsLatestQuery, graphID).Out(&output)
   261  	return
   262  }
   263  
   264  func (m Manager) PatchNodes(ctx context.Context, graphID uuid.UUID, ps types.PatchSet) error {
   265  	var params []any
   266  	var statement = fmt.Sprintf("UPDATE %s SET", Node{}.TableName())
   267  	var first = true
   268  	for key, value := range ps {
   269  		if !nodesMeta.HasColumn(key) {
   270  			continue
   271  		}
   272  		if neverPatch.Has(key) {
   273  			continue
   274  		}
   275  		params = append(params, value)
   276  		if first {
   277  			statement = statement + fmt.Sprintf("\n%s = $%d", key, len(params))
   278  			first = false
   279  		} else {
   280  			statement = statement + fmt.Sprintf("\n, %s = $%d", key, len(params))
   281  		}
   282  	}
   283  	if len(params) == 0 {
   284  		return nil
   285  	}
   286  	params = append(params, graphID)
   287  	statement = statement + fmt.Sprintf("\nWHERE graph_id = $%d", len(params))
   288  	return db.ExecErr(m.Invoke(ctx, db.OptLabel("patch_nodes")).Exec(statement, params...))
   289  }
   290  
   291  func (m Manager) PatchNode(ctx context.Context, nodeID uuid.UUID, ps types.PatchSet) (bool, error) {
   292  	var params []any
   293  	var statement = fmt.Sprintf("UPDATE %s SET", Node{}.TableName())
   294  	var first = true
   295  	for key, value := range ps {
   296  		if !nodesMeta.HasColumn(key) {
   297  			continue
   298  		}
   299  		if neverPatch.Has(key) {
   300  			continue
   301  		}
   302  		params = append(params, value)
   303  		if first {
   304  			statement = statement + fmt.Sprintf("\n%s = $%d", key, len(params))
   305  			first = false
   306  		} else {
   307  			statement = statement + fmt.Sprintf("\n, %s = $%d", key, len(params))
   308  		}
   309  	}
   310  	if len(params) == 0 {
   311  		return false, nil
   312  	}
   313  	params = append(params, nodeID)
   314  	statement = statement + fmt.Sprintf("\nWHERE id = $%d", len(params))
   315  	return db.ExecAffectedAny(m.Invoke(ctx, db.OptLabel("patch_node")).Exec(statement, params...))
   316  }
   317  
   318  func (m Manager) MarkNodeStale(ctx context.Context, graphID, userID, nodeID uuid.UUID) (found bool, err error) {
   319  	found, err = db.ExecAffectedAny(m.Invoke(ctx, db.OptLabel("mark_node_stale")).Exec(setNodeSetAtExec, nodeID))
   320  	if err != nil {
   321  		return
   322  	}
   323  	if !found {
   324  		return
   325  	}
   326  	err = db.ExecErr(m.Invoke(ctx, db.OptLabel("upsert_recompute_heap")).Exec(upsertRecomputeHeapExec, graphID, userID, nodeID))
   327  	return
   328  }
   329  
   330  func (m Manager) UpdateGraphPostStabilization(ctx context.Context, graphID uuid.UUID, stabilizationNum uint64) error {
   331  	return db.ExecErr(m.Invoke(ctx, db.OptLabel("update_graph_post_stabilization")).Exec(`UPDATE graph SET stabilization_num = $2 WHERE id = $1`, graphID, stabilizationNum))
   332  }
   333  
   334  func (m Manager) SetRecomputeHeap(ctx context.Context, graphID, userID uuid.UUID, nodeIDs ...uuid.UUID) (err error) {
   335  	err = db.ExecErr(m.Invoke(ctx, db.OptLabel("clear_recompute_heap")).Exec(clearRecomputeHeapExec, graphID))
   336  	if err != nil {
   337  		return
   338  	}
   339  	if len(nodeIDs) > 0 {
   340  		err = db.ExecErr(m.Invoke(ctx, db.OptLabel("insert_recompute_heap")).Exec(insertRecomputeHeapExec, graphID, userID, nodeIDs))
   341  	}
   342  	return
   343  }
   344  
   345  func (m Manager) SetNodeHeights(ctx context.Context, graphID uuid.UUID, nodeHeights map[uuid.UUID]int) (err error) {
   346  	if len(nodeHeights) == 0 {
   347  		return nil
   348  	}
   349  
   350  	sql := `UPDATE node as n SET height = c.height FROM ( values `
   351  	var params []any
   352  	var first = true
   353  	for nodeID, height := range nodeHeights {
   354  		var comma string
   355  		if first {
   356  			first = false
   357  		} else {
   358  			comma = ","
   359  		}
   360  		sql = sql + fmt.Sprintf("\n%s($%d::uuid, $%d::bigint)", comma, len(params)+1, len(params)+2)
   361  		params = append(params, nodeID, int64(height))
   362  	}
   363  	sql = sql + ") as c (node_id, height) WHERE c.node_id = n.id"
   364  	err = db.ExecErr(m.Invoke(ctx, db.OptLabel("set_node_heights")).Exec(sql, params...))
   365  	return
   366  }
   367  
   368  func (m Manager) SetNodeMetadata(ctx context.Context, graphID uuid.UUID, nodeMetadata map[uuid.UUID]Node) (err error) {
   369  	if len(nodeMetadata) == 0 {
   370  		return nil
   371  	}
   372  
   373  	sql := `UPDATE node as n SET set_at = c.set_at, changed_at = c.changed_at, recomputed_at = c.recomputed_at FROM ( values `
   374  	var params []any
   375  	var first = true
   376  	for nodeID, meta := range nodeMetadata {
   377  		var comma string
   378  		if first {
   379  			first = false
   380  		} else {
   381  			comma = ","
   382  		}
   383  		sql = sql + fmt.Sprintf("\n%s($%d::uuid, $%d::bigint, $%d::bigint, $%d::bigint)", comma, len(params)+1, len(params)+2, len(params)+3, len(params)+4)
   384  		params = append(params, nodeID, meta.SetAt, meta.ChangedAt, meta.RecomputedAt)
   385  	}
   386  	sql = sql + ") as c (node_id, set_at, changed_at, recomputed_at) WHERE c.node_id = n.id"
   387  	err = db.ExecErr(m.Invoke(ctx, db.OptLabel("set_node_metadata")).Exec(sql, params...))
   388  	return
   389  }