Page Menu
Home
c4science
Search
Configure Global Search
Log In
Files
F88373802
dgemm-perm-4x4x2.c
No One
Temporary
Actions
Download File
Edit File
Delete File
View Transforms
Subscribe
Mute Notifications
Award Token
Subscribers
None
File Metadata
Details
File Info
Storage
Attached
Created
Fri, Oct 18, 11:23
Size
9 KB
Mime Type
text/x-c
Expires
Sun, Oct 20, 11:23 (2 d)
Engine
blob
Format
Raw Data
Handle
21761660
Attached To
R31 arm64-hpc
dgemm-perm-4x4x2.c
View Options
//
//
//
#include <stdlib.h>
//
#warning "arm neon"
#include <arm_neon.h>
#include <sys/time.h>
#include "gemm_ncopy_4.c"
#include "gemm_tcopy_4.c"
#if !defined(BLOCK_SIZE)
#ifndef M_BLOCK_SIZE
#define M_BLOCK_SIZE 300
#endif
#ifndef N_BLOCK_SIZE
#define N_BLOCK_SIZE 2000
#endif
#ifndef K_BLOCK_SIZE
#define K_BLOCK_SIZE 300
#endif
#else
#define N_BLOCK_SIZE BLOCK_SIZE
#define M_BLOCK_SIZE BLOCK_SIZE
#define K_BLOCK_SIZE BLOCK_SIZE
#endif
//#define PREFETCH(A) _mm_prefetch(A, _MM_HINT_NTA)
//#define PREFETCH0(A) _mm_prefetch(A, _MM_HINT_T0)
//#define PREFETCH1(A) _mm_prefetch(A, _MM_HINT_T1)
//#define PREFETCH2(A) _mm_prefetch(A, _MM_HINT_T2)
#define min(a,b) (((a)<(b))?(a):(b))
//#define STORE128(A, B) _mm_store_pd(A, B)
double
myseconds
()
{
struct
timeval
tp
;
struct
timezone
tzp
;
int
i
;
i
=
gettimeofday
(
&
tp
,
&
tzp
);
return
(
(
double
)
tp
.
tv_sec
+
(
double
)
tp
.
tv_usec
*
1.e-6
);
}
//
//#ifdef __ARM_NEON__
#define vtype float64x2_t
#define vtype_2 float64x2x2_t
#define PREFETCH __builtin_prefetch
#define LOAD vld1q_f64
#define STORE vst1q_f64
//#endif
vtype
set_vector
(
double
val
)
{
vtype
ret
;
ret
=
vsetq_lane_f64
(
val
,
ret
,
0
);
ret
=
vsetq_lane_f64
(
val
,
ret
,
1
);
return
ret
;
}
//
//vtype set_//
//
//
void
print128
(
vtype
vec
)
{
printf
(
"%f %f
\n
"
,
vec
[
0
],
vec
[
1
]);
}
void
dgemm
(
const
int
M
,
const
int
N
,
const
int
K
,
const
double
alpha
,
const
double
*
A
,
const
int
lda
,
const
double
*
B
,
const
int
ldb
,
const
double
beta
,
double
*
C
,
const
int
ldc
)
{
int
ib
,
jb
,
kb
;
int
i
,
j
,
k
;
//
double
*
Ab
=
(
double
*
)
malloc
(
M_BLOCK_SIZE
*
K_BLOCK_SIZE
*
sizeof
(
double
));
double
*
Bb
=
(
double
*
)
malloc
(
K_BLOCK_SIZE
*
N_BLOCK_SIZE
*
sizeof
(
double
));
//
double
copytime
=
0.
;
double
computetime
=
0.
;
//
long
int
ops
=
0
,
mem
=
0
;
//
//#pragma omp parallel for private(i, j, k, kb, jb, ib)
for
(
kb
=
0
;
kb
<
K
;
kb
+=
K_BLOCK_SIZE
){
int
Kb
=
min
(
K_BLOCK_SIZE
,
K
-
kb
);
for
(
jb
=
0
;
jb
<
N
;
jb
+=
N_BLOCK_SIZE
){
int
Nb
=
min
(
N_BLOCK_SIZE
,
N
-
jb
);
//printf("-------> ib = %d, jb = %d, kb = %d\n", ib, jb, kb);
copytime
-=
myseconds
();
ncopy_4
(
Kb
,
Nb
,
B
+
jb
*
ldb
+
kb
,
ldb
,
Bb
);
copytime
+=
myseconds
();
mem
+=
Kb
*
Nb
*
8
;
//
for
(
ib
=
0
;
ib
<
M
;
ib
+=
M_BLOCK_SIZE
){
int
Mb
=
min
(
M_BLOCK_SIZE
,
M
-
ib
);
//
copytime
-=
myseconds
();
//
tcopy_4
(
Kb
,
Mb
,
A
+
kb
*
lda
+
ib
,
lda
,
Ab
);
copytime
+=
myseconds
();
mem
+=
Kb
*
Mb
*
8
;
//
double
*
pC
=
&
C
[
0
];
//
for
(
i
=
0
;
i
<
Mb
-
Mb
%
4
;
i
=
i
+
4
){
for
(
j
=
0
;
j
<
Nb
-
Nb
%
4
;
j
=
j
+
4
){
//
//PREFETCH((void*) pB + 0);
//PREFETCH((void*) pB + 8);
//
PREFETCH
((
void
*
)
&
C
[(
j
+
jb
+
0
)
*
ldc
+
i
+
ib
+
0
]);
PREFETCH
((
void
*
)
&
C
[(
j
+
jb
+
1
)
*
ldc
+
i
+
ib
+
0
]);
PREFETCH
((
void
*
)
&
C
[(
j
+
jb
+
2
)
*
ldc
+
i
+
ib
+
0
]);
PREFETCH
((
void
*
)
&
C
[(
j
+
jb
+
3
)
*
ldc
+
i
+
ib
+
0
]);
//
vtype
v11
=
set_vector
(
0.
);
//_mm256_setzero_pd();
vtype
v10
=
set_vector
(
0.
);
//_mm256_setzero_pd();
vtype
v09
=
set_vector
(
0.
);
//_mm256_setzero_pd();
vtype
v08
=
set_vector
(
0.
);
//_mm256_setzero_pd();
//
vtype
v07
=
set_vector
(
0.
);
//_mm256_setzero_pd();
vtype
v06
=
set_vector
(
0.
);
//_mm256_setzero_pd();
vtype
v05
=
set_vector
(
0.
);
//_mm256_setzero_pd();
vtype
v04
=
set_vector
(
0.
);
//_mm256_setzero_pd();
//
vtype
b02
;
vtype
b00
;
vtype
b12
;
vtype
b10
;
//
vtype
a01
;
vtype
a03
;
vtype
a00
;
vtype
a02
;
//
vtype
a11
;
vtype
a13
;
vtype
a10
;
vtype
a12
;
//
double
*
pA
=
&
Ab
[
i
*
Kb
+
0
];
double
*
pB
=
&
Bb
[
j
*
Kb
+
0
];
//
a00
=
LOAD
(
pA
+
0
);
a02
=
LOAD
(
pA
+
2
);
//
b00
=
LOAD
(
pB
+
0
);
b02
=
LOAD
(
pB
+
2
);
//
k
=
Kb
>>
0
;
//
//printf("\n");
computetime
-=
myseconds
();
while
(
k
)
{
PREFETCH
((
void
*
)
pA
+
512
);
PREFETCH
((
void
*
)
pB
+
512
);
a10
=
LOAD
(
pA
+
4
);
a12
=
LOAD
(
pA
+
6
);
//
b10
=
LOAD
(
pB
+
4
);
b12
=
LOAD
(
pB
+
6
);
//
//
// first part
//
v04
=
vfmaq_f64
(
v04
,
b00
,
a00
);
v05
=
vfmaq_f64
(
v05
,
b02
,
a00
);
//
a01
=
vextq_f64
(
a00
,
a00
,
1
);
v06
=
vfmaq_f64
(
v06
,
b00
,
a01
);
v07
=
vfmaq_f64
(
v07
,
b02
,
a01
);
//
// second part
//
v08
=
vfmaq_f64
(
v08
,
b00
,
a02
);
v09
=
vfmaq_f64
(
v09
,
b02
,
a02
);
//
a03
=
vextq_f64
(
a02
,
a02
,
1
);
v10
=
vfmaq_f64
(
v10
,
b00
,
a03
);
v11
=
vfmaq_f64
(
v11
,
b02
,
a03
);
//
ops
+=
2
*
2
*
8
;
//
//pA += 4;
//pB += 4;
//printf("\n");
//k--;
//
// unroll #2
//
a00
=
LOAD
(
pA
+
8
);
a02
=
LOAD
(
pA
+
10
);
//
b00
=
LOAD
(
pB
+
8
);
b02
=
LOAD
(
pB
+
10
);
//
// first part
//
a11
=
vextq_f64
(
a10
,
a10
,
1
);
v04
=
vfmaq_f64
(
v04
,
b10
,
a10
);
v05
=
vfmaq_f64
(
v05
,
b12
,
a10
);
//
v06
=
vfmaq_f64
(
v06
,
b10
,
a11
);
v07
=
vfmaq_f64
(
v07
,
b12
,
a11
);
//
// second part
//
a13
=
vextq_f64
(
a12
,
a12
,
1
);
v08
=
vfmaq_f64
(
v08
,
b10
,
a12
);
v09
=
vfmaq_f64
(
v09
,
b12
,
a12
);
//
v10
=
vfmaq_f64
(
v10
,
b10
,
a13
);
v11
=
vfmaq_f64
(
v11
,
b12
,
a13
);
//
ops
+=
2
*
2
*
8
;
//
pA
+=
8
;
pB
+=
8
;
//printf("\n");
k
--
;
k
--
;
}
computetime
+=
myseconds
();
//
C
[(
j
+
jb
+
0
)
*
ldc
+
i
+
ib
+
0
]
+=
v04
[
0
];
C
[(
j
+
jb
+
1
)
*
ldc
+
i
+
ib
+
0
]
+=
v06
[
1
];
//
C
[(
j
+
jb
+
2
)
*
ldc
+
i
+
ib
+
0
]
+=
v05
[
0
];
C
[(
j
+
jb
+
3
)
*
ldc
+
i
+
ib
+
0
]
+=
v07
[
1
];
//
C
[(
j
+
jb
+
0
)
*
ldc
+
i
+
ib
+
1
]
+=
v06
[
0
];
C
[(
j
+
jb
+
1
)
*
ldc
+
i
+
ib
+
1
]
+=
v04
[
1
];
//
C
[(
j
+
jb
+
2
)
*
ldc
+
i
+
ib
+
1
]
+=
v07
[
0
];
C
[(
j
+
jb
+
3
)
*
ldc
+
i
+
ib
+
1
]
+=
v05
[
1
];
//
C
[(
j
+
jb
+
0
)
*
ldc
+
i
+
ib
+
2
]
+=
v08
[
0
];
C
[(
j
+
jb
+
1
)
*
ldc
+
i
+
ib
+
2
]
+=
v10
[
1
];
//
C
[(
j
+
jb
+
2
)
*
ldc
+
i
+
ib
+
2
]
+=
v09
[
0
];
C
[(
j
+
jb
+
3
)
*
ldc
+
i
+
ib
+
2
]
+=
v11
[
1
];
//
C
[(
j
+
jb
+
0
)
*
ldc
+
i
+
ib
+
3
]
+=
v10
[
0
];
C
[(
j
+
jb
+
1
)
*
ldc
+
i
+
ib
+
3
]
+=
v08
[
1
];
//
C
[(
j
+
jb
+
2
)
*
ldc
+
i
+
ib
+
3
]
+=
v11
[
0
];
C
[(
j
+
jb
+
3
)
*
ldc
+
i
+
ib
+
3
]
+=
v09
[
1
];
//
mem
+=
16
*
8
;
//
}
}
}
}
//
}
//
free
(
Ab
);
free
(
Bb
);
//
printf
(
"copy time = %f, %f GB/s
\n
"
,
copytime
,
mem
/
copytime
/
1024
/
1024
/
1024
);
printf
(
"compute time = %f, %f GFlops
\n
"
,
computetime
,
2.
*
M
*
N
*
K
/
computetime
/
1.e9
);
printf
(
"ops = %ld
\n
"
,
ops
);
printf
(
"mem = %d
\n
"
,
mem
);
//
}
Event Timeline
Log In to Comment