github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/function/spatial/st_distance.go (about)

     1  // Copyright 2023 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package spatial
    16  
    17  import (
    18  	"fmt"
    19  	"math"
    20  	"strings"
    21  
    22  	"gopkg.in/src-d/go-errors.v1"
    23  
    24  	"github.com/dolthub/go-mysql-server/sql"
    25  	"github.com/dolthub/go-mysql-server/sql/expression"
    26  	"github.com/dolthub/go-mysql-server/sql/types"
    27  )
    28  
    29  // Distance is a function that returns the shortest distance between two geometries
    30  type Distance struct {
    31  	expression.NaryExpression
    32  }
    33  
    34  var _ sql.FunctionExpression = (*Distance)(nil)
    35  var _ sql.CollationCoercible = (*Distance)(nil)
    36  
    37  // ErrNoUnits is thrown when the specified SRID does not have units
    38  var ErrNoUnits = errors.NewKind("the geometry passed to function st_distance is in SRID %v, which doesn't specify a length unit. Can't convert to '%v'.")
    39  
    40  // NewDistance creates a new Distance expression.
    41  func NewDistance(args ...sql.Expression) (sql.Expression, error) {
    42  	if len(args) != 2 && len(args) != 3 {
    43  		return nil, sql.ErrInvalidArgumentNumber.New("ST_DISTANCE", "2 or 3", len(args))
    44  	}
    45  	return &Distance{expression.NaryExpression{ChildExpressions: args}}, nil
    46  }
    47  
    48  // FunctionName implements sql.FunctionExpression
    49  func (d *Distance) FunctionName() string {
    50  	return "st_distance"
    51  }
    52  
    53  // Description implements sql.FunctionExpression
    54  func (d *Distance) Description() string {
    55  	return "returns the distance between g1 and g2."
    56  }
    57  
    58  // Type implements the sql.Expression interface.
    59  func (d *Distance) Type() sql.Type {
    60  	return types.Float64
    61  }
    62  
    63  // CollationCoercibility implements the interface sql.CollationCoercible.
    64  func (*Distance) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
    65  	return sql.Collation_binary, 5
    66  }
    67  
    68  func (d *Distance) String() string {
    69  	var args = make([]string, len(d.ChildExpressions))
    70  	for i, arg := range d.ChildExpressions {
    71  		args[i] = arg.String()
    72  	}
    73  	return fmt.Sprintf("%s(%s)", d.FunctionName(), strings.Join(args, ","))
    74  }
    75  
    76  // WithChildren implements the Expression interface.
    77  func (d *Distance) WithChildren(children ...sql.Expression) (sql.Expression, error) {
    78  	return NewDistance(children...)
    79  }
    80  
    81  // flattenGeometry recursively "flattens" the geometry value into a map of its points
    82  func flattenGeometry(g types.GeometryValue, points map[types.Point]bool) {
    83  	switch g := g.(type) {
    84  	case types.Point:
    85  		points[g] = true
    86  	case types.LineString:
    87  		for _, p := range g.Points {
    88  			flattenGeometry(p, points)
    89  		}
    90  	case types.Polygon:
    91  		for _, l := range g.Lines {
    92  			flattenGeometry(l, points)
    93  		}
    94  	case types.MultiPoint:
    95  		for _, p := range g.Points {
    96  			flattenGeometry(p, points)
    97  		}
    98  	case types.MultiLineString:
    99  		for _, l := range g.Lines {
   100  			flattenGeometry(l, points)
   101  		}
   102  	case types.MultiPolygon:
   103  		for _, p := range g.Polygons {
   104  			flattenGeometry(p, points)
   105  		}
   106  	case types.GeomColl:
   107  		for _, gg := range g.Geoms {
   108  			flattenGeometry(gg, points)
   109  		}
   110  	}
   111  }
   112  
   113  // calcPointDist calculates the distance between two points
   114  // Small Optimization: don't do square root
   115  func calcPointDist(a, b types.Point) float64 {
   116  	dx := b.X - a.X
   117  	dy := b.Y - a.Y
   118  	return math.Sqrt(dx*dx + dy*dy)
   119  }
   120  
   121  // calcDist finds the minimum distance from a Point in g1 to a Point g2
   122  func calcDist(g1, g2 types.GeometryValue) interface{} {
   123  	points1, points2 := map[types.Point]bool{}, map[types.Point]bool{}
   124  	flattenGeometry(g1, points1)
   125  	flattenGeometry(g2, points2)
   126  
   127  	if len(points1) == 0 || len(points2) == 0 {
   128  		return nil
   129  	}
   130  
   131  	minDist := math.MaxFloat64
   132  	for a := range points1 {
   133  		for b := range points2 {
   134  			minDist = math.Min(minDist, calcPointDist(a, b))
   135  		}
   136  	}
   137  
   138  	return minDist
   139  }
   140  
   141  // Eval implements the sql.Expression interface.
   142  func (d *Distance) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
   143  	g1, err := d.ChildExpressions[0].Eval(ctx, row)
   144  	if err != nil {
   145  		return nil, err
   146  	}
   147  
   148  	g2, err := d.ChildExpressions[1].Eval(ctx, row)
   149  	if err != nil {
   150  		return nil, err
   151  	}
   152  
   153  	if g1 == nil || g2 == nil {
   154  		return nil, nil
   155  	}
   156  
   157  	geom1, ok := g1.(types.GeometryValue)
   158  	if !ok {
   159  		return nil, sql.ErrInvalidGISData.New(d.FunctionName())
   160  	}
   161  
   162  	geom2, ok := g2.(types.GeometryValue)
   163  	if !ok {
   164  		return nil, sql.ErrInvalidGISData.New(d.FunctionName())
   165  	}
   166  
   167  	srid1 := geom1.GetSRID()
   168  	srid2 := geom2.GetSRID()
   169  	if srid1 != srid2 {
   170  		return nil, sql.ErrDiffSRIDs.New(d.FunctionName(), srid1, srid2)
   171  	}
   172  
   173  	if srid1 != types.CartesianSRID {
   174  		return nil, sql.ErrUnsupportedSRID.New(srid1)
   175  	}
   176  
   177  	dist := calcDist(geom1, geom2)
   178  
   179  	if len(d.ChildExpressions) == 3 {
   180  		if srid1 == types.CartesianSRID {
   181  			return nil, ErrNoUnits.New(srid1)
   182  		}
   183  		// TODO: check valid unit arguments
   184  	}
   185  
   186  	return dist, nil
   187  }