go.charczuk.com@v0.0.0-20240327042549-bc490516bd1a/projects/nodes/static/_nextjs/src/store/store.ts (about)

     1  /**
     2   * Copyright (c) 2024 - Present. Will Charczuk. All rights reserved.
     3   * Use of this source code is governed by a MIT license that can be found in the LICENSE file at the root of the repository.
     4   */
     5  import { createWithEqualityFn } from 'zustand/traditional';
     6  import {
     7    Connection,
     8    Edge,
     9    EdgeChange,
    10    Node,
    11    NodeChange,
    12    addEdge,
    13    applyNodeChanges,
    14    applyEdgeChanges,
    15    Position,
    16    MarkerType,
    17    getConnectedEdges,
    18    NodeRemoveChange,
    19    XYPosition,
    20  } from 'reactflow';
    21  import * as api from '../api/nodes';
    22  import * as graphApi from '../api/graphs';
    23  import * as workflowApi from '../api/workflow';
    24  import { NodeData } from './nodeData';
    25  import { ApiEffect } from './apiEffect';
    26  import { nodeTypes, valueArrayTypes } from '../refdata/nodeTypes';
    27  
    28  export type StoreState = {
    29    isLoading: boolean;
    30    graph: graphApi.Graph | null;
    31    graphs: graphApi.Graph[];
    32    nodes: Node<NodeData>[];
    33    watchedNodes: api.Node[];
    34    edges: Edge[];
    35  
    36    onApiEffect: (changes: ApiEffect[]) => Promise<any[]>;
    37  
    38    onNodesChange: (changes: NodeChange[]) => void;
    39    onEdgesChange: (changes: EdgeChange[]) => void;
    40    onConnect: (connection: Connection) => Promise<void>;
    41    onDisconnect: (edge: Edge) => Promise<void>;
    42  
    43    applyNodeDataChangeToStore: (id: string, change: (n0d: NodeData) => NodeData) => void;
    44  
    45    onRefresh: (graphId: string) => Promise<void>;
    46    onStabilize: () => Promise<workflowApi.WorkflowRunStatus>;
    47  };
    48  
    49  const useStateStore = createWithEqualityFn<StoreState>((set, get) => {
    50    // NOTE(wc): a word on change management!
    51    // react will only propagate a change to a state object if the _reference_ changes, that is
    52    // you have to create an entirely new object with the desired changes for the change to
    53    // propagate through react elements.
    54    //
    55    // e.g.: if you have a state reference `foo`:
    56    //  `foo.bar = "baz"` will not propagate!, you have to do `foo = {...foo, bar: "baz"}`
    57  
    58    // applyNodeDataChangeToStore will apply a change to a given node's data and update the reference for that node's data.
    59    const applyNodeDataChangeToStore = (id: string, change: (n0d: NodeData) => NodeData) => {
    60      set({
    61        nodes: get().nodes.map(n => {
    62          if (n.id === id) {
    63            const n0d = Object.assign({}, n.data)
    64            const n1d = change(n0d)
    65            n.data = n1d
    66          }
    67          return n
    68        })
    69      })
    70    }
    71  
    72    const filterWatchedNodes = (nodes: api.Node[]): api.Node[] => {
    73      const disalowedValueTypes = new Set([...valueArrayTypes, 'table']);
    74      const allowedNodeTypes = new Set(['var', 'observer']);
    75      const flowNodeIsDisallowed = (n: api.Node): boolean => {
    76        const node_type = n.metadata.node_type;
    77        const input_type = n.metadata.input_type || '';
    78        const output_type = n.metadata.output_type || '';
    79        const nodeWatched = n.metadata.watched === undefined ? false : n.metadata.watched;
    80        const nodeTypeIsAllowed = allowedNodeTypes.has(node_type);
    81        const nodeValueTypeIsAllowed = ((node_type === 'observer' && !disalowedValueTypes.has(input_type)) || (node_type === 'var' && !disalowedValueTypes.has(output_type)))
    82        return nodeTypeIsAllowed && nodeValueTypeIsAllowed && nodeWatched;
    83      }
    84      return nodes.filter(flowNodeIsDisallowed).sort((a, b) => a.label.localeCompare(b.label))
    85    }
    86  
    87    const graphId = () => {
    88      return get().graph?.id || '';
    89    }
    90  
    91    // onApiEffect will make changes to both the underlying data store via the API _and_ apply those changes to
    92    // the client memory mapped store. you should use this directly (instead of the reactflow methods) where you can!
    93    const onApiEffect = async (changes: ApiEffect[]) => {
    94      return Promise.all(changes.map(async c => {
    95        switch (c.type) {
    96          case 'add-node': {
    97            const nodeId = await api.postNode(graphId(), c.node);
    98            const node: api.Node = { ...Object.assign({}, c.node), id: nodeId }
    99            const { onApiEffect, graph } = get()
   100            const nodeType = nodeTypes[c.node.metadata.node_type];
   101            if (nodeType.canSetValue) {
   102              if (c.value !== undefined) {
   103                await api.putNodeValue(graphId(), nodeId, c.value)
   104              }
   105            }
   106            set({
   107              nodes: [...get().nodes, flowNodeFromApiNode(node, { onApiEffect, graph })],
   108            })
   109            return nodeId
   110          }
   111          case 'duplicate-node': {
   112            const nodeType = nodeTypes[c.node.metadata.node_type];
   113            let value = null;
   114            if (nodeType.canSetValue) {
   115              value = await api.getNodeValue(graphId(), c.node.id)
   116            }
   117            const [nodeId] = await onApiEffect([{
   118              type: 'add-node',
   119              node: {
   120                ...c.node,
   121                label: c.node.label + '-copy',
   122                metadata: {
   123                  ...c.node.metadata,
   124                  position_x: c.node.metadata.position_x + 350,
   125                  position_y: c.node.metadata.position_y,
   126                }
   127              },
   128              value: value
   129            }])
   130            return nodeId
   131          }
   132          case 'set-stale':
   133            return api.postNodeStale(graphId(), c.node_id)
   134          case 'unlink-nodes':
   135            await api.deleteEdge(graphId(), c)
   136            const edgeId = formatEdgeId(c)
   137            set({
   138              edges: get().edges.filter((e) => e.id !== edgeId)
   139            })
   140            return
   141          case 'link-nodes':
   142            await api.postEdge(graphId(), c)
   143            set({
   144              edges: addEdge(flowEdgeFromApiEdge(c), get().edges),
   145            });
   146            return
   147          case 'remove-node':
   148            await api.deleteNode(graphId(), c.node_id)
   149            const affectedNode = get().nodes.find(fn => fn.id === c.node_id)
   150            if (!affectedNode) {
   151              throw new Error(`remove-node; affected node not found: ${c.node_id}`)
   152            }
   153            const connectedEdges = getConnectedEdges([affectedNode], get().edges);
   154            set({
   155              nodes: get().nodes.filter(n => n.id !== c.node_id),
   156              edges: get().edges.filter(e => !connectedEdges.includes(e)),
   157              watchedNodes: [...get().watchedNodes.filter(n => n.id !== c.node_id)],
   158            })
   159            return
   160          case 'update-label':
   161            await api.patchNode(graphId(), c.node_id, { label: c.label });
   162            applyNodeDataChangeToStore(c.node_id, (n0d) => {
   163              n0d.node.label = c.label
   164              return n0d
   165            })
   166            return
   167          case 'update-expression':
   168            await api.patchNode(graphId(), c.node_id, { expression: c.expression });
   169            applyNodeDataChangeToStore(c.node_id, (n0d) => {
   170              n0d.node.metadata.expression = c.expression
   171              return n0d
   172            })
   173            return
   174          case 'update-value': {
   175            await api.putNodeValue(graphId(), c.node_id, c.value);
   176            applyNodeDataChangeToStore(c.node_id, (n0d) => {
   177              n0d.refreshedAt = new Date()
   178              return n0d
   179            })
   180            return
   181          }
   182          case 'update-position':
   183            await api.patchNode(graphId(), c.node_id, { position_x: c.position.x, position_y: c.position.y })
   184            applyNodeDataChangeToStore(c.node_id, (n0d) => {
   185              n0d.node.metadata.position_x = c.position.x;
   186              n0d.node.metadata.position_y = c.position.y;
   187              return n0d
   188            })
   189            return
   190          case 'update-size':
   191            await api.patchNode(graphId(), c.node_id, { display_height: c.height, display_width: c.width })
   192            applyNodeDataChangeToStore(c.node_id, (n0d) => {
   193              n0d.node.metadata.height = c.height;
   194              n0d.node.metadata.width = c.width;
   195              return n0d
   196            })
   197            return
   198          case 'update-collapsed-all':
   199            await api.patchNodes(graphId(), { collapsed: c.collapsed });
   200            set({
   201              nodes: get().nodes.map(n => {
   202                const n0d = Object.assign({}, n.data)
   203                n0d.node.metadata.collapsed = c.collapsed
   204                n.data = n0d
   205                return n
   206              })
   207            })
   208            return
   209          case 'update-collapsed':
   210            await api.patchNode(graphId(), c.node_id, { collapsed: c.collapsed });
   211            applyNodeDataChangeToStore(c.node_id, (n0d) => {
   212              n0d.node.metadata.collapsed = c.collapsed
   213              return n0d
   214            })
   215            return
   216          case 'update-watched':
   217            await api.patchNode(graphId(), c.node_id, { watched: c.watched });
   218            applyNodeDataChangeToStore(c.node_id, (n0d) => {
   219              n0d.node.metadata.watched = c.watched
   220              return n0d
   221            })
   222            if (c.watched) {
   223              const newWatched = get().nodes.find(n => n.id === c.node_id)?.data.node
   224              if (newWatched === undefined) {
   225                return
   226              }
   227              set({
   228                watchedNodes: [...get().watchedNodes, newWatched],
   229              })
   230            } else {
   231              set({
   232                watchedNodes: [...get().watchedNodes.filter(n => n.id !== c.node_id)],
   233              })
   234            }
   235  
   236            return
   237          case 'update-watched-collapsed':
   238            await api.patchNode(graphId(), c.node_id, { watched_collapsed: c.watched_collapsed });
   239            applyNodeDataChangeToStore(c.node_id, (n0d) => {
   240              n0d.node.metadata.watched_collapsed = c.watched_collapsed
   241              return n0d
   242            })
   243            return
   244          case 'update-graph-viewport': {
   245            await graphApi.patchGraph(c.graph_id, {
   246              "viewport_x": c.viewport.x,
   247              "viewport_y": c.viewport.y,
   248              "viewport_zoom": c.viewport.zoom,
   249            });
   250            const { graph, graphs } = get();
   251            if (graph !== null && graph.id === c.graph_id) {
   252              set({
   253                graph: {
   254                  ...graph,
   255                  metadata: {
   256                    ...graph?.metadata,
   257                    viewport_x: c.viewport.x,
   258                    viewport_y: c.viewport.y,
   259                    viewport_zoom: c.viewport.zoom,
   260                  },
   261                },
   262                graphs: [...graphs.map(g => {
   263                  if (g.id === c.graph_id) {
   264                    const newObj = Object.assign({}, g)
   265                    newObj.metadata.viewport_x = c.viewport.x
   266                    newObj.metadata.viewport_y = c.viewport.y
   267                    newObj.metadata.viewport_zoom = c.viewport.zoom
   268                    return newObj
   269                  }
   270                  return g
   271                })],
   272              })
   273            }
   274            return
   275          }
   276          case 'update-graph-label': {
   277            await graphApi.patchGraph(c.graph_id, {
   278              "label": c.label,
   279            });
   280            const { graph, graphs } = get();
   281            if (graph !== null && graph.id === c.graph_id) {
   282              set({
   283                graph: {
   284                  ...graph,
   285                  label: c.label,
   286                },
   287                graphs: [...graphs.map(g => {
   288                  if (g.id === c.graph_id) {
   289                    const newObj = Object.assign({}, g)
   290                    newObj.label = c.label
   291                    return newObj
   292                  }
   293                  return g
   294                })],
   295              })
   296            }
   297            return
   298          }
   299          case 'delete-graph': {
   300            await graphApi.deleteGraph(c.graph_id);
   301            const [graph, graphs] = await Promise.all([
   302              graphApi.getGraphActive(),
   303              graphApi.getGraphs(),
   304            ]);
   305            set({
   306              graph: graph,
   307              graphs: graphs,
   308            })
   309          }
   310        }
   311      }))
   312    }
   313  
   314    const onNodesChange = async (changes: NodeChange[]) => {
   315      set({
   316        nodes: applyNodeChanges(changes, get().nodes),
   317        edges: changes.filter(c => c.type == 'remove').reduce((acc, change) => {
   318          const removedNode = change as NodeRemoveChange
   319          const affectedNode = get().nodes.find(fn => fn.id === removedNode.id)
   320          if (affectedNode) {
   321            const connectedEdges = getConnectedEdges([affectedNode], get().edges);
   322            return acc.filter((edge) => !connectedEdges.includes(edge));
   323          }
   324          return acc;
   325        }, get().edges)
   326      });
   327    }
   328  
   329    const onEdgesChange = (changes: EdgeChange[]) => {
   330      set({
   331        edges: applyEdgeChanges(changes, get().edges),
   332      });
   333    }
   334  
   335    const onConnect = async (connection: Connection) => {
   336      const targetNode = get().nodes.find(n => n.id === connection.target);
   337      if (!targetNode) {
   338        return
   339      }
   340      const sourceNode = get().nodes.find(n => n.id === connection.source);
   341      if (!sourceNode) {
   342        return
   343      }
   344  
   345      const targetNodeType = nodeTypes[targetNode.data.node.metadata.node_type];
   346      const targetNodeHandle = targetNodeType.inputs.find(i => i.id === connection.targetHandle)
   347      if (!targetNodeHandle) {
   348        return
   349      }
   350  
   351      if (targetNodeHandle.multiplicity === 'single') {
   352        // we have to _unlink_ any existing inputs here first!
   353        // find the existing edge.
   354        const existingEdge = get().edges.find(e => e.target === connection.target && e.targetHandle === connection.targetHandle)
   355        if (existingEdge !== undefined) {
   356          await onApiEffect([{
   357            type: 'unlink-nodes',
   358            parent_id: existingEdge.source,
   359            child_id: existingEdge.target,
   360            child_input_name: existingEdge.targetHandle === undefined ? null : existingEdge.targetHandle,
   361          }])
   362        }
   363      }
   364      await onApiEffect([{
   365        type: 'link-nodes',
   366        parent_id: connection.source || '',
   367        child_id: connection.target || '',
   368        child_input_name: connection.targetHandle === undefined ? null : connection.targetHandle,
   369      }])
   370    }
   371  
   372    const onDisconnect = async (edge: Edge) => {
   373      await api.deleteEdge(graphId(), { child_id: edge.target || '', parent_id: edge.source || '', child_input_name: edge.targetHandle === undefined ? null : edge.targetHandle })
   374      set({
   375        edges: get().edges.filter((e) => e.id !== edge.id)
   376      })
   377    }
   378  
   379    const onRefresh = async (graphId: string) => {
   380      set({ isLoading: true })
   381      const graph = await graphApi.getGraph(graphId);
   382      if (graph === null) {
   383        set({
   384          isLoading: false,
   385          graph: graph,
   386        })
   387        return
   388      }
   389      const [graphs, apiNodes, apiEdges] = await Promise.all([
   390        graphApi.getGraphs(),
   391        api.getNodes(graph.id),
   392        api.getEdges(graph.id),
   393      ])
   394      const flowNodes = apiNodes.map(apiNode => flowNodeFromApiNode(apiNode, { onApiEffect, graph }));
   395      const watchedNodes = filterWatchedNodes(apiNodes);
   396      set({
   397        isLoading: false,
   398        graph: graph,
   399        graphs: graphs,
   400        nodes: flowNodes,
   401        watchedNodes: watchedNodes,
   402        edges: flowEdgesFromApiEdges(apiEdges),
   403      })
   404    }
   405  
   406    const onStabilize = async () => {
   407      const workflowRun = await workflowApi.stabilize(graphId(), true);
   408      const workflowRunResult = await workflowApi.pollUntilFinished(workflowRun.workflow_id, workflowRun.run_id)
   409  
   410      // get the graph again to fetch the upated stabilization number
   411      const graph = await graphApi.getGraph(graphId());
   412      set({
   413        nodes: [...get().nodes.map(n => {
   414          const dataCopy = Object.assign({}, n.data)
   415          dataCopy.refreshedAt = new Date()
   416          dataCopy.stabilizationNum = graph.stabilization_num
   417          n.data = dataCopy
   418          return n
   419        })],
   420        graph,
   421      })
   422      if (workflowRunResult.failure !== null) {
   423        throw new Error(workflowRunResult.failure?.cause.message)
   424      }
   425      return workflowRunResult
   426    }
   427  
   428    return {
   429      isLoading: true,
   430      graph: graphApi.emptyGraph,
   431      graphs: [],
   432      nodes: [],
   433      edges: [],
   434      watchedNodes: [],
   435      onApiEffect,
   436      onNodesChange,
   437      onEdgesChange,
   438      onConnect,
   439      onDisconnect,
   440      onRefresh,
   441      onStabilize,
   442      applyNodeDataChangeToStore,
   443    }
   444  }, Object.is);
   445  
   446  const nodeDefaults = {
   447    sourcePosition: Position.Right,
   448    targetPosition: Position.Left,
   449  }
   450  
   451  export function apiNodeFromFlowNodeData(data: NodeData, pos: XYPosition): api.Node {
   452    return {
   453      ...data.node,
   454      metadata: {
   455        ...data.node?.metadata,
   456        position_x: pos.x,
   457        position_y: pos.y,
   458      }
   459    }
   460  }
   461  
   462  export const flowNodeFromApiNode = (node: api.Node, state: Pick<StoreState, 'onApiEffect' | 'graph'>): Node<NodeData> => {
   463    const resizable = node.metadata.node_type === "observer" && node.metadata.input_type === "svg"
   464    return {
   465      ...nodeDefaults,
   466      id: node.id,
   467      position: {
   468        x: node.metadata.position_x,
   469        y: node.metadata.position_y,
   470      },
   471      parentNode: node.metadata.parent_node_id,
   472      type: 'nodeCard',
   473      resizing: resizable,
   474      dragHandle: '.bp5-section-header',
   475      style: {
   476        ...(resizable ? { height: node.metadata.height || 0, width: node.metadata.width } : {})
   477      },
   478      data: {
   479        node: node,
   480        refreshedAt: new Date(),
   481        stabilizationNum: state.graph?.stabilization_num,
   482        getValue: async () => {
   483          return api.getNodeValue(state.graph?.id || '', node.id)
   484        },
   485        onCollapse: async (collapsed: boolean) => {
   486          await state.onApiEffect([{
   487            type: 'update-collapsed',
   488            node_id: node.id,
   489            collapsed: collapsed,
   490          }])
   491        },
   492      },
   493    }
   494  }
   495  
   496  export const defaultEdgeOptions = {
   497    type: 'smoothstep',
   498    animated: true,
   499    markerEnd: {
   500      type: MarkerType.ArrowClosed,
   501    },
   502  }
   503  
   504  export const flowEdgesFromApiEdges = (edges: api.Edge[]): Edge[] => {
   505    const output: Edge[] = []
   506    for (const e of edges) {
   507      output.push(flowEdgeFromApiEdge(e))
   508    }
   509    return output
   510  }
   511  
   512  export const formatEdgeId = (edge: api.Edge): string => {
   513    return `${edge.parent_id}->${edge.child_id}.${edge.child_input_name}`
   514  }
   515  
   516  export const flowEdgeFromApiEdge = (edge: api.Edge): Edge => {
   517    return {
   518      id: formatEdgeId(edge),
   519      source: edge.parent_id,
   520      target: edge.child_id,
   521      targetHandle: edge.child_input_name,
   522      className: `edge`,
   523      ...defaultEdgeOptions,
   524    }
   525  }
   526  
   527  export default useStateStore;