github.com/alkemics/goflow@v0.2.1/wrappers/ctx/wrapper.go (about) 1 package ctx 2 3 import ( 4 "fmt" 5 6 "github.com/alkemics/goflow" 7 ) 8 9 const ctxNodeID = "__ctx" 10 11 type graphRenderer struct { 12 goflow.GraphRenderer 13 14 imports []goflow.Import 15 inputs []goflow.Field 16 nodes []goflow.NodeRenderer 17 } 18 19 func (g graphRenderer) Imports() []goflow.Import { return g.imports } 20 func (g graphRenderer) Inputs() []goflow.Field { return g.inputs } 21 func (g graphRenderer) Nodes() []goflow.NodeRenderer { return g.nodes } 22 23 type nodeRenderer struct { 24 goflow.NodeRenderer 25 26 inputs []goflow.Field 27 } 28 29 func (g nodeRenderer) Inputs() []goflow.Field { return g.inputs } 30 31 type ctxNode struct{} 32 33 func (n ctxNode) ID() string { return ctxNodeID } 34 func (n ctxNode) Previous() []string { return nil } 35 func (n ctxNode) Imports() []goflow.Import { return []goflow.Import{{Pkg: "context", Dir: "context"}} } 36 func (n ctxNode) Doc() string { return "" } 37 func (n ctxNode) Dependencies() []goflow.Field { return nil } 38 func (n ctxNode) Inputs() []goflow.Field { return nil } 39 func (n ctxNode) Outputs() []goflow.Field { 40 return []goflow.Field{{Name: "ctx", Type: "context.Context"}} 41 } 42 43 func (n ctxNode) Run(_, outputs []goflow.Field) (string, error) { 44 // ctx must come from the inputs 45 return fmt.Sprintf("%s = ctx", outputs[0].Name), nil 46 } 47 48 func Wrapper(_ func(interface{}) error, graph goflow.GraphRenderer) (goflow.GraphRenderer, error) { 49 nodes := graph.Nodes() 50 wrappedNodes := make([]goflow.NodeRenderer, len(nodes)) 51 addCtx := false 52 for i, node := range nodes { 53 inputs := node.Inputs() 54 if len(inputs) == 0 || inputs[0].Name != "ctx" || inputs[0].Type != "context.Context" { 55 wrappedNodes[i] = node 56 continue 57 } 58 59 // Here, we will need to add the context in there. We also need to remap the 60 // ctx input to inputs.ctx 61 addCtx = true 62 nodeInputs := make([]goflow.Field, len(inputs)) 63 copy(nodeInputs, inputs) 64 nodeInputs[0].Name = fmt.Sprintf("%s.ctx", ctxNodeID) 65 wrappedNodes[i] = nodeRenderer{ 66 NodeRenderer: node, 67 inputs: nodeInputs, 68 } 69 } 70 71 inputs := graph.Inputs() 72 imports := graph.Imports() 73 if addCtx && (len(inputs) == 0 || inputs[0].Name != "ctx" || inputs[0].Type != "context.Context") { 74 inputs = append( 75 []goflow.Field{{Name: "ctx", Type: "context.Context"}}, 76 inputs..., 77 ) 78 79 imports = append(imports, goflow.Import{Pkg: "context", Dir: "context"}) 80 wrappedNodes = append(wrappedNodes, ctxNode{}) 81 } 82 83 return graphRenderer{ 84 GraphRenderer: graph, 85 86 imports: imports, 87 inputs: inputs, 88 nodes: wrappedNodes, 89 }, nil 90 }