LASs - Linear Algebra Routines on OmpSs  1.0.0
LASs
ddss_dgemm.c
Go to the documentation of this file.
1 #include "../include/lass.h"
2 
3 /**
4  *
5  * @file ddss_dgemm.c
6  *
7  * @brief LASs-DDSs ddss_dgemm routine.
8  *
9  * LASs-DDSs is a software package provided by:
10  * Barcelona Supercomputing Center - Centro Nacional de Supercomputacion
11  *
12  * @author Pedro Valero-Lara pedro.valero@bsc.es
13  * @date 2017-01-02
14  * @reviewer
15  * @modified
16  *
17  **/
18 
19 /**
20  *
21  * @ingroup DDSS
22  *
23  * Performs the matrix-matrix operation:
24  *
25  * C = ALPHA * op( A ) * op( B ) + BETA * C
26  *
27  * where op( X ) is one of:
28  *
29  * op( X ) = X or
30  * op( X ) = X**T
31  *
32  * ALPHA and BETA are scalars, and A, B and C are matrices,
33  * with op( A ) an M by K matrix, op( B ) a K by N matrix and C
34  * an M by N matrix.
35  *
36 **/
37 
38 /**
39  *
40  * @param[in]
41  * TRANS_A enum DDSS_TRANS.
42  * TRANS_A specifies the form of op( A ) to be used in
43  * the matrix multiplication as follows:
44  * - NoTrans: op( A ) = A.
45  * - Trans: op( A ) = A**T.
46  *
47  * @param[in]
48  * TRANS_B enum DDSS_TRANS.
49  * TRANS_B specifies the form of op( B ) to be used in
50  * the matrix multiplication as follows:
51  * - NoTrans: op( B ) = B.
52  * - Trans: op( B ) = B**T.
53  *
54  * @param[in]
55  * M int.
56  * M specifies the number of rows of the matrix A
57  * and the number of rows of the matrix C.
58  * M must be greater than zero.
59  *
60  * @param[in]
61  * N int.
62  * N specifies the number of columns of the matrix B
63  * and the number of columns of the matrix C.
64  * N must be greater than zero.
65  *
66  * @param[in]
67  * K int.
68  * K specifies the number of columns of the matrix A
69  * and the number of rows of the matrix B.
70  * K must be greater than zero.
71  *
72  * @param[in]
73  * ALPHA double.
74  *
75  * @param[in]
76  * A double *.
77  * A is a pointer to a matrix of dimension Ma ( rows ) by Ka
78  * ( columns ), where Ma is M and Ka is K when TRANS_A = NoTrans,
79  * and Ma is K and Ka is M otherwise.
80  *
81  * @param[in]
82  * LDA int.
83  * LDA specifies the number of columns of A ( row-major order ).
84  * When TRANS_A = NoTrans then LDA must be at least max( 1, K ),
85  * otherwise LDA must be at least max( 1, M ).
86  *
87  * @param[in]
88  * B double *.
89  * B is a pointer to a matrix of dimension Kb ( rows ) by Nb
90  * ( columns ), where Kb is K and Nb is N when TRANS_B = NoTrans,
91  * and Kb is N and Nb is K otherwise.
92  *
93  * @param[in]
94  * LDB int.
95  * LDB specifies the number of columns of B ( row-major order ).
96  * When TRANS_B = NoTrans then LDB must be at least max( 1, N ),
97  * otherwise LDB must be at least max( 1, K ).
98  *
99  * @param[in]
100  * BETA double.
101  *
102  * @param[in,out]
103  * C double *.
104  * C is a pointer to a matrix of dimension M by N.
105  * On exit, C is overwritten by the M by N
106  * matrix ( ALPHA*op( A )*op( B ) + BETA*C ).
107  *
108  * @param[in]
109  * LDC int.
110  * LDC specifies the number of columns of C ( row-major order ).
111  * LDC must be at least max( 1, N ).
112  *
113  **/
114 
115 /**
116  *
117  * @retval Success successful exit
118  * @retval NoSuccess unsuccessful exit
119  *
120  **/
121 
122 /**
123  *
124  * @sa kdgemm
125  *
126  **/
127 
128 int ddss_dgemm( enum DDSS_TRANS TRANS_A, enum DDSS_TRANS TRANS_B,
129  int M, int N, int K,
130  double ALPHA, double *A, int LDA,
131  double *B, int LDB,
132  double BETA, double *C, int LDC )
133 {
134 
135  // Local variables
136  int An, Bn;
137 
138  // Argument checking
139  if ( ( TRANS_A != NoTrans ) && ( TRANS_A != Trans ) )
140  {
141  fprintf( stderr, "Illegal value of TRANS_A, in ddss_dgemm code\n" );
142  return NoSuccess;
143  }
144 
145  if ( ( TRANS_B != NoTrans ) && ( TRANS_B != Trans ) )
146  {
147  fprintf( stderr, "Illegal value of TRANS_B, in ddss_dgemm code\n" );
148  return NoSuccess;
149  }
150 
151  if ( M < 0 )
152  {
153  fprintf( stderr, "Illegal value of M, in ddss_dgemm code\n" );
154  return NoSuccess;
155  }
156 
157  if ( N < 0 )
158  {
159  fprintf( stderr, "Illegal value of N, in ddss_dgemm code\n" );
160  return NoSuccess;
161  }
162 
163  if ( K < 0 )
164  {
165  fprintf( stderr, "Illegal value of K, in ddss_dgemm code\n" );
166  return NoSuccess;
167  }
168 
169  if ( TRANS_A == NoTrans )
170  {
171  An = K;
172  }
173  else
174  {
175  An = M;
176  }
177 
178  if ( LDA < MAX( 1, An ) )
179  {
180  fprintf( stderr, "Illegal value of LDA, in ddss_dgemm code\n" );
181  return NoSuccess;
182  }
183 
184  if ( TRANS_B == NoTrans )
185  {
186  Bn = N;
187  }
188  else
189  {
190  Bn = K;
191  }
192 
193  if ( LDB < MAX( 1, Bn ) )
194  {
195  fprintf( stderr, "Illegal value of LDB, in ddss_dgemm code\n" );
196  return NoSuccess;
197  }
198 
199  if ( LDC < MAX( 1, N ) )
200  {
201  fprintf( stderr, "Illegal value of LDC, in ddss_dgemm code\n" );
202  return NoSuccess;
203  }
204 
205  // Quick return
206  if ( M == 0 || N == 0 || ( ( ALPHA == 0.0 || K == 0 ) && BETA == 1.0 ) )
207  {
208  return Success;
209  }
210 
211  return kdgemm( TRANS_A, TRANS_B, M, N, K,
212  (const double) ALPHA, A, LDA,
213  B, LDB,
214  (const double) BETA, C, LDC );
215 
216 }
enum LASS_RETURN kdgemm(enum DDSS_TRANS TRANS_A, enum DDSS_TRANS TRANS_B, int M, int N, int K, const double ALPHA, double *A, int LDA, double *B, int LDB, const double BETA, double *C, int LDC)
Definition: kdgemm.c:131
int ddss_dgemm(enum DDSS_TRANS TRANS_A, enum DDSS_TRANS TRANS_B, int M, int N, int K, double ALPHA, double *A, int LDA, double *B, int LDB, double BETA, double *C, int LDC)
Definition: ddss_dgemm.c:128