Visual Servoing Platform version 3.5.0
vpGEMM.h
1/****************************************************************************
2 *
3 * ViSP, open source Visual Servoing Platform software.
4 * Copyright (C) 2005 - 2019 by Inria. All rights reserved.
5 *
6 * This software is free software; you can redistribute it and/or modify
7 * it under the terms of the GNU General Public License as published by
8 * the Free Software Foundation; either version 2 of the License, or
9 * (at your option) any later version.
10 * See the file LICENSE.txt at the root directory of this source
11 * distribution for additional information about the GNU GPL.
12 *
13 * For using ViSP with software that can not be combined with the GNU
14 * GPL, please contact Inria about acquiring a ViSP Professional
15 * Edition License.
16 *
17 * See http://visp.inria.fr for more information.
18 *
19 * This software was developed at:
20 * Inria Rennes - Bretagne Atlantique
21 * Campus Universitaire de Beaulieu
22 * 35042 Rennes Cedex
23 * France
24 *
25 * If you have questions regarding the use of this file, please contact
26 * Inria at visp@inria.fr
27 *
28 * This file is provided AS IS with NO WARRANTY OF ANY KIND, INCLUDING THE
29 * WARRANTY OF DESIGN, MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE.
30 *
31 * Description:
32 * Matrix generalized multiplication.
33 *
34 * Authors:
35 * Laneurit Jean
36 *
37 *****************************************************************************/
38
39#ifndef _vpGEMM_h_
40#define _vpGEMM_h_
41
42#include <visp3/core/vpArray2D.h>
43#include <visp3/core/vpException.h>
44
45const vpArray2D<double> null(0, 0);
46
57typedef enum {
58 VP_GEMM_A_T = 1,
59 VP_GEMM_B_T = 2,
60 VP_GEMM_C_T = 4,
62
63template <unsigned int>
64inline void GEMMsize(const vpArray2D<double> & /*A*/, const vpArray2D<double> & /*B*/, unsigned int & /*Arows*/,
65 unsigned int & /*Acols*/, unsigned int & /*Brows*/, unsigned int & /*Bcols*/)
66{
67}
68
69template <>
70void inline GEMMsize<0>(const vpArray2D<double> &A, const vpArray2D<double> &B, unsigned int &Arows,
71 unsigned int &Acols, unsigned int &Brows, unsigned int &Bcols)
72{
73 Arows = A.getRows();
74 Acols = A.getCols();
75 Brows = B.getRows();
76 Bcols = B.getCols();
77}
78
79template <>
80inline void GEMMsize<1>(const vpArray2D<double> &A, const vpArray2D<double> &B, unsigned int &Arows,
81 unsigned int &Acols, unsigned int &Brows, unsigned int &Bcols)
82{
83 Arows = A.getCols();
84 Acols = A.getRows();
85 Brows = B.getRows();
86 Bcols = B.getCols();
87}
88template <>
89inline void GEMMsize<2>(const vpArray2D<double> &A, const vpArray2D<double> &B, unsigned int &Arows,
90 unsigned int &Acols, unsigned int &Brows, unsigned int &Bcols)
91{
92 Arows = A.getRows();
93 Acols = A.getCols();
94 Brows = B.getCols();
95 Bcols = B.getRows();
96}
97template <>
98inline void GEMMsize<3>(const vpArray2D<double> &A, const vpArray2D<double> &B, unsigned int &Arows,
99 unsigned int &Acols, unsigned int &Brows, unsigned int &Bcols)
100{
101 Arows = A.getCols();
102 Acols = A.getRows();
103 Brows = B.getCols();
104 Bcols = B.getRows();
105}
106
107template <>
108inline void GEMMsize<4>(const vpArray2D<double> &A, const vpArray2D<double> &B, unsigned int &Arows,
109 unsigned int &Acols, unsigned int &Brows, unsigned int &Bcols)
110{
111 Arows = A.getRows();
112 Acols = A.getCols();
113 Brows = B.getRows();
114 Bcols = B.getCols();
115}
116
117template <>
118inline void GEMMsize<5>(const vpArray2D<double> &A, const vpArray2D<double> &B, unsigned int &Arows,
119 unsigned int &Acols, unsigned int &Brows, unsigned int &Bcols)
120{
121 Arows = A.getCols();
122 Acols = A.getRows();
123 Brows = B.getRows();
124 Bcols = B.getCols();
125}
126
127template <>
128inline void GEMMsize<6>(const vpArray2D<double> &A, const vpArray2D<double> &B, unsigned int &Arows,
129 unsigned int &Acols, unsigned int &Brows, unsigned int &Bcols)
130{
131 Arows = A.getRows();
132 Acols = A.getCols();
133 Brows = B.getCols();
134 Bcols = B.getRows();
135}
136
137template <>
138inline void GEMMsize<7>(const vpArray2D<double> &A, const vpArray2D<double> &B, unsigned int &Arows,
139 unsigned int &Acols, unsigned int &Brows, unsigned int &Bcols)
140{
141 Arows = A.getCols();
142 Acols = A.getRows();
143 Brows = B.getCols();
144 Bcols = B.getRows();
145}
146
147template <unsigned int>
148inline void GEMM1(const unsigned int & /*Arows*/, const unsigned int & /*Brows*/, const unsigned int & /*Bcols*/,
149 const vpArray2D<double> & /*A*/, const vpArray2D<double> & /*B*/, const double & /*alpha*/,
150 vpArray2D<double> & /*D*/)
151{
152}
153
154template <>
155inline void GEMM1<0>(const unsigned int &Arows, const unsigned int &Brows, const unsigned int &Bcols,
156 const vpArray2D<double> &A, const vpArray2D<double> &B, const double &alpha, vpArray2D<double> &D)
157{
158 for (unsigned int r = 0; r < Arows; r++)
159 for (unsigned int c = 0; c < Bcols; c++) {
160 double sum = 0;
161 for (unsigned int n = 0; n < Brows; n++)
162 sum += A[r][n] * B[n][c] * alpha;
163 D[r][c] = sum;
164 }
165}
166
167template <>
168inline void GEMM1<1>(const unsigned int &Arows, const unsigned int &Brows, const unsigned int &Bcols,
169 const vpArray2D<double> &A, const vpArray2D<double> &B, const double &alpha, vpArray2D<double> &D)
170{
171 for (unsigned int r = 0; r < Arows; r++)
172 for (unsigned int c = 0; c < Bcols; c++) {
173 double sum = 0;
174 for (unsigned int n = 0; n < Brows; n++)
175 sum += A[n][r] * B[n][c] * alpha;
176 D[r][c] = sum;
177 }
178}
179
180template <>
181inline void GEMM1<2>(const unsigned int &Arows, const unsigned int &Brows, const unsigned int &Bcols,
182 const vpArray2D<double> &A, const vpArray2D<double> &B, const double &alpha, vpArray2D<double> &D)
183{
184 for (unsigned int r = 0; r < Arows; r++)
185 for (unsigned int c = 0; c < Bcols; c++) {
186 double sum = 0;
187 for (unsigned int n = 0; n < Brows; n++)
188 sum += A[r][n] * B[c][n] * alpha;
189 D[r][c] = sum;
190 }
191}
192
193template <>
194inline void GEMM1<3>(const unsigned int &Arows, const unsigned int &Brows, const unsigned int &Bcols,
195 const vpArray2D<double> &A, const vpArray2D<double> &B, const double &alpha, vpArray2D<double> &D)
196{
197 for (unsigned int r = 0; r < Arows; r++)
198 for (unsigned int c = 0; c < Bcols; c++) {
199 double sum = 0;
200 for (unsigned int n = 0; n < Brows; n++)
201 sum += A[n][r] * B[c][n] * alpha;
202 D[r][c] = sum;
203 }
204}
205
206template <unsigned int>
207inline void GEMM2(const unsigned int & /*Arows*/, const unsigned int & /*Brows*/, const unsigned int & /*Bcols*/,
208 const vpArray2D<double> & /*A*/, const vpArray2D<double> & /*B*/, const double & /*alpha*/,
209 const vpArray2D<double> & /*C*/, const double & /*beta*/, vpArray2D<double> & /*D*/)
210{
211}
212
213template <>
214inline void GEMM2<0>(const unsigned int &Arows, const unsigned int &Brows, const unsigned int &Bcols,
215 const vpArray2D<double> &A, const vpArray2D<double> &B, const double &alpha,
216 const vpArray2D<double> &C, const double &beta, vpArray2D<double> &D)
217{
218 for (unsigned int r = 0; r < Arows; r++)
219 for (unsigned int c = 0; c < Bcols; c++) {
220 double sum = 0;
221 for (unsigned int n = 0; n < Brows; n++)
222 sum += A[r][n] * B[n][c] * alpha;
223 D[r][c] = sum + C[r][c] * beta;
224 }
225}
226
227template <>
228inline void GEMM2<1>(const unsigned int &Arows, const unsigned int &Brows, const unsigned int &Bcols,
229 const vpArray2D<double> &A, const vpArray2D<double> &B, const double &alpha,
230 const vpArray2D<double> &C, const double &beta, vpArray2D<double> &D)
231{
232 for (unsigned int r = 0; r < Arows; r++)
233 for (unsigned int c = 0; c < Bcols; c++) {
234 double sum = 0;
235 for (unsigned int n = 0; n < Brows; n++)
236 sum += A[n][r] * B[n][c] * alpha;
237 D[r][c] = sum + C[r][c] * beta;
238 }
239}
240
241template <>
242inline void GEMM2<2>(const unsigned int &Arows, const unsigned int &Brows, const unsigned int &Bcols,
243 const vpArray2D<double> &A, const vpArray2D<double> &B, const double &alpha,
244 const vpArray2D<double> &C, const double &beta, vpArray2D<double> &D)
245{
246 for (unsigned int r = 0; r < Arows; r++)
247 for (unsigned int c = 0; c < Bcols; c++) {
248 double sum = 0;
249 for (unsigned int n = 0; n < Brows; n++)
250 sum += A[r][n] * B[c][n] * alpha;
251 D[r][c] = sum + C[r][c] * beta;
252 }
253}
254
255template <>
256inline void GEMM2<3>(const unsigned int &Arows, const unsigned int &Brows, const unsigned int &Bcols,
257 const vpArray2D<double> &A, const vpArray2D<double> &B, const double &alpha,
258 const vpArray2D<double> &C, const double &beta, vpArray2D<double> &D)
259{
260 for (unsigned int r = 0; r < Arows; r++)
261 for (unsigned int c = 0; c < Bcols; c++) {
262 double sum = 0;
263 for (unsigned int n = 0; n < Brows; n++)
264 sum += A[n][r] * B[c][n] * alpha;
265 D[r][c] = sum + C[r][c] * beta;
266 }
267}
268
269template <>
270inline void GEMM2<4>(const unsigned int &Arows, const unsigned int &Brows, const unsigned int &Bcols,
271 const vpArray2D<double> &A, const vpArray2D<double> &B, const double &alpha,
272 const vpArray2D<double> &C, const double &beta, vpArray2D<double> &D)
273{
274 for (unsigned int r = 0; r < Arows; r++)
275 for (unsigned int c = 0; c < Bcols; c++) {
276 double sum = 0;
277 for (unsigned int n = 0; n < Brows; n++)
278 sum += A[r][n] * B[n][c] * alpha;
279 D[r][c] = sum + C[c][r] * beta;
280 }
281}
282
283template <>
284inline void GEMM2<5>(const unsigned int &Arows, const unsigned int &Brows, const unsigned int &Bcols,
285 const vpArray2D<double> &A, const vpArray2D<double> &B, const double &alpha,
286 const vpArray2D<double> &C, const double &beta, vpArray2D<double> &D)
287{
288 for (unsigned int r = 0; r < Arows; r++)
289 for (unsigned int c = 0; c < Bcols; c++) {
290 double sum = 0;
291 for (unsigned int n = 0; n < Brows; n++)
292 sum += A[n][r] * B[n][c] * alpha;
293 D[r][c] = sum + C[c][r] * beta;
294 }
295}
296
297template <>
298inline void GEMM2<6>(const unsigned int &Arows, const unsigned int &Brows, const unsigned int &Bcols,
299 const vpArray2D<double> &A, const vpArray2D<double> &B, const double &alpha,
300 const vpArray2D<double> &C, const double &beta, vpArray2D<double> &D)
301{
302 for (unsigned int r = 0; r < Arows; r++)
303 for (unsigned int c = 0; c < Bcols; c++) {
304 double sum = 0;
305 for (unsigned int n = 0; n < Brows; n++)
306 sum += A[r][n] * B[c][n] * alpha;
307 D[r][c] = sum + C[c][r] * beta;
308 }
309}
310
311template <>
312inline void GEMM2<7>(const unsigned int &Arows, const unsigned int &Brows, const unsigned int &Bcols,
313 const vpArray2D<double> &A, const vpArray2D<double> &B, const double &alpha,
314 const vpArray2D<double> &C, const double &beta, vpArray2D<double> &D)
315{
316 for (unsigned int r = 0; r < Arows; r++)
317 for (unsigned int c = 0; c < Bcols; c++) {
318 double sum = 0;
319 for (unsigned int n = 0; n < Brows; n++)
320 sum += A[n][r] * B[c][n] * alpha;
321 D[r][c] = sum + C[c][r] * beta;
322 }
323}
324
325template <unsigned int T>
326inline void vpTGEMM(const vpArray2D<double> &A, const vpArray2D<double> &B, const double &alpha,
327 const vpArray2D<double> &C, const double &beta, vpArray2D<double> &D)
328{
329 unsigned int Arows;
330 unsigned int Acols;
331 unsigned int Brows;
332 unsigned int Bcols;
333
334 GEMMsize<T>(A, B, Arows, Acols, Brows, Bcols);
335
336 try {
337 if ((Arows != D.getRows()) || (Bcols != D.getCols()))
338 D.resize(Arows, Bcols);
339 } catch (...) {
340 throw;
341 }
342
343 if (Acols != Brows) {
344 throw(vpException(vpException::dimensionError, "In vpGEMM, cannot multiply (%dx%d) matrix by (%dx%d) matrix", Arows,
345 Acols, Brows, Bcols));
346 }
347
348 if (C.getRows() != 0 && C.getCols() != 0) {
349 if ((Arows != C.getRows()) || (Bcols != C.getCols())) {
350 throw(vpException(vpException::dimensionError, "In vpGEMM, cannot add resulting (%dx%d) matrix to (%dx%d) matrix",
351 Arows, Bcols, C.getRows(), C.getCols()));
352 }
353
354 GEMM2<T>(Arows, Brows, Bcols, A, B, alpha, C, beta, D);
355 } else {
356 GEMM1<T>(Arows, Brows, Bcols, A, B, alpha, D);
357 }
358}
359
393inline void vpGEMM(const vpArray2D<double> &A, const vpArray2D<double> &B, const double &alpha,
394 const vpArray2D<double> &C, const double &beta, vpArray2D<double> &D, const unsigned int &ops = 0)
395{
396 switch (ops) {
397 case 0:
398 vpTGEMM<0>(A, B, alpha, C, beta, D);
399 break;
400 case 1:
401 vpTGEMM<1>(A, B, alpha, C, beta, D);
402 break;
403 case 2:
404 vpTGEMM<2>(A, B, alpha, C, beta, D);
405 break;
406 case 3:
407 vpTGEMM<3>(A, B, alpha, C, beta, D);
408 break;
409 case 4:
410 vpTGEMM<4>(A, B, alpha, C, beta, D);
411 break;
412 case 5:
413 vpTGEMM<5>(A, B, alpha, C, beta, D);
414 break;
415 case 6:
416 vpTGEMM<6>(A, B, alpha, C, beta, D);
417 break;
418 case 7:
419 vpTGEMM<7>(A, B, alpha, C, beta, D);
420 break;
421 default:
422 throw(vpException(vpException::functionNotImplementedError, "Operation on vpGEMM not implemented"));
423 break;
424 }
425}
426
427#endif
unsigned int getCols() const
Definition: vpArray2D.h:279
void vpGEMM(const vpArray2D< double > &A, const vpArray2D< double > &B, const double &alpha, const vpArray2D< double > &C, const double &beta, vpArray2D< double > &D, const unsigned int &ops=0)
Definition: vpGEMM.h:393
vpGEMMmethod
Definition: vpGEMM.h:57
unsigned int getRows() const
Definition: vpArray2D.h:289
error that can be emited by ViSP classes.
Definition: vpException.h:72
@ functionNotImplementedError
Function not implemented.
Definition: vpException.h:90
@ dimensionError
Bad dimension.
Definition: vpException.h:95