github.com/gopherd/gonum@v0.0.4/spatial/barneshut/barneshut3.go (about)

     1  // Copyright ©2019 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package barneshut
     6  
     7  import (
     8  	"errors"
     9  	"fmt"
    10  	"math"
    11  
    12  	"github.com/gopherd/gonum/spatial/r3"
    13  )
    14  
    15  // Particle3 is a particle in a volume.
    16  type Particle3 interface {
    17  	Coord3() r3.Vec
    18  	Mass() float64
    19  }
    20  
    21  // Force3 is a force modeling function for interactions between p1 and p2,
    22  // m1 is the mass of p1 and m2 of p2. The vector v is the vector from p1 to
    23  // p2. The returned value is the force vector acting on p1.
    24  //
    25  // In models where the identity of particles must be known, p1 and p2 may be
    26  // compared. Force3 may be passed nil for p2 when the Barnes-Hut approximation
    27  // is being used. A nil p2 indicates that the second mass center is an
    28  // aggregate.
    29  type Force3 func(p1, p2 Particle3, m1, m2 float64, v r3.Vec) r3.Vec
    30  
    31  // Gravity3 returns a vector force on m1 by m2, equal to (m1⋅m2)/‖v‖²
    32  // in the directions of v. Gravity3 ignores the identity of the interacting
    33  // particles and returns a zero vector when the two particles are
    34  // coincident, but performs no other sanity checks.
    35  func Gravity3(_, _ Particle3, m1, m2 float64, v r3.Vec) r3.Vec {
    36  	d2 := v.X*v.X + v.Y*v.Y + v.Z*v.Z
    37  	if d2 == 0 {
    38  		return r3.Vec{}
    39  	}
    40  	return v.Scale((m1 * m2) / (d2 * math.Sqrt(d2)))
    41  }
    42  
    43  // Volume implements Barnes-Hut force approximation calculations.
    44  type Volume struct {
    45  	root bucket
    46  
    47  	Particles []Particle3
    48  }
    49  
    50  // NewVolume returns a new Volume. If the volume is too large to allow
    51  // particle coordinates to be distinguished due to floating point
    52  // precision limits, NewVolume will return a non-nil error.
    53  func NewVolume(p []Particle3) (*Volume, error) {
    54  	q := Volume{Particles: p}
    55  	err := q.Reset()
    56  	if err != nil {
    57  		return nil, err
    58  	}
    59  	return &q, nil
    60  }
    61  
    62  // Reset reconstructs the Barnes-Hut tree. Reset must be called if the
    63  // Particles field or elements of Particles have been altered, unless
    64  // ForceOn is called with theta=0 or no data structures have been
    65  // previously built. If the volume is too large to allow particle
    66  // coordinates to be distinguished due to floating point precision
    67  // limits, Reset will return a non-nil error.
    68  func (q *Volume) Reset() (err error) {
    69  	if len(q.Particles) == 0 {
    70  		q.root = bucket{}
    71  		return nil
    72  	}
    73  
    74  	q.root = bucket{
    75  		particle: q.Particles[0],
    76  		center:   q.Particles[0].Coord3(),
    77  		mass:     q.Particles[0].Mass(),
    78  	}
    79  	q.root.bounds.Min = q.root.center
    80  	q.root.bounds.Max = q.root.center
    81  	for _, e := range q.Particles[1:] {
    82  		c := e.Coord3()
    83  		if c.X < q.root.bounds.Min.X {
    84  			q.root.bounds.Min.X = c.X
    85  		}
    86  		if c.X > q.root.bounds.Max.X {
    87  			q.root.bounds.Max.X = c.X
    88  		}
    89  		if c.Y < q.root.bounds.Min.Y {
    90  			q.root.bounds.Min.Y = c.Y
    91  		}
    92  		if c.Y > q.root.bounds.Max.Y {
    93  			q.root.bounds.Max.Y = c.Y
    94  		}
    95  		if c.Z < q.root.bounds.Min.Z {
    96  			q.root.bounds.Min.Z = c.Z
    97  		}
    98  		if c.Z > q.root.bounds.Max.Z {
    99  			q.root.bounds.Max.Z = c.Z
   100  		}
   101  	}
   102  
   103  	defer func() {
   104  		switch r := recover(); r {
   105  		case nil:
   106  		case volumeTooBig:
   107  			err = volumeTooBig
   108  		default:
   109  			panic(r)
   110  		}
   111  	}()
   112  
   113  	// TODO(kortschak): Partially parallelise this by
   114  	// choosing the direction and using one of eight
   115  	// goroutines to work on each root octant.
   116  	for _, e := range q.Particles[1:] {
   117  		q.root.insert(e)
   118  	}
   119  	q.root.summarize()
   120  	return nil
   121  }
   122  
   123  var volumeTooBig = errors.New("barneshut: volume too big")
   124  
   125  // ForceOn returns a force vector on p given p's mass and the force function, f,
   126  // using the Barnes-Hut theta approximation parameter.
   127  //
   128  // Calls to f will include p in the p1 position and a non-nil p2 if the force
   129  // interaction is with a non-aggregate mass center, otherwise p2 will be nil.
   130  //
   131  // It is safe to call ForceOn concurrently.
   132  func (q *Volume) ForceOn(p Particle3, theta float64, f Force3) (force r3.Vec) {
   133  	var empty bucket
   134  	if theta > 0 && q.root != empty {
   135  		return q.root.forceOn(p, p.Coord3(), p.Mass(), theta, f)
   136  	}
   137  
   138  	// For the degenerate case, just iterate over the
   139  	// slice of particles rather than walking the tree.
   140  	var v r3.Vec
   141  	m := p.Mass()
   142  	pv := p.Coord3()
   143  	for _, e := range q.Particles {
   144  		v = v.Add(f(p, e, m, e.Mass(), e.Coord3().Sub(pv)))
   145  	}
   146  	return v
   147  }
   148  
   149  // bucket is an oct tree octant with Barnes-Hut extensions.
   150  type bucket struct {
   151  	particle Particle3
   152  
   153  	bounds r3.Box
   154  
   155  	nodes [8]*bucket
   156  
   157  	center r3.Vec
   158  	mass   float64
   159  }
   160  
   161  // insert inserts p into the subtree rooted at b.
   162  func (b *bucket) insert(p Particle3) {
   163  	if b.particle == nil {
   164  		for _, q := range b.nodes {
   165  			if q != nil {
   166  				b.passDown(p)
   167  				return
   168  			}
   169  		}
   170  		b.particle = p
   171  		b.center = p.Coord3()
   172  		b.mass = p.Mass()
   173  		return
   174  	}
   175  
   176  	b.passDown(p)
   177  	b.passDown(b.particle)
   178  	b.particle = nil
   179  	b.center = r3.Vec{}
   180  	b.mass = 0
   181  }
   182  
   183  func (b *bucket) passDown(p Particle3) {
   184  	dir := octantOf(b.bounds, p)
   185  	if b.nodes[dir] == nil {
   186  		b.nodes[dir] = &bucket{bounds: splitVolume(b.bounds, dir)}
   187  	}
   188  	b.nodes[dir].insert(p)
   189  }
   190  
   191  const (
   192  	lne = iota
   193  	lse
   194  	lsw
   195  	lnw
   196  	une
   197  	use
   198  	usw
   199  	unw
   200  )
   201  
   202  // octantOf returns which octant of b that p should be placed in.
   203  func octantOf(b r3.Box, p Particle3) int {
   204  	center := r3.Vec{
   205  		X: (b.Min.X + b.Max.X) / 2,
   206  		Y: (b.Min.Y + b.Max.Y) / 2,
   207  		Z: (b.Min.Z + b.Max.Z) / 2,
   208  	}
   209  	c := p.Coord3()
   210  	if checkBounds && (c.X < b.Min.X || b.Max.X < c.X || c.Y < b.Min.Y || b.Max.Y < c.Y || c.Z < b.Min.Z || b.Max.Z < c.Z) {
   211  		panic(fmt.Sprintf("p out of range %+v: %#v", b, p))
   212  	}
   213  	if c.X < center.X {
   214  		if c.Y < center.Y {
   215  			if c.Z < center.Z {
   216  				return lnw
   217  			} else {
   218  				return unw
   219  			}
   220  		} else {
   221  			if c.Z < center.Z {
   222  				return lsw
   223  			} else {
   224  				return usw
   225  			}
   226  		}
   227  	} else {
   228  		if c.Y < center.Y {
   229  			if c.Z < center.Z {
   230  				return lne
   231  			} else {
   232  				return une
   233  			}
   234  		} else {
   235  			if c.Z < center.Z {
   236  				return lse
   237  			} else {
   238  				return use
   239  			}
   240  		}
   241  	}
   242  }
   243  
   244  // splitVolume returns an octant subdivision of b in the given direction.
   245  func splitVolume(b r3.Box, dir int) r3.Box {
   246  	old := b
   247  	halfX := (b.Max.X - b.Min.X) / 2
   248  	halfY := (b.Max.Y - b.Min.Y) / 2
   249  	halfZ := (b.Max.Z - b.Min.Z) / 2
   250  	switch dir {
   251  	case lne:
   252  		b.Min.X += halfX
   253  		b.Max.Y -= halfY
   254  		b.Max.Z -= halfZ
   255  	case lse:
   256  		b.Min.X += halfX
   257  		b.Min.Y += halfY
   258  		b.Max.Z -= halfZ
   259  	case lsw:
   260  		b.Max.X -= halfX
   261  		b.Min.Y += halfY
   262  		b.Max.Z -= halfZ
   263  	case lnw:
   264  		b.Max.X -= halfX
   265  		b.Max.Y -= halfY
   266  		b.Max.Z -= halfZ
   267  	case une:
   268  		b.Min.X += halfX
   269  		b.Max.Y -= halfY
   270  		b.Min.Z += halfZ
   271  	case use:
   272  		b.Min.X += halfX
   273  		b.Min.Y += halfY
   274  		b.Min.Z += halfZ
   275  	case usw:
   276  		b.Max.X -= halfX
   277  		b.Min.Y += halfY
   278  		b.Min.Z += halfZ
   279  	case unw:
   280  		b.Max.X -= halfX
   281  		b.Max.Y -= halfY
   282  		b.Min.Z += halfZ
   283  	}
   284  	if b == old {
   285  		panic(volumeTooBig)
   286  	}
   287  	return b
   288  }
   289  
   290  // summarize updates node masses and centers of mass.
   291  func (b *bucket) summarize() (center r3.Vec, mass float64) {
   292  	for _, d := range &b.nodes {
   293  		if d == nil {
   294  			continue
   295  		}
   296  		c, m := d.summarize()
   297  		b.center.X += c.X * m
   298  		b.center.Y += c.Y * m
   299  		b.center.Z += c.Z * m
   300  		b.mass += m
   301  	}
   302  	b.center.X /= b.mass
   303  	b.center.Y /= b.mass
   304  	b.center.Z /= b.mass
   305  	return b.center, b.mass
   306  }
   307  
   308  // forceOn returns a force vector on p given p's mass m and the force
   309  // calculation function, using the Barnes-Hut theta approximation parameter.
   310  func (b *bucket) forceOn(p Particle3, pt r3.Vec, m, theta float64, f Force3) (vector r3.Vec) {
   311  	s := ((b.bounds.Max.X - b.bounds.Min.X) + (b.bounds.Max.Y - b.bounds.Min.Y) + (b.bounds.Max.Z - b.bounds.Min.Z)) / 3
   312  	d := math.Hypot(math.Hypot(pt.X-b.center.X, pt.Y-b.center.Y), pt.Z-b.center.Z)
   313  	if s/d < theta || b.particle != nil {
   314  		return f(p, b.particle, m, b.mass, b.center.Sub(pt))
   315  	}
   316  
   317  	var v r3.Vec
   318  	for _, d := range &b.nodes {
   319  		if d == nil {
   320  			continue
   321  		}
   322  		v = v.Add(d.forceOn(p, pt, m, theta, f))
   323  	}
   324  	return v
   325  }