github.com/vescale/zgraph@v0.0.0-20230410094002-959c02d50f95/compiler/preprocess.go (about)

     1  // Copyright 2022 zGraph Authors. All rights reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package compiler
    16  
    17  import (
    18  	"github.com/pingcap/errors"
    19  	"github.com/vescale/zgraph/catalog"
    20  	"github.com/vescale/zgraph/meta"
    21  	"github.com/vescale/zgraph/parser/ast"
    22  	"github.com/vescale/zgraph/stmtctx"
    23  )
    24  
    25  // Preprocess is used to validate the AST to ensure the AST is valid.
    26  type Preprocess struct {
    27  	sc  *stmtctx.Context
    28  	err error
    29  }
    30  
    31  // NewPreprocess returns a preprocess visitor.
    32  func NewPreprocess(sc *stmtctx.Context) *Preprocess {
    33  	return &Preprocess{
    34  		sc: sc,
    35  	}
    36  }
    37  
    38  // Enter implements the ast.Visitor interface.
    39  func (p *Preprocess) Enter(n ast.Node) (node ast.Node, skipChildren bool) {
    40  	switch stmt := n.(type) {
    41  	case *ast.CreateGraphStmt:
    42  		p.checkCreateGraphStmt(stmt)
    43  	case *ast.CreateLabelStmt:
    44  		p.checkCreateLabelStmt(stmt)
    45  	case *ast.CreateIndexStmt:
    46  		p.checkCreateIndexStmt(stmt)
    47  	case *ast.DropGraphStmt:
    48  		p.checkDropGraphStmt(stmt)
    49  	case *ast.DropLabelStmt:
    50  		p.checkDropLabelStmt(stmt)
    51  	case *ast.DropIndexStmt:
    52  		p.checkDropIndexStmt(stmt)
    53  	case *ast.UseStmt:
    54  		p.checkUseStmt(stmt)
    55  	case *ast.InsertStmt:
    56  		p.checkInsertStmt(stmt)
    57  	case *ast.SelectStmt:
    58  		p.checkSelectStmt(stmt)
    59  	case *ast.ShowStmt:
    60  		p.checkShowStmt(stmt)
    61  	case *ast.GroupByClause:
    62  		// FIXME: reject all unsupported clauses/expressions
    63  		p.err = errors.Errorf("unsupported clause/expression: %T", n)
    64  	}
    65  	return n, p.err != nil
    66  }
    67  
    68  // Leave implements the ast.Visitor interface.
    69  func (p *Preprocess) Leave(n ast.Node) (node ast.Node, ok bool) {
    70  	return n, p.err != nil
    71  }
    72  
    73  // Error returns the internal error of preprocess.
    74  func (p *Preprocess) Error() error {
    75  	return p.err
    76  }
    77  
    78  func (p *Preprocess) checkCreateGraphStmt(stmt *ast.CreateGraphStmt) {
    79  	if isIncorrectName(stmt.Graph.L) {
    80  		p.err = ErrIncorrectGraphName
    81  		return
    82  	}
    83  
    84  	if !stmt.IfNotExists {
    85  		graph := p.sc.Catalog().Graph(stmt.Graph.L)
    86  		if graph != nil {
    87  			p.err = meta.ErrGraphExists
    88  			return
    89  		}
    90  	}
    91  }
    92  
    93  func (p *Preprocess) checkCreateLabelStmt(stmt *ast.CreateLabelStmt) {
    94  	if isIncorrectName(stmt.Label.L) {
    95  		p.err = ErrIncorrectLabelName
    96  		return
    97  	}
    98  
    99  	if !stmt.IfNotExists {
   100  		// An ErrGraphNotChosen are expected if users didn't choose graph via USE <graphName>
   101  		graphName := p.sc.CurrentGraphName()
   102  		if graphName == "" {
   103  			p.err = ErrGraphNotChosen
   104  			return
   105  		}
   106  		// change it from CurrentGraph to Catalog.Graph,  this will reduce the overhead of a lock,
   107  		// and since we got the name above, there is no need to use CurrentGraph
   108  		graph := p.sc.Catalog().Graph(graphName)
   109  		if graph == nil {
   110  			p.err = meta.ErrGraphNotExists
   111  			return
   112  		}
   113  		label := graph.Label(stmt.Label.L)
   114  		if label != nil {
   115  			p.err = meta.ErrLabelExists
   116  			return
   117  		}
   118  	}
   119  }
   120  
   121  func (p *Preprocess) checkCreateIndexStmt(stmt *ast.CreateIndexStmt) {
   122  	if isIncorrectName(stmt.IndexName.L) {
   123  		p.err = ErrIncorrectIndexName
   124  		return
   125  	}
   126  
   127  	graphName := p.sc.CurrentGraphName()
   128  	if graphName == "" {
   129  		p.err = ErrGraphNotChosen
   130  		return
   131  	}
   132  
   133  	graph := p.sc.Catalog().Graph(graphName)
   134  	if graph == nil {
   135  		p.err = meta.ErrGraphNotExists
   136  		return
   137  	}
   138  
   139  	index := graph.Index(stmt.IndexName.L)
   140  	if index != nil && !stmt.IfNotExists {
   141  		p.err = meta.ErrIndexExists
   142  		return
   143  	}
   144  
   145  	for _, prop := range stmt.Properties {
   146  		property := graph.Property(prop.L)
   147  		if property == nil {
   148  			p.err = errors.Annotatef(meta.ErrPropertyNotExists, "property %s", prop.L)
   149  			return
   150  		}
   151  	}
   152  }
   153  
   154  func (p *Preprocess) checkDropGraphStmt(stmt *ast.DropGraphStmt) {
   155  	graph := p.sc.Catalog().Graph(stmt.Graph.L)
   156  	if graph == nil && !stmt.IfExists {
   157  		p.err = meta.ErrGraphNotExists
   158  		return
   159  	}
   160  }
   161  
   162  func (p *Preprocess) checkDropLabelStmt(stmt *ast.DropLabelStmt) {
   163  	graph := p.sc.CurrentGraph()
   164  	if graph == nil && !stmt.IfExists {
   165  		p.err = meta.ErrGraphNotExists
   166  		return
   167  	}
   168  	label := graph.Label(stmt.Label.L)
   169  	if label == nil && !stmt.IfExists {
   170  		p.err = meta.ErrLabelNotExists
   171  		return
   172  	}
   173  }
   174  
   175  func (p *Preprocess) checkDropIndexStmt(stmt *ast.DropIndexStmt) {
   176  	graph := p.sc.CurrentGraph()
   177  	if graph == nil && !stmt.IfExists {
   178  		p.err = meta.ErrGraphNotExists
   179  		return
   180  	}
   181  	index := graph.Index(stmt.IndexName.L)
   182  	if index == nil && !stmt.IfExists {
   183  		p.err = meta.ErrIndexNotExists
   184  		return
   185  	}
   186  }
   187  
   188  func (p *Preprocess) checkUseStmt(stmt *ast.UseStmt) {
   189  	if isIncorrectName(stmt.GraphName.L) {
   190  		p.err = ErrIncorrectGraphName
   191  		return
   192  	}
   193  
   194  	graph := p.sc.Catalog().Graph(stmt.GraphName.L)
   195  	if graph == nil {
   196  		p.err = meta.ErrGraphNotExists
   197  		return
   198  	}
   199  }
   200  
   201  // isIncorrectName checks if the identifier is incorrect.
   202  // See https://dev.mysql.com/doc/refman/5.7/en/identifiers.html
   203  func isIncorrectName(name string) bool {
   204  	if len(name) == 0 {
   205  		return true
   206  	}
   207  	if name[len(name)-1] == ' ' {
   208  		return true
   209  	}
   210  	return false
   211  }
   212  
   213  func (p *Preprocess) checkInsertStmt(stmt *ast.InsertStmt) {
   214  	var intoGraph string
   215  	if !stmt.IntoGraphName.IsEmpty() {
   216  		if isIncorrectName(stmt.IntoGraphName.L) {
   217  			p.err = ErrIncorrectGraphName
   218  			return
   219  		}
   220  		intoGraph = stmt.IntoGraphName.L
   221  	}
   222  	var graph *catalog.Graph
   223  	if intoGraph == "" {
   224  		graph = p.sc.CurrentGraph()
   225  	} else {
   226  		graph = p.sc.Catalog().Graph(intoGraph)
   227  	}
   228  	if graph == nil {
   229  		p.err = errors.Annotatef(meta.ErrGraphNotExists, "graph: %s", intoGraph)
   230  		return
   231  	}
   232  
   233  	// Make sure all labels exists.
   234  	for _, insertion := range stmt.Insertions {
   235  		variables := map[string]struct{}{}
   236  		switch insertion.InsertionType {
   237  		case ast.InsertionTypeVertex:
   238  			if !insertion.VariableName.IsEmpty() {
   239  				variables[insertion.VariableName.L] = struct{}{}
   240  			}
   241  		case ast.InsertionTypeEdge:
   242  			if !insertion.VariableName.IsEmpty() {
   243  				variables[insertion.VariableName.L] = struct{}{}
   244  			}
   245  			variables[insertion.From.L] = struct{}{}
   246  			variables[insertion.From.L] = struct{}{}
   247  		}
   248  
   249  		// All labels reference variable names need exists.
   250  		lps := insertion.LabelsAndProperties
   251  		if len(lps.Labels) > 0 {
   252  			for _, lbl := range lps.Labels {
   253  				label := graph.Label(lbl.L)
   254  				if label == nil {
   255  					p.err = errors.Annotatef(meta.ErrLabelNotExists, "label: %s", lbl.L)
   256  					return
   257  				}
   258  			}
   259  		}
   260  		if len(lps.Assignments) > 0 {
   261  			for _, a := range lps.Assignments {
   262  				_, ok := variables[a.PropertyAccess.VariableName.L]
   263  				if !ok {
   264  					p.err = errors.Annotatef(ErrVariableReferenceNotExits, "variable: %s", a.PropertyAccess.VariableName.L)
   265  					return
   266  				}
   267  			}
   268  		}
   269  	}
   270  }
   271  
   272  func (p *Preprocess) checkSelectStmt(_ *ast.SelectStmt) {}
   273  
   274  func (p *Preprocess) checkShowStmt(stmt *ast.ShowStmt) {
   275  	if stmt.Tp == ast.ShowTargetLabels && stmt.GraphName.IsEmpty() {
   276  		graph := p.sc.CurrentGraph()
   277  		if graph == nil {
   278  			p.err = meta.ErrNoGraphSelected
   279  			return
   280  		}
   281  	}
   282  }