github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/function/spatial/st_distance.go (about) 1 // Copyright 2023 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package spatial 16 17 import ( 18 "fmt" 19 "math" 20 "strings" 21 22 "gopkg.in/src-d/go-errors.v1" 23 24 "github.com/dolthub/go-mysql-server/sql" 25 "github.com/dolthub/go-mysql-server/sql/expression" 26 "github.com/dolthub/go-mysql-server/sql/types" 27 ) 28 29 // Distance is a function that returns the shortest distance between two geometries 30 type Distance struct { 31 expression.NaryExpression 32 } 33 34 var _ sql.FunctionExpression = (*Distance)(nil) 35 var _ sql.CollationCoercible = (*Distance)(nil) 36 37 // ErrNoUnits is thrown when the specified SRID does not have units 38 var ErrNoUnits = errors.NewKind("the geometry passed to function st_distance is in SRID %v, which doesn't specify a length unit. Can't convert to '%v'.") 39 40 // NewDistance creates a new Distance expression. 41 func NewDistance(args ...sql.Expression) (sql.Expression, error) { 42 if len(args) != 2 && len(args) != 3 { 43 return nil, sql.ErrInvalidArgumentNumber.New("ST_DISTANCE", "2 or 3", len(args)) 44 } 45 return &Distance{expression.NaryExpression{ChildExpressions: args}}, nil 46 } 47 48 // FunctionName implements sql.FunctionExpression 49 func (d *Distance) FunctionName() string { 50 return "st_distance" 51 } 52 53 // Description implements sql.FunctionExpression 54 func (d *Distance) Description() string { 55 return "returns the distance between g1 and g2." 56 } 57 58 // Type implements the sql.Expression interface. 59 func (d *Distance) Type() sql.Type { 60 return types.Float64 61 } 62 63 // CollationCoercibility implements the interface sql.CollationCoercible. 64 func (*Distance) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { 65 return sql.Collation_binary, 5 66 } 67 68 func (d *Distance) String() string { 69 var args = make([]string, len(d.ChildExpressions)) 70 for i, arg := range d.ChildExpressions { 71 args[i] = arg.String() 72 } 73 return fmt.Sprintf("%s(%s)", d.FunctionName(), strings.Join(args, ",")) 74 } 75 76 // WithChildren implements the Expression interface. 77 func (d *Distance) WithChildren(children ...sql.Expression) (sql.Expression, error) { 78 return NewDistance(children...) 79 } 80 81 // flattenGeometry recursively "flattens" the geometry value into a map of its points 82 func flattenGeometry(g types.GeometryValue, points map[types.Point]bool) { 83 switch g := g.(type) { 84 case types.Point: 85 points[g] = true 86 case types.LineString: 87 for _, p := range g.Points { 88 flattenGeometry(p, points) 89 } 90 case types.Polygon: 91 for _, l := range g.Lines { 92 flattenGeometry(l, points) 93 } 94 case types.MultiPoint: 95 for _, p := range g.Points { 96 flattenGeometry(p, points) 97 } 98 case types.MultiLineString: 99 for _, l := range g.Lines { 100 flattenGeometry(l, points) 101 } 102 case types.MultiPolygon: 103 for _, p := range g.Polygons { 104 flattenGeometry(p, points) 105 } 106 case types.GeomColl: 107 for _, gg := range g.Geoms { 108 flattenGeometry(gg, points) 109 } 110 } 111 } 112 113 // calcPointDist calculates the distance between two points 114 // Small Optimization: don't do square root 115 func calcPointDist(a, b types.Point) float64 { 116 dx := b.X - a.X 117 dy := b.Y - a.Y 118 return math.Sqrt(dx*dx + dy*dy) 119 } 120 121 // calcDist finds the minimum distance from a Point in g1 to a Point g2 122 func calcDist(g1, g2 types.GeometryValue) interface{} { 123 points1, points2 := map[types.Point]bool{}, map[types.Point]bool{} 124 flattenGeometry(g1, points1) 125 flattenGeometry(g2, points2) 126 127 if len(points1) == 0 || len(points2) == 0 { 128 return nil 129 } 130 131 minDist := math.MaxFloat64 132 for a := range points1 { 133 for b := range points2 { 134 minDist = math.Min(minDist, calcPointDist(a, b)) 135 } 136 } 137 138 return minDist 139 } 140 141 // Eval implements the sql.Expression interface. 142 func (d *Distance) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { 143 g1, err := d.ChildExpressions[0].Eval(ctx, row) 144 if err != nil { 145 return nil, err 146 } 147 148 g2, err := d.ChildExpressions[1].Eval(ctx, row) 149 if err != nil { 150 return nil, err 151 } 152 153 if g1 == nil || g2 == nil { 154 return nil, nil 155 } 156 157 geom1, ok := g1.(types.GeometryValue) 158 if !ok { 159 return nil, sql.ErrInvalidGISData.New(d.FunctionName()) 160 } 161 162 geom2, ok := g2.(types.GeometryValue) 163 if !ok { 164 return nil, sql.ErrInvalidGISData.New(d.FunctionName()) 165 } 166 167 srid1 := geom1.GetSRID() 168 srid2 := geom2.GetSRID() 169 if srid1 != srid2 { 170 return nil, sql.ErrDiffSRIDs.New(d.FunctionName(), srid1, srid2) 171 } 172 173 if srid1 != types.CartesianSRID { 174 return nil, sql.ErrUnsupportedSRID.New(srid1) 175 } 176 177 dist := calcDist(geom1, geom2) 178 179 if len(d.ChildExpressions) == 3 { 180 if srid1 == types.CartesianSRID { 181 return nil, ErrNoUnits.New(srid1) 182 } 183 // TODO: check valid unit arguments 184 } 185 186 return dist, nil 187 }