github.com/alkemics/goflow@v0.2.1/wrappers/ctx/wrapper.go (about)

     1  package ctx
     2  
     3  import (
     4  	"fmt"
     5  
     6  	"github.com/alkemics/goflow"
     7  )
     8  
     9  const ctxNodeID = "__ctx"
    10  
    11  type graphRenderer struct {
    12  	goflow.GraphRenderer
    13  
    14  	imports []goflow.Import
    15  	inputs  []goflow.Field
    16  	nodes   []goflow.NodeRenderer
    17  }
    18  
    19  func (g graphRenderer) Imports() []goflow.Import     { return g.imports }
    20  func (g graphRenderer) Inputs() []goflow.Field       { return g.inputs }
    21  func (g graphRenderer) Nodes() []goflow.NodeRenderer { return g.nodes }
    22  
    23  type nodeRenderer struct {
    24  	goflow.NodeRenderer
    25  
    26  	inputs []goflow.Field
    27  }
    28  
    29  func (g nodeRenderer) Inputs() []goflow.Field { return g.inputs }
    30  
    31  type ctxNode struct{}
    32  
    33  func (n ctxNode) ID() string                   { return ctxNodeID }
    34  func (n ctxNode) Previous() []string           { return nil }
    35  func (n ctxNode) Imports() []goflow.Import     { return []goflow.Import{{Pkg: "context", Dir: "context"}} }
    36  func (n ctxNode) Doc() string                  { return "" }
    37  func (n ctxNode) Dependencies() []goflow.Field { return nil }
    38  func (n ctxNode) Inputs() []goflow.Field       { return nil }
    39  func (n ctxNode) Outputs() []goflow.Field {
    40  	return []goflow.Field{{Name: "ctx", Type: "context.Context"}}
    41  }
    42  
    43  func (n ctxNode) Run(_, outputs []goflow.Field) (string, error) {
    44  	// ctx must come from the inputs
    45  	return fmt.Sprintf("%s = ctx", outputs[0].Name), nil
    46  }
    47  
    48  func Wrapper(_ func(interface{}) error, graph goflow.GraphRenderer) (goflow.GraphRenderer, error) {
    49  	nodes := graph.Nodes()
    50  	wrappedNodes := make([]goflow.NodeRenderer, len(nodes))
    51  	addCtx := false
    52  	for i, node := range nodes {
    53  		inputs := node.Inputs()
    54  		if len(inputs) == 0 || inputs[0].Name != "ctx" || inputs[0].Type != "context.Context" {
    55  			wrappedNodes[i] = node
    56  			continue
    57  		}
    58  
    59  		// Here, we will need to add the context in there. We also need to remap the
    60  		// ctx input to inputs.ctx
    61  		addCtx = true
    62  		nodeInputs := make([]goflow.Field, len(inputs))
    63  		copy(nodeInputs, inputs)
    64  		nodeInputs[0].Name = fmt.Sprintf("%s.ctx", ctxNodeID)
    65  		wrappedNodes[i] = nodeRenderer{
    66  			NodeRenderer: node,
    67  			inputs:       nodeInputs,
    68  		}
    69  	}
    70  
    71  	inputs := graph.Inputs()
    72  	imports := graph.Imports()
    73  	if addCtx && (len(inputs) == 0 || inputs[0].Name != "ctx" || inputs[0].Type != "context.Context") {
    74  		inputs = append(
    75  			[]goflow.Field{{Name: "ctx", Type: "context.Context"}},
    76  			inputs...,
    77  		)
    78  
    79  		imports = append(imports, goflow.Import{Pkg: "context", Dir: "context"})
    80  		wrappedNodes = append(wrappedNodes, ctxNode{})
    81  	}
    82  
    83  	return graphRenderer{
    84  		GraphRenderer: graph,
    85  
    86  		imports: imports,
    87  		inputs:  inputs,
    88  		nodes:   wrappedNodes,
    89  	}, nil
    90  }