github.com/vescale/zgraph@v0.0.0-20230410094002-959c02d50f95/compiler/macro_expansion.go (about) 1 // Copyright 2022 zGraph Authors. All rights reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package compiler 16 17 import ( 18 "github.com/vescale/zgraph/parser/ast" 19 "github.com/vescale/zgraph/parser/opcode" 20 ) 21 22 // MacroExpansion is used to expand the PathPatternMacros. 23 // 24 // PATH has_parent AS () -[:has_father|has_mother]-> (:Person) 25 // SELECT ancestor.name 26 // FROM MATCH (p1:Person) -/:has_parent+/-> (ancestor) 27 // , MATCH (p2:Person) -/:has_parent+/-> (ancestor) 28 // WHERE p1.name = 'Mario' 29 // AND p2.name = 'Luigi' 30 // 31 // The MacroExpansion will replace the `has_parent` macro. 32 type MacroExpansion struct { 33 macros []*ast.PathPatternMacro 34 mapping map[string]*ast.PathPatternMacro 35 wheres map[ast.ExprNode]struct{} 36 } 37 38 func NewMacroExpansion() *MacroExpansion { 39 return &MacroExpansion{ 40 mapping: map[string]*ast.PathPatternMacro{}, 41 wheres: map[ast.ExprNode]struct{}{}, 42 } 43 } 44 45 func (m *MacroExpansion) Enter(n ast.Node) (node ast.Node, skipChildren bool) { 46 switch stmt := n.(type) { 47 case *ast.InsertStmt: 48 m.macros = stmt.PathPatternMacros 49 case *ast.UpdateStmt: 50 m.macros = stmt.PathPatternMacros 51 case *ast.DeleteStmt: 52 m.macros = stmt.PathPatternMacros 53 case *ast.SelectStmt: 54 m.macros = stmt.PathPatternMacros 55 case *ast.MatchClauseList: 56 return m.macroExpansion(n.(*ast.MatchClauseList)) 57 default: 58 return n, true 59 } 60 61 // We skip its children if the statement doesn't have macro definitions. 62 return n, len(m.macros) == 0 63 } 64 65 func (m *MacroExpansion) macroExpansion(matchList *ast.MatchClauseList) (node ast.Node, skipChildren bool) { 66 if len(m.macros) != len(m.mapping) { 67 for _, macro := range m.macros { 68 m.mapping[macro.Name.L] = macro 69 } 70 } 71 72 detected := map[int] /*matchIndex*/ map[int] /*pathIndex*/ []int{} 73 for matchIndex, matchClause := range matchList.Matches { 74 for pathIndex, path := range matchClause.Paths { 75 for connIndex, conn := range path.Connections { 76 // Ref: https://pgql-lang.org/spec/1.5/#path-pattern-macros 77 // One or more “path pattern macros” may be declared at the beginning of the query. 78 // These macros allow for expressing complex regular expressions. PGQL 1.5 allows 79 // macros only for reachability, not for (top-k) shortest path. 80 reachabilityPathExpr, ok := conn.(*ast.ReachabilityPathExpr) 81 if !ok { 82 continue 83 } 84 85 var found bool 86 for _, label := range reachabilityPathExpr.Labels { 87 _, ok = m.mapping[label.L] 88 found = ok || found 89 } 90 if found { 91 pathGroup, ok := detected[matchIndex] 92 if !ok { 93 pathGroup = map[int][]int{} 94 detected[matchIndex] = pathGroup 95 } 96 pathGroup[pathIndex] = append(pathGroup[pathIndex], connIndex) 97 } 98 } 99 } 100 } 101 102 if len(detected) == 0 { 103 return matchList, true 104 } 105 106 // Shallow copy the match clause list. 107 newMatchList := &ast.MatchClauseList{} 108 *newMatchList = *matchList 109 newMatchList.Matches = make([]*ast.MatchClause, 0, len(matchList.Matches)) 110 newMatchList.Matches = append(newMatchList.Matches, matchList.Matches...) 111 112 for matchIndex, pathGroup := range detected { 113 oldMatch := matchList.Matches[matchIndex] 114 newMatch := &ast.MatchClause{} 115 *newMatch = *oldMatch 116 newMatch.Paths = make([]*ast.PathPattern, 0, len(oldMatch.Paths)) 117 newMatch.Paths = append(newMatch.Paths, oldMatch.Paths...) 118 newMatchList.Matches[matchIndex] = newMatch 119 for pathIndex, connGroup := range pathGroup { 120 oldPath := oldMatch.Paths[pathIndex] 121 newPath := &ast.PathPattern{} 122 *newPath = *oldPath 123 newPath.Connections = make([]ast.VertexPairConnection, 0, len(oldPath.Connections)) 124 newPath.Connections = append(newPath.Connections, oldPath.Connections...) 125 newMatch.Paths[pathIndex] = newPath 126 for _, connIndex := range connGroup { 127 oldConn := oldPath.Connections[connIndex].(*ast.ReachabilityPathExpr) 128 newConn := &ast.ReachabilityPathExpr{} 129 *newConn = *oldConn 130 newConn.Macros = map[string]*ast.PathPattern{} 131 for _, label := range newConn.Labels { 132 macro, found := m.mapping[label.L] 133 if !found { 134 continue 135 } 136 newConn.Macros[label.L] = macro.Path 137 if macro.Where != nil { 138 m.wheres[macro.Where] = struct{}{} 139 } 140 } 141 newPath.Connections[connIndex] = newConn 142 } 143 } 144 } 145 146 return newMatchList, true 147 } 148 149 func (m *MacroExpansion) Leave(n ast.Node) (node ast.Node, ok bool) { 150 if len(m.wheres) == 0 { 151 return n, true 152 } 153 154 var cnf ast.ExprNode 155 for expr := range m.wheres { 156 if cnf == nil { 157 cnf = expr 158 continue 159 } 160 cnf = &ast.BinaryExpr{ 161 Op: opcode.LogicAnd, 162 L: cnf, 163 R: expr, 164 } 165 } 166 167 // Attach where expressions. 168 switch stmt := n.(type) { 169 case *ast.InsertStmt: 170 newInsert := &ast.InsertStmt{} 171 *newInsert = *stmt 172 if newInsert.Where != nil { 173 newInsert.Where = &ast.BinaryExpr{ 174 Op: opcode.LogicAnd, 175 L: newInsert.Where, 176 R: cnf, 177 } 178 } else { 179 newInsert.Where = cnf 180 } 181 n = newInsert 182 case *ast.UpdateStmt: 183 newUpdate := &ast.UpdateStmt{} 184 *newUpdate = *stmt 185 if newUpdate.Where != nil { 186 newUpdate.Where = &ast.BinaryExpr{ 187 Op: opcode.LogicAnd, 188 L: newUpdate.Where, 189 R: cnf, 190 } 191 } else { 192 newUpdate.Where = cnf 193 } 194 n = newUpdate 195 case *ast.DeleteStmt: 196 newDelete := &ast.DeleteStmt{} 197 *newDelete = *stmt 198 if newDelete.Where != nil { 199 newDelete.Where = &ast.BinaryExpr{ 200 Op: opcode.LogicAnd, 201 L: newDelete.Where, 202 R: cnf, 203 } 204 } else { 205 newDelete.Where = cnf 206 } 207 n = newDelete 208 case *ast.SelectStmt: 209 newSelect := &ast.SelectStmt{} 210 *newSelect = *stmt 211 if newSelect.Where != nil { 212 newSelect.Where = &ast.BinaryExpr{ 213 Op: opcode.LogicAnd, 214 L: newSelect.Where, 215 R: cnf, 216 } 217 } else { 218 newSelect.Where = cnf 219 } 220 n = newSelect 221 default: 222 return n, true 223 } 224 225 return n, true 226 }