Page MenuHomec4science

dgemm-perm-4x4x2.c
No OneTemporary

File Metadata

Created
Fri, Oct 18, 11:23

dgemm-perm-4x4x2.c

//
//
//
#include <stdlib.h>
//
#warning "arm neon"
#include <arm_neon.h>
#include <sys/time.h>
#include "gemm_ncopy_4.c"
#include "gemm_tcopy_4.c"
#if !defined(BLOCK_SIZE)
#ifndef M_BLOCK_SIZE
#define M_BLOCK_SIZE 300
#endif
#ifndef N_BLOCK_SIZE
#define N_BLOCK_SIZE 2000
#endif
#ifndef K_BLOCK_SIZE
#define K_BLOCK_SIZE 300
#endif
#else
#define N_BLOCK_SIZE BLOCK_SIZE
#define M_BLOCK_SIZE BLOCK_SIZE
#define K_BLOCK_SIZE BLOCK_SIZE
#endif
//#define PREFETCH(A) _mm_prefetch(A, _MM_HINT_NTA)
//#define PREFETCH0(A) _mm_prefetch(A, _MM_HINT_T0)
//#define PREFETCH1(A) _mm_prefetch(A, _MM_HINT_T1)
//#define PREFETCH2(A) _mm_prefetch(A, _MM_HINT_T2)
#define min(a,b) (((a)<(b))?(a):(b))
//#define STORE128(A, B) _mm_store_pd(A, B)
double myseconds()
{
struct timeval tp;
struct timezone tzp;
int i;
i = gettimeofday(&tp,&tzp);
return ( (double) tp.tv_sec + (double) tp.tv_usec * 1.e-6 );
}
//
//#ifdef __ARM_NEON__
#define vtype float64x2_t
#define vtype_2 float64x2x2_t
#define PREFETCH __builtin_prefetch
#define LOAD vld1q_f64
#define STORE vst1q_f64
//#endif
vtype set_vector(double val)
{
vtype ret;
ret = vsetq_lane_f64(val, ret, 0);
ret = vsetq_lane_f64(val, ret, 1);
return ret;
}
//
//vtype set_//
//
//
void print128(vtype vec)
{
printf("%f %f\n", vec[0], vec[1]);
}
void dgemm( const int M, const int N, const int K, const double alpha, const double *A, const int lda, const double *B, const int ldb, const double beta, double* C, const int ldc)
{
int ib, jb, kb;
int i, j, k;
//
double* Ab = (double*) malloc(M_BLOCK_SIZE*K_BLOCK_SIZE*sizeof(double));
double* Bb = (double*) malloc(K_BLOCK_SIZE*N_BLOCK_SIZE*sizeof(double));
//
double copytime = 0.;
double computetime = 0.;
//
long int ops = 0, mem = 0;
//
//#pragma omp parallel for private(i, j, k, kb, jb, ib)
for( kb = 0; kb < K; kb += K_BLOCK_SIZE ){ int Kb = min( K_BLOCK_SIZE, K - kb );
for( jb = 0; jb < N; jb += N_BLOCK_SIZE ){ int Nb = min( N_BLOCK_SIZE, N - jb );
//printf("-------> ib = %d, jb = %d, kb = %d\n", ib, jb, kb);
copytime -= myseconds();
ncopy_4(Kb, Nb, B + jb*ldb + kb, ldb, Bb);
copytime += myseconds();
mem += Kb*Nb*8;
//
for( ib = 0; ib < M; ib += M_BLOCK_SIZE ){ int Mb = min( M_BLOCK_SIZE, M - ib );
//
copytime -= myseconds();
//
tcopy_4(Kb, Mb, A + kb*lda + ib, lda, Ab);
copytime += myseconds();
mem += Kb*Mb*8;
//
double* pC = &C[0];
//
for (i = 0; i < Mb - Mb%4; i = i + 4){
for (j = 0; j < Nb - Nb%4; j = j + 4){
//
//PREFETCH((void*) pB + 0);
//PREFETCH((void*) pB + 8);
//
PREFETCH((void*) &C[(j + jb + 0)*ldc + i + ib + 0]);
PREFETCH((void*) &C[(j + jb + 1)*ldc + i + ib + 0]);
PREFETCH((void*) &C[(j + jb + 2)*ldc + i + ib + 0]);
PREFETCH((void*) &C[(j + jb + 3)*ldc + i + ib + 0]);
//
vtype v11 = set_vector(0.); //_mm256_setzero_pd();
vtype v10 = set_vector(0.); //_mm256_setzero_pd();
vtype v09 = set_vector(0.); //_mm256_setzero_pd();
vtype v08 = set_vector(0.); //_mm256_setzero_pd();
//
vtype v07 = set_vector(0.); //_mm256_setzero_pd();
vtype v06 = set_vector(0.); //_mm256_setzero_pd();
vtype v05 = set_vector(0.); //_mm256_setzero_pd();
vtype v04 = set_vector(0.); //_mm256_setzero_pd();
//
vtype b02;
vtype b00;
vtype b12;
vtype b10;
//
vtype a01;
vtype a03;
vtype a00;
vtype a02;
//
vtype a11;
vtype a13;
vtype a10;
vtype a12;
//
double* pA = &Ab[i*Kb + 0];
double* pB = &Bb[j*Kb + 0];
//
a00 = LOAD(pA + 0);
a02 = LOAD(pA + 2);
//
b00 = LOAD(pB + 0);
b02 = LOAD(pB + 2);
//
k = Kb >> 0;
//
//printf("\n");
computetime -= myseconds();
while(k)
{
PREFETCH((void*) pA + 512);
PREFETCH((void*) pB + 512);
a10 = LOAD(pA + 4);
a12 = LOAD(pA + 6);
//
b10 = LOAD(pB + 4);
b12 = LOAD(pB + 6);
//
//
// first part
//
v04 = vfmaq_f64(v04, b00, a00);
v05 = vfmaq_f64(v05, b02, a00);
//
a01 = vextq_f64(a00, a00, 1);
v06 = vfmaq_f64(v06, b00, a01);
v07 = vfmaq_f64(v07, b02, a01);
//
// second part
//
v08 = vfmaq_f64(v08, b00, a02);
v09 = vfmaq_f64(v09, b02, a02);
//
a03 = vextq_f64(a02, a02, 1);
v10 = vfmaq_f64(v10, b00, a03);
v11 = vfmaq_f64(v11, b02, a03);
//
ops += 2*2*8;
//
//pA += 4;
//pB += 4;
//printf("\n");
//k--;
//
// unroll #2
//
a00 = LOAD(pA + 8);
a02 = LOAD(pA + 10);
//
b00 = LOAD(pB + 8);
b02 = LOAD(pB + 10);
//
// first part
//
a11 = vextq_f64(a10, a10, 1);
v04 = vfmaq_f64(v04, b10, a10);
v05 = vfmaq_f64(v05, b12, a10);
//
v06 = vfmaq_f64(v06, b10, a11);
v07 = vfmaq_f64(v07, b12, a11);
//
// second part
//
a13 = vextq_f64(a12, a12, 1);
v08 = vfmaq_f64(v08, b10, a12);
v09 = vfmaq_f64(v09, b12, a12);
//
v10 = vfmaq_f64(v10, b10, a13);
v11 = vfmaq_f64(v11, b12, a13);
//
ops += 2*2*8;
//
pA += 8;
pB += 8;
//printf("\n");
k--; k--;
}
computetime += myseconds();
//
C[(j + jb + 0)*ldc + i + ib + 0] += v04[0];
C[(j + jb + 1)*ldc + i + ib + 0] += v06[1];
//
C[(j + jb + 2)*ldc + i + ib + 0] += v05[0];
C[(j + jb + 3)*ldc + i + ib + 0] += v07[1];
//
C[(j + jb + 0)*ldc + i + ib + 1] += v06[0];
C[(j + jb + 1)*ldc + i + ib + 1] += v04[1];
//
C[(j + jb + 2)*ldc + i + ib + 1] += v07[0];
C[(j + jb + 3)*ldc + i + ib + 1] += v05[1];
//
C[(j + jb + 0)*ldc + i + ib + 2] += v08[0];
C[(j + jb + 1)*ldc + i + ib + 2] += v10[1];
//
C[(j + jb + 2)*ldc + i + ib + 2] += v09[0];
C[(j + jb + 3)*ldc + i + ib + 2] += v11[1];
//
C[(j + jb + 0)*ldc + i + ib + 3] += v10[0];
C[(j + jb + 1)*ldc + i + ib + 3] += v08[1];
//
C[(j + jb + 2)*ldc + i + ib + 3] += v11[0];
C[(j + jb + 3)*ldc + i + ib + 3] += v09[1];
//
mem += 16*8;
//
}
}
}
} //
}
//
free(Ab);
free(Bb);
//
printf("copy time = %f, %f GB/s\n", copytime, mem/copytime/1024/1024/1024);
printf("compute time = %f, %f GFlops\n", computetime, 2.*M*N*K/computetime/1.e9);
printf("ops = %ld\n", ops);
printf("mem = %d\n", mem);
//
}

Event Timeline