gonum.org/v1/gonum@v0.14.0/mat/shadow.go (about) 1 // Copyright ©2015 The Gonum Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package mat 6 7 import "gonum.org/v1/gonum/blas/blas64" 8 9 // checkOverlap returns false if the receiver does not overlap data elements 10 // referenced by the parameter and panics otherwise. 11 // 12 // checkOverlap methods return a boolean to allow the check call to be added to a 13 // boolean expression, making use of short-circuit operators. 14 func checkOverlap(a, b blas64.General) bool { 15 if cap(a.Data) == 0 || cap(b.Data) == 0 { 16 return false 17 } 18 19 off := offset(a.Data[:1], b.Data[:1]) 20 21 if off == 0 { 22 // At least one element overlaps. 23 if a.Cols == b.Cols && a.Rows == b.Rows && a.Stride == b.Stride { 24 panic(regionIdentity) 25 } 26 panic(regionOverlap) 27 } 28 29 if off > 0 && len(a.Data) <= off { 30 // We know a is completely before b. 31 return false 32 } 33 if off < 0 && len(b.Data) <= -off { 34 // We know a is completely after b. 35 return false 36 } 37 38 if a.Stride != b.Stride && a.Stride != 1 && b.Stride != 1 { 39 // Too hard, so assume the worst; if either stride 40 // is one it will be caught in rectanglesOverlap. 41 panic(mismatchedStrides) 42 } 43 44 if off < 0 { 45 off = -off 46 a.Cols, b.Cols = b.Cols, a.Cols 47 } 48 if rectanglesOverlap(off, a.Cols, b.Cols, min(a.Stride, b.Stride)) { 49 panic(regionOverlap) 50 } 51 return false 52 } 53 54 func (m *Dense) checkOverlap(a blas64.General) bool { 55 return checkOverlap(m.RawMatrix(), a) 56 } 57 58 func (m *Dense) checkOverlapMatrix(a Matrix) bool { 59 if m == a { 60 return false 61 } 62 var amat blas64.General 63 switch ar := a.(type) { 64 default: 65 return false 66 case RawMatrixer: 67 amat = ar.RawMatrix() 68 case RawSymmetricer: 69 amat = generalFromSymmetric(ar.RawSymmetric()) 70 case RawSymBander: 71 amat = generalFromSymmetricBand(ar.RawSymBand()) 72 case RawTriangular: 73 amat = generalFromTriangular(ar.RawTriangular()) 74 case RawVectorer: 75 r, c := a.Dims() 76 amat = generalFromVector(ar.RawVector(), r, c) 77 } 78 return m.checkOverlap(amat) 79 } 80 81 func (s *SymDense) checkOverlap(a blas64.General) bool { 82 return checkOverlap(generalFromSymmetric(s.RawSymmetric()), a) 83 } 84 85 func (s *SymDense) checkOverlapMatrix(a Matrix) bool { 86 if s == a { 87 return false 88 } 89 var amat blas64.General 90 switch ar := a.(type) { 91 default: 92 return false 93 case RawMatrixer: 94 amat = ar.RawMatrix() 95 case RawSymmetricer: 96 amat = generalFromSymmetric(ar.RawSymmetric()) 97 case RawSymBander: 98 amat = generalFromSymmetricBand(ar.RawSymBand()) 99 case RawTriangular: 100 amat = generalFromTriangular(ar.RawTriangular()) 101 case RawVectorer: 102 r, c := a.Dims() 103 amat = generalFromVector(ar.RawVector(), r, c) 104 } 105 return s.checkOverlap(amat) 106 } 107 108 // generalFromSymmetric returns a blas64.General with the backing 109 // data and dimensions of a. 110 func generalFromSymmetric(a blas64.Symmetric) blas64.General { 111 return blas64.General{ 112 Rows: a.N, 113 Cols: a.N, 114 Stride: a.Stride, 115 Data: a.Data, 116 } 117 } 118 119 func (t *TriDense) checkOverlap(a blas64.General) bool { 120 return checkOverlap(generalFromTriangular(t.RawTriangular()), a) 121 } 122 123 func (t *TriDense) checkOverlapMatrix(a Matrix) bool { 124 if t == a { 125 return false 126 } 127 var amat blas64.General 128 switch ar := a.(type) { 129 default: 130 return false 131 case RawMatrixer: 132 amat = ar.RawMatrix() 133 case RawSymmetricer: 134 amat = generalFromSymmetric(ar.RawSymmetric()) 135 case RawSymBander: 136 amat = generalFromSymmetricBand(ar.RawSymBand()) 137 case RawTriangular: 138 amat = generalFromTriangular(ar.RawTriangular()) 139 case RawVectorer: 140 r, c := a.Dims() 141 amat = generalFromVector(ar.RawVector(), r, c) 142 } 143 return t.checkOverlap(amat) 144 } 145 146 // generalFromTriangular returns a blas64.General with the backing 147 // data and dimensions of a. 148 func generalFromTriangular(a blas64.Triangular) blas64.General { 149 return blas64.General{ 150 Rows: a.N, 151 Cols: a.N, 152 Stride: a.Stride, 153 Data: a.Data, 154 } 155 } 156 157 func (v *VecDense) checkOverlap(a blas64.Vector) bool { 158 mat := v.mat 159 if cap(mat.Data) == 0 || cap(a.Data) == 0 { 160 return false 161 } 162 163 off := offset(mat.Data[:1], a.Data[:1]) 164 165 if off == 0 { 166 // At least one element overlaps. 167 if mat.Inc == a.Inc && len(mat.Data) == len(a.Data) { 168 panic(regionIdentity) 169 } 170 panic(regionOverlap) 171 } 172 173 if off > 0 && len(mat.Data) <= off { 174 // We know v is completely before a. 175 return false 176 } 177 if off < 0 && len(a.Data) <= -off { 178 // We know v is completely after a. 179 return false 180 } 181 182 if mat.Inc != a.Inc && mat.Inc != 1 && a.Inc != 1 { 183 // Too hard, so assume the worst; if either 184 // increment is one it will be caught below. 185 panic(mismatchedStrides) 186 } 187 inc := min(mat.Inc, a.Inc) 188 189 if inc == 1 || off&inc == 0 { 190 panic(regionOverlap) 191 } 192 return false 193 } 194 195 // generalFromVector returns a blas64.General with the backing 196 // data and dimensions of a. 197 func generalFromVector(a blas64.Vector, r, c int) blas64.General { 198 return blas64.General{ 199 Rows: r, 200 Cols: c, 201 Stride: a.Inc, 202 Data: a.Data, 203 } 204 } 205 206 func (s *SymBandDense) checkOverlap(a blas64.General) bool { 207 return checkOverlap(generalFromSymmetricBand(s.RawSymBand()), a) 208 } 209 210 //lint:ignore U1000 This will be used when we do shadow checks for banded matrices. 211 func (s *SymBandDense) checkOverlapMatrix(a Matrix) bool { 212 if s == a { 213 return false 214 } 215 var amat blas64.General 216 switch ar := a.(type) { 217 default: 218 return false 219 case RawMatrixer: 220 amat = ar.RawMatrix() 221 case RawSymmetricer: 222 amat = generalFromSymmetric(ar.RawSymmetric()) 223 case RawSymBander: 224 amat = generalFromSymmetricBand(ar.RawSymBand()) 225 case RawTriangular: 226 amat = generalFromTriangular(ar.RawTriangular()) 227 case RawVectorer: 228 r, c := a.Dims() 229 amat = generalFromVector(ar.RawVector(), r, c) 230 } 231 return s.checkOverlap(amat) 232 } 233 234 // generalFromSymmetricBand returns a blas64.General with the backing 235 // data and dimensions of a. 236 func generalFromSymmetricBand(a blas64.SymmetricBand) blas64.General { 237 return blas64.General{ 238 Rows: a.N, 239 Cols: a.K + 1, 240 Data: a.Data, 241 Stride: a.Stride, 242 } 243 }