github.com/vescale/zgraph@v0.0.0-20230410094002-959c02d50f95/compiler/macro_expansion.go (about)

     1  // Copyright 2022 zGraph Authors. All rights reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package compiler
    16  
    17  import (
    18  	"github.com/vescale/zgraph/parser/ast"
    19  	"github.com/vescale/zgraph/parser/opcode"
    20  )
    21  
    22  // MacroExpansion is used to expand the PathPatternMacros.
    23  //
    24  //	 PATH has_parent AS () -[:has_father|has_mother]-> (:Person)
    25  //	 SELECT ancestor.name
    26  //		  FROM MATCH (p1:Person) -/:has_parent+/-> (ancestor)
    27  //			 , MATCH (p2:Person) -/:has_parent+/-> (ancestor)
    28  //		 WHERE p1.name = 'Mario'
    29  //		   AND p2.name = 'Luigi'
    30  //
    31  // The MacroExpansion will replace the `has_parent` macro.
    32  type MacroExpansion struct {
    33  	macros  []*ast.PathPatternMacro
    34  	mapping map[string]*ast.PathPatternMacro
    35  	wheres  map[ast.ExprNode]struct{}
    36  }
    37  
    38  func NewMacroExpansion() *MacroExpansion {
    39  	return &MacroExpansion{
    40  		mapping: map[string]*ast.PathPatternMacro{},
    41  		wheres:  map[ast.ExprNode]struct{}{},
    42  	}
    43  }
    44  
    45  func (m *MacroExpansion) Enter(n ast.Node) (node ast.Node, skipChildren bool) {
    46  	switch stmt := n.(type) {
    47  	case *ast.InsertStmt:
    48  		m.macros = stmt.PathPatternMacros
    49  	case *ast.UpdateStmt:
    50  		m.macros = stmt.PathPatternMacros
    51  	case *ast.DeleteStmt:
    52  		m.macros = stmt.PathPatternMacros
    53  	case *ast.SelectStmt:
    54  		m.macros = stmt.PathPatternMacros
    55  	case *ast.MatchClauseList:
    56  		return m.macroExpansion(n.(*ast.MatchClauseList))
    57  	default:
    58  		return n, true
    59  	}
    60  
    61  	// We skip its children if the statement doesn't have macro definitions.
    62  	return n, len(m.macros) == 0
    63  }
    64  
    65  func (m *MacroExpansion) macroExpansion(matchList *ast.MatchClauseList) (node ast.Node, skipChildren bool) {
    66  	if len(m.macros) != len(m.mapping) {
    67  		for _, macro := range m.macros {
    68  			m.mapping[macro.Name.L] = macro
    69  		}
    70  	}
    71  
    72  	detected := map[int] /*matchIndex*/ map[int] /*pathIndex*/ []int{}
    73  	for matchIndex, matchClause := range matchList.Matches {
    74  		for pathIndex, path := range matchClause.Paths {
    75  			for connIndex, conn := range path.Connections {
    76  				// Ref: https://pgql-lang.org/spec/1.5/#path-pattern-macros
    77  				// One or more “path pattern macros” may be declared at the beginning of the query.
    78  				// These macros allow for expressing complex regular expressions. PGQL 1.5 allows
    79  				// macros only for reachability, not for (top-k) shortest path.
    80  				reachabilityPathExpr, ok := conn.(*ast.ReachabilityPathExpr)
    81  				if !ok {
    82  					continue
    83  				}
    84  
    85  				var found bool
    86  				for _, label := range reachabilityPathExpr.Labels {
    87  					_, ok = m.mapping[label.L]
    88  					found = ok || found
    89  				}
    90  				if found {
    91  					pathGroup, ok := detected[matchIndex]
    92  					if !ok {
    93  						pathGroup = map[int][]int{}
    94  						detected[matchIndex] = pathGroup
    95  					}
    96  					pathGroup[pathIndex] = append(pathGroup[pathIndex], connIndex)
    97  				}
    98  			}
    99  		}
   100  	}
   101  
   102  	if len(detected) == 0 {
   103  		return matchList, true
   104  	}
   105  
   106  	// Shallow copy the match clause list.
   107  	newMatchList := &ast.MatchClauseList{}
   108  	*newMatchList = *matchList
   109  	newMatchList.Matches = make([]*ast.MatchClause, 0, len(matchList.Matches))
   110  	newMatchList.Matches = append(newMatchList.Matches, matchList.Matches...)
   111  
   112  	for matchIndex, pathGroup := range detected {
   113  		oldMatch := matchList.Matches[matchIndex]
   114  		newMatch := &ast.MatchClause{}
   115  		*newMatch = *oldMatch
   116  		newMatch.Paths = make([]*ast.PathPattern, 0, len(oldMatch.Paths))
   117  		newMatch.Paths = append(newMatch.Paths, oldMatch.Paths...)
   118  		newMatchList.Matches[matchIndex] = newMatch
   119  		for pathIndex, connGroup := range pathGroup {
   120  			oldPath := oldMatch.Paths[pathIndex]
   121  			newPath := &ast.PathPattern{}
   122  			*newPath = *oldPath
   123  			newPath.Connections = make([]ast.VertexPairConnection, 0, len(oldPath.Connections))
   124  			newPath.Connections = append(newPath.Connections, oldPath.Connections...)
   125  			newMatch.Paths[pathIndex] = newPath
   126  			for _, connIndex := range connGroup {
   127  				oldConn := oldPath.Connections[connIndex].(*ast.ReachabilityPathExpr)
   128  				newConn := &ast.ReachabilityPathExpr{}
   129  				*newConn = *oldConn
   130  				newConn.Macros = map[string]*ast.PathPattern{}
   131  				for _, label := range newConn.Labels {
   132  					macro, found := m.mapping[label.L]
   133  					if !found {
   134  						continue
   135  					}
   136  					newConn.Macros[label.L] = macro.Path
   137  					if macro.Where != nil {
   138  						m.wheres[macro.Where] = struct{}{}
   139  					}
   140  				}
   141  				newPath.Connections[connIndex] = newConn
   142  			}
   143  		}
   144  	}
   145  
   146  	return newMatchList, true
   147  }
   148  
   149  func (m *MacroExpansion) Leave(n ast.Node) (node ast.Node, ok bool) {
   150  	if len(m.wheres) == 0 {
   151  		return n, true
   152  	}
   153  
   154  	var cnf ast.ExprNode
   155  	for expr := range m.wheres {
   156  		if cnf == nil {
   157  			cnf = expr
   158  			continue
   159  		}
   160  		cnf = &ast.BinaryExpr{
   161  			Op: opcode.LogicAnd,
   162  			L:  cnf,
   163  			R:  expr,
   164  		}
   165  	}
   166  
   167  	// Attach where expressions.
   168  	switch stmt := n.(type) {
   169  	case *ast.InsertStmt:
   170  		newInsert := &ast.InsertStmt{}
   171  		*newInsert = *stmt
   172  		if newInsert.Where != nil {
   173  			newInsert.Where = &ast.BinaryExpr{
   174  				Op: opcode.LogicAnd,
   175  				L:  newInsert.Where,
   176  				R:  cnf,
   177  			}
   178  		} else {
   179  			newInsert.Where = cnf
   180  		}
   181  		n = newInsert
   182  	case *ast.UpdateStmt:
   183  		newUpdate := &ast.UpdateStmt{}
   184  		*newUpdate = *stmt
   185  		if newUpdate.Where != nil {
   186  			newUpdate.Where = &ast.BinaryExpr{
   187  				Op: opcode.LogicAnd,
   188  				L:  newUpdate.Where,
   189  				R:  cnf,
   190  			}
   191  		} else {
   192  			newUpdate.Where = cnf
   193  		}
   194  		n = newUpdate
   195  	case *ast.DeleteStmt:
   196  		newDelete := &ast.DeleteStmt{}
   197  		*newDelete = *stmt
   198  		if newDelete.Where != nil {
   199  			newDelete.Where = &ast.BinaryExpr{
   200  				Op: opcode.LogicAnd,
   201  				L:  newDelete.Where,
   202  				R:  cnf,
   203  			}
   204  		} else {
   205  			newDelete.Where = cnf
   206  		}
   207  		n = newDelete
   208  	case *ast.SelectStmt:
   209  		newSelect := &ast.SelectStmt{}
   210  		*newSelect = *stmt
   211  		if newSelect.Where != nil {
   212  			newSelect.Where = &ast.BinaryExpr{
   213  				Op: opcode.LogicAnd,
   214  				L:  newSelect.Where,
   215  				R:  cnf,
   216  			}
   217  		} else {
   218  			newSelect.Where = cnf
   219  		}
   220  		n = newSelect
   221  	default:
   222  		return n, true
   223  	}
   224  
   225  	return n, true
   226  }