github.com/dolthub/go-mysql-server@v0.18.0/sql/analyzer/replace_window_names.go (about)

     1  // Copyright 2022 DoltHub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package analyzer
    16  
    17  import (
    18  	"github.com/dolthub/go-mysql-server/sql"
    19  	"github.com/dolthub/go-mysql-server/sql/expression"
    20  	"github.com/dolthub/go-mysql-server/sql/plan"
    21  	"github.com/dolthub/go-mysql-server/sql/transform"
    22  )
    23  
    24  // replaceNamedWindows will 1) extract window definitions from a *plan.NamedWindows node,
    25  // 2) resolve window name references, 3) embed resolved window definitions in sql.Window clauses
    26  // (currently in expression.UnresolvedFunction instances), and 4) replace the plan.NamedWindows
    27  // node with its child *plan.Window.
    28  func replaceNamedWindows(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector) (sql.Node, transform.TreeIdentity, error) {
    29  	return transform.Node(n, func(n sql.Node) (sql.Node, transform.TreeIdentity, error) {
    30  		switch n.(type) {
    31  		case *plan.NamedWindows:
    32  			wn, ok := n.(*plan.NamedWindows)
    33  			if !ok {
    34  				return n, transform.SameTree, nil
    35  			}
    36  
    37  			window, ok := wn.Child.(*plan.Window)
    38  			if !ok {
    39  				return n, transform.SameTree, nil
    40  			}
    41  
    42  			err := checkCircularWindowDef(wn.WindowDefs)
    43  			if err != nil {
    44  				return nil, transform.SameTree, err
    45  			}
    46  
    47  			// find and replace over expressions with new window definitions
    48  			// over sql.Windows are in unresolved aggregation functions
    49  			newExprs := make([]sql.Expression, len(window.SelectExprs))
    50  			same := transform.SameTree
    51  			for i, expr := range window.SelectExprs {
    52  				newExprs[i], _, err = transform.Expr(expr, func(e sql.Expression) (sql.Expression, transform.TreeIdentity, error) {
    53  					uf, ok := e.(*expression.UnresolvedFunction)
    54  					if !ok {
    55  						return e, transform.SameTree, nil
    56  					}
    57  					if uf.Window == nil {
    58  						return e, transform.SameTree, nil
    59  					}
    60  					newWindow, sameDef, err := resolveWindowDef(uf.Window, wn.WindowDefs)
    61  					if err != nil {
    62  						return nil, transform.SameTree, err
    63  					}
    64  					same = same && sameDef
    65  					if sameDef {
    66  						return expr, transform.SameTree, nil
    67  					}
    68  					return uf.WithWindow(newWindow), transform.NewTree, nil
    69  				})
    70  				if err != nil {
    71  					return nil, transform.SameTree, err
    72  				}
    73  			}
    74  			if same {
    75  				return window, transform.SameTree, nil
    76  			}
    77  			return plan.NewWindow(newExprs, window.Child), transform.NewTree, nil
    78  		}
    79  		return n, transform.SameTree, nil
    80  	})
    81  }
    82  
    83  // checkCircularWindowDef verifies that window references terminate
    84  // with concrete definitions. We use a linked-list algorithm
    85  // because a sql.WindowDefinition can have at most one [Ref].
    86  func checkCircularWindowDef(windowDefs map[string]*sql.WindowDefinition) error {
    87  	var head, tail *sql.WindowDefinition
    88  	for _, def := range windowDefs {
    89  		if def.Ref == "" {
    90  			continue
    91  		}
    92  		head = def
    93  		head = windowDefs[head.Ref]
    94  		tail = def
    95  		for head != nil && tail != nil && head != tail {
    96  			tail = windowDefs[tail.Ref]
    97  			head = windowDefs[head.Ref]
    98  			if head != nil {
    99  				head = windowDefs[head.Ref]
   100  			}
   101  		}
   102  		if head != nil && head == tail {
   103  			return sql.ErrCircularWindowInheritance.New()
   104  		}
   105  	}
   106  	return nil
   107  }
   108  
   109  // resolveWindowDef uses DFS to walk the [windowDefs] adjacency list, resolving and merging
   110  // all named windows required to define the topmost window of concern.
   111  // A WindowDefinition is considered resolved when its [Ref] is empty. Otherwise, we recurse
   112  // to define that Ref'd window, before finally merging the resolved ref with the original window
   113  // definition.
   114  // A sql.WindowDef can have at most one named reference.
   115  // We cache merged definitions in [windowDefs] to aid subsequent lookups.
   116  func resolveWindowDef(n *sql.WindowDefinition, windowDefs map[string]*sql.WindowDefinition) (*sql.WindowDefinition, transform.TreeIdentity, error) {
   117  	// base case
   118  	if n.Ref == "" {
   119  		return n, transform.SameTree, nil
   120  	}
   121  
   122  	var err error
   123  	ref, ok := windowDefs[n.Ref]
   124  	if !ok {
   125  		return nil, transform.SameTree, sql.ErrUnknownWindowName.New(n.Ref)
   126  	}
   127  
   128  	// recursively resolve [n.Ref]
   129  	ref, _, err = resolveWindowDef(ref, windowDefs)
   130  	if err != nil {
   131  		return nil, transform.SameTree, err
   132  	}
   133  
   134  	// [n] is fully defined by its attributes merging with the named reference
   135  	n, err = mergeWindowDefs(n, ref)
   136  	if err != nil {
   137  		return nil, transform.SameTree, err
   138  	}
   139  
   140  	if n.Name != "" {
   141  		// cache lookup
   142  		windowDefs[n.Name] = n
   143  	}
   144  	return n, transform.NewTree, nil
   145  }
   146  
   147  // mergeWindowDefs combines the attributes of two window definitions or returns
   148  // an error if the two are incompatible. [def] should have a reference to
   149  // [ref] through [def.Ref], and the return value drops the reference to indicate
   150  // the two were properly combined.
   151  func mergeWindowDefs(def, ref *sql.WindowDefinition) (*sql.WindowDefinition, error) {
   152  	if ref.Ref != "" {
   153  		panic("unreachable; cannot merge unresolved window definition")
   154  	}
   155  
   156  	var orderBy sql.SortFields
   157  	switch {
   158  	case len(def.OrderBy) > 0 && len(ref.OrderBy) > 0:
   159  		return nil, sql.ErrInvalidWindowInheritance.New("", "", "both contain order by clause")
   160  	case len(def.OrderBy) > 0:
   161  		orderBy = def.OrderBy
   162  	case len(ref.OrderBy) > 0:
   163  		orderBy = ref.OrderBy
   164  	default:
   165  	}
   166  
   167  	var partitionBy []sql.Expression
   168  	switch {
   169  	case len(def.PartitionBy) > 0 && len(ref.PartitionBy) > 0:
   170  		return nil, sql.ErrInvalidWindowInheritance.New("", "", "both contain partition by clause")
   171  	case len(def.PartitionBy) > 0:
   172  		partitionBy = def.PartitionBy
   173  	case len(ref.PartitionBy) > 0:
   174  		partitionBy = ref.PartitionBy
   175  	default:
   176  		partitionBy = []sql.Expression{}
   177  	}
   178  
   179  	var frame sql.WindowFrame
   180  	switch {
   181  	case def.Frame != nil && ref.Frame != nil:
   182  		_, isDefDefaultFrame := def.Frame.(*plan.RowsUnboundedPrecedingToUnboundedFollowingFrame)
   183  		_, isRefDefaultFrame := ref.Frame.(*plan.RowsUnboundedPrecedingToUnboundedFollowingFrame)
   184  
   185  		// if both frames are set and one is RowsUnboundedPrecedingToUnboundedFollowingFrame (default),
   186  		// we should use the other frame
   187  		if isDefDefaultFrame {
   188  			frame = ref.Frame
   189  		} else if isRefDefaultFrame {
   190  			frame = def.Frame
   191  		} else {
   192  			// if both frames have identical string representations, use either one
   193  			df := def.Frame.String()
   194  			rf := ref.Frame.String()
   195  			if df != rf {
   196  				return nil, sql.ErrInvalidWindowInheritance.New("", "", "both contain different frame clauses")
   197  			}
   198  			frame = def.Frame
   199  		}
   200  	case def.Frame != nil:
   201  		frame = def.Frame
   202  	case ref.Frame != nil:
   203  		frame = ref.Frame
   204  	default:
   205  	}
   206  
   207  	return sql.NewWindowDefinition(partitionBy, orderBy, frame, "", def.Name), nil
   208  }