gonum.org/v1/gonum@v0.14.0/blas/blas64/conv.go (about)

     1  // Copyright ©2015 The Gonum Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package blas64
     6  
     7  import "gonum.org/v1/gonum/blas"
     8  
     9  // GeneralCols represents a matrix using the conventional column-major storage scheme.
    10  type GeneralCols General
    11  
    12  // From fills the receiver with elements from a. The receiver
    13  // must have the same dimensions as a and have adequate backing
    14  // data storage.
    15  func (t GeneralCols) From(a General) {
    16  	if t.Rows != a.Rows || t.Cols != a.Cols {
    17  		panic("blas64: mismatched dimension")
    18  	}
    19  	if len(t.Data) < (t.Cols-1)*t.Stride+t.Rows {
    20  		panic("blas64: short data slice")
    21  	}
    22  	for i := 0; i < a.Rows; i++ {
    23  		for j, v := range a.Data[i*a.Stride : i*a.Stride+a.Cols] {
    24  			t.Data[i+j*t.Stride] = v
    25  		}
    26  	}
    27  }
    28  
    29  // From fills the receiver with elements from a. The receiver
    30  // must have the same dimensions as a and have adequate backing
    31  // data storage.
    32  func (t General) From(a GeneralCols) {
    33  	if t.Rows != a.Rows || t.Cols != a.Cols {
    34  		panic("blas64: mismatched dimension")
    35  	}
    36  	if len(t.Data) < (t.Rows-1)*t.Stride+t.Cols {
    37  		panic("blas64: short data slice")
    38  	}
    39  	for j := 0; j < a.Cols; j++ {
    40  		for i, v := range a.Data[j*a.Stride : j*a.Stride+a.Rows] {
    41  			t.Data[i*t.Stride+j] = v
    42  		}
    43  	}
    44  }
    45  
    46  // TriangularCols represents a matrix using the conventional column-major storage scheme.
    47  type TriangularCols Triangular
    48  
    49  // From fills the receiver with elements from a. The receiver
    50  // must have the same dimensions, uplo and diag as a and have
    51  // adequate backing data storage.
    52  func (t TriangularCols) From(a Triangular) {
    53  	if t.N != a.N {
    54  		panic("blas64: mismatched dimension")
    55  	}
    56  	if t.Uplo != a.Uplo {
    57  		panic("blas64: mismatched BLAS uplo")
    58  	}
    59  	if t.Diag != a.Diag {
    60  		panic("blas64: mismatched BLAS diag")
    61  	}
    62  	switch a.Uplo {
    63  	default:
    64  		panic("blas64: bad BLAS uplo")
    65  	case blas.Upper:
    66  		for i := 0; i < a.N; i++ {
    67  			for j := i; j < a.N; j++ {
    68  				t.Data[i+j*t.Stride] = a.Data[i*a.Stride+j]
    69  			}
    70  		}
    71  	case blas.Lower:
    72  		for i := 0; i < a.N; i++ {
    73  			for j := 0; j <= i; j++ {
    74  				t.Data[i+j*t.Stride] = a.Data[i*a.Stride+j]
    75  			}
    76  		}
    77  	case blas.All:
    78  		for i := 0; i < a.N; i++ {
    79  			for j := 0; j < a.N; j++ {
    80  				t.Data[i+j*t.Stride] = a.Data[i*a.Stride+j]
    81  			}
    82  		}
    83  	}
    84  }
    85  
    86  // From fills the receiver with elements from a. The receiver
    87  // must have the same dimensions, uplo and diag as a and have
    88  // adequate backing data storage.
    89  func (t Triangular) From(a TriangularCols) {
    90  	if t.N != a.N {
    91  		panic("blas64: mismatched dimension")
    92  	}
    93  	if t.Uplo != a.Uplo {
    94  		panic("blas64: mismatched BLAS uplo")
    95  	}
    96  	if t.Diag != a.Diag {
    97  		panic("blas64: mismatched BLAS diag")
    98  	}
    99  	switch a.Uplo {
   100  	default:
   101  		panic("blas64: bad BLAS uplo")
   102  	case blas.Upper:
   103  		for i := 0; i < a.N; i++ {
   104  			for j := i; j < a.N; j++ {
   105  				t.Data[i*t.Stride+j] = a.Data[i+j*a.Stride]
   106  			}
   107  		}
   108  	case blas.Lower:
   109  		for i := 0; i < a.N; i++ {
   110  			for j := 0; j <= i; j++ {
   111  				t.Data[i*t.Stride+j] = a.Data[i+j*a.Stride]
   112  			}
   113  		}
   114  	case blas.All:
   115  		for i := 0; i < a.N; i++ {
   116  			for j := 0; j < a.N; j++ {
   117  				t.Data[i*t.Stride+j] = a.Data[i+j*a.Stride]
   118  			}
   119  		}
   120  	}
   121  }
   122  
   123  // BandCols represents a matrix using the band column-major storage scheme.
   124  type BandCols Band
   125  
   126  // From fills the receiver with elements from a. The receiver
   127  // must have the same dimensions and bandwidth as a and have
   128  // adequate backing data storage.
   129  func (t BandCols) From(a Band) {
   130  	if t.Rows != a.Rows || t.Cols != a.Cols {
   131  		panic("blas64: mismatched dimension")
   132  	}
   133  	if t.KL != a.KL || t.KU != a.KU {
   134  		panic("blas64: mismatched bandwidth")
   135  	}
   136  	if a.Stride < a.KL+a.KU+1 {
   137  		panic("blas64: short stride for source")
   138  	}
   139  	if t.Stride < t.KL+t.KU+1 {
   140  		panic("blas64: short stride for destination")
   141  	}
   142  	for i := 0; i < a.Rows; i++ {
   143  		for j := max(0, i-a.KL); j < min(i+a.KU+1, a.Cols); j++ {
   144  			t.Data[i+t.KU-j+j*t.Stride] = a.Data[j+a.KL-i+i*a.Stride]
   145  		}
   146  	}
   147  }
   148  
   149  // From fills the receiver with elements from a. The receiver
   150  // must have the same dimensions and bandwidth as a and have
   151  // adequate backing data storage.
   152  func (t Band) From(a BandCols) {
   153  	if t.Rows != a.Rows || t.Cols != a.Cols {
   154  		panic("blas64: mismatched dimension")
   155  	}
   156  	if t.KL != a.KL || t.KU != a.KU {
   157  		panic("blas64: mismatched bandwidth")
   158  	}
   159  	if a.Stride < a.KL+a.KU+1 {
   160  		panic("blas64: short stride for source")
   161  	}
   162  	if t.Stride < t.KL+t.KU+1 {
   163  		panic("blas64: short stride for destination")
   164  	}
   165  	for j := 0; j < a.Cols; j++ {
   166  		for i := max(0, j-a.KU); i < min(j+a.KL+1, a.Rows); i++ {
   167  			t.Data[j+a.KL-i+i*a.Stride] = a.Data[i+t.KU-j+j*t.Stride]
   168  		}
   169  	}
   170  }
   171  
   172  // TriangularBandCols represents a triangular matrix using the band column-major storage scheme.
   173  type TriangularBandCols TriangularBand
   174  
   175  // From fills the receiver with elements from a. The receiver
   176  // must have the same dimensions, bandwidth and uplo as a and
   177  // have adequate backing data storage.
   178  func (t TriangularBandCols) From(a TriangularBand) {
   179  	if t.N != a.N {
   180  		panic("blas64: mismatched dimension")
   181  	}
   182  	if t.K != a.K {
   183  		panic("blas64: mismatched bandwidth")
   184  	}
   185  	if a.Stride < a.K+1 {
   186  		panic("blas64: short stride for source")
   187  	}
   188  	if t.Stride < t.K+1 {
   189  		panic("blas64: short stride for destination")
   190  	}
   191  	if t.Uplo != a.Uplo {
   192  		panic("blas64: mismatched BLAS uplo")
   193  	}
   194  	if t.Diag != a.Diag {
   195  		panic("blas64: mismatched BLAS diag")
   196  	}
   197  	dst := BandCols{
   198  		Rows: t.N, Cols: t.N,
   199  		Stride: t.Stride,
   200  		Data:   t.Data,
   201  	}
   202  	src := Band{
   203  		Rows: a.N, Cols: a.N,
   204  		Stride: a.Stride,
   205  		Data:   a.Data,
   206  	}
   207  	switch a.Uplo {
   208  	default:
   209  		panic("blas64: bad BLAS uplo")
   210  	case blas.Upper:
   211  		dst.KU = t.K
   212  		src.KU = a.K
   213  	case blas.Lower:
   214  		dst.KL = t.K
   215  		src.KL = a.K
   216  	}
   217  	dst.From(src)
   218  }
   219  
   220  // From fills the receiver with elements from a. The receiver
   221  // must have the same dimensions, bandwidth and uplo as a and
   222  // have adequate backing data storage.
   223  func (t TriangularBand) From(a TriangularBandCols) {
   224  	if t.N != a.N {
   225  		panic("blas64: mismatched dimension")
   226  	}
   227  	if t.K != a.K {
   228  		panic("blas64: mismatched bandwidth")
   229  	}
   230  	if a.Stride < a.K+1 {
   231  		panic("blas64: short stride for source")
   232  	}
   233  	if t.Stride < t.K+1 {
   234  		panic("blas64: short stride for destination")
   235  	}
   236  	if t.Uplo != a.Uplo {
   237  		panic("blas64: mismatched BLAS uplo")
   238  	}
   239  	if t.Diag != a.Diag {
   240  		panic("blas64: mismatched BLAS diag")
   241  	}
   242  	dst := Band{
   243  		Rows: t.N, Cols: t.N,
   244  		Stride: t.Stride,
   245  		Data:   t.Data,
   246  	}
   247  	src := BandCols{
   248  		Rows: a.N, Cols: a.N,
   249  		Stride: a.Stride,
   250  		Data:   a.Data,
   251  	}
   252  	switch a.Uplo {
   253  	default:
   254  		panic("blas64: bad BLAS uplo")
   255  	case blas.Upper:
   256  		dst.KU = t.K
   257  		src.KU = a.K
   258  	case blas.Lower:
   259  		dst.KL = t.K
   260  		src.KL = a.K
   261  	}
   262  	dst.From(src)
   263  }
   264  
   265  func min(a, b int) int {
   266  	if a < b {
   267  		return a
   268  	}
   269  	return b
   270  }
   271  
   272  func max(a, b int) int {
   273  	if a > b {
   274  		return a
   275  	}
   276  	return b
   277  }