Related questions BLAS with symmetry in higher order tensor in Fortran
How to speed up reshape in higher rank tensor contraction by BLAS in Fortran?
To evaluate the following tensor contraction
A[a,b] * B[b,c,d] = C[a,c,d]
(Einstein summation rule implied, repeated indices, b
, means summation.) with symmetries B[b,c,d] = B[b,d,c]
and hence C[a,c,d] = C[a,d,c]
. As suggested in the related question above,
do d=1,n
do c=1,d
C(:,c,d) = matmul(A(:,:), B(:,c,d)) !This block came from the answer in the first related question. I think it should be matmul(A(a,:), B(:,c,d))
enddo
do c=d+1,n
C(:,c,d) = C(:,d,c)
enddo
enddo
To utilize BLAS in matmul
, it seems I need to do summation in some dimension of an array. In using DGEMM, is there any way to specify the range of indices being used in the array B
? By the size of K
in stackoverflow.com/questions/66296334/openmp-with-blas? Should I account from the first index of B
?
I tried to introduce smaller arrays, e.g., B_small(:) = B(:,c,d), C_small(:) = C(:,c,d)
, manipulating new arrays take quite some time
An attempt to use DGEMM
and DGEMV
Program blas_symm_tensor
Use, Intrinsic :: iso_fortran_env, Only : wp => real64, li => int64
Implicit None
Real( wp ), Dimension( :, : ), Allocatable :: a
Real( wp ), Dimension( :, :, : ), Allocatable :: b
Real( wp ), Dimension( :, :, : ), Allocatable :: c1, c2, c3, c4, c5
Integer :: na, nb, nc, nd, ne
Integer :: la, lb, lc, ld
Integer( li ) :: start, finish, rate, numthreads
Write( *, * ) 'na, nb, nc, nd ?'
Read( *, * ) na, nb, nc, nd
ne = nc * nd
Allocate( a ( 1:na, 1:nb ) )
Allocate( b ( 1:nb, 1:nc, 1:nd ) )
Allocate( c1( 1:na, 1:nc, 1:nd ) )
Allocate( c2( 1:na, 1:nc, 1:nd ) )
Allocate( c3( 1:na, 1:nc, 1:nd ) )
Allocate( c4( 1:na, 1:nc, 1:nd ) )
Allocate( c5( 1:na, 1:nc, 1:nd ) )
! Set up some data
Call Random_number( a )
Call Random_number( b )
! symmetrize tensor b
do la = 1, na
do lc = 1, nc
do ld = lc+1, nd
b(la,lc,ld) = b(la,ld,lc)
enddo
enddo
enddo
Call System_clock( start, rate )
c1 = 0.0_wp
do ld = 1, nd
do lc = 1, nc
do lb = 1, nb
do la = 1, na
c1(la,lc,ld) = c1(la,lc,ld) + a(la,lb) * b(lb, lc, ld)
enddo
enddo
enddo
enddo
Call System_clock( finish, rate )
Write( *, * ) 'Time for loop', Real( finish - start, wp ) / rate
Call System_clock( start, rate )
c2 = 0.0_wp
do ld = 1, nd
do lc = 1, ld
do lb = 1, nb
do la = 1, na
c2(la,lc,ld) = c2(la,lc,ld) + a(la,lb) * b(lb, lc, ld)
enddo
enddo
enddo
enddo
do ld = 1, nd
do lc = ld+1, nc
do la = 1, na
c2(la,lc,ld) = c2(la,ld,lc)
end do
enddo
enddo
Call System_clock( finish, rate )
Write( *, * ) 'Time for symmetric loop', Real( finish - start, wp ) / rate
Call System_clock( start, rate )
c3 = 0.0_wp
do ld = 1, nd
do lc = 1, ld
Call dgemm( 'N', 'N', na, 1, nb, 1.0_wp, a , Size( a , Dim = 1 ), &
b(:,lc,ld) , Size( b , Dim = 1 ), &
0.0_wp, c3, Size( c3, Dim = 1 ) )
enddo
enddo
do ld = 1, nd
do lc = ld+1, nc
do la = 1, na
c3(la,lc,ld) = c3(la,ld,lc)
end do
enddo
enddo
Call System_clock( finish, rate )
Write( *, * ) 'Time for symmetric dgemm', Real( finish - start, wp ) / rate
! Direct
Call System_clock( start, rate )
Call dgemm( 'N', 'N', na, ne, nb, 1.0_wp, a , Size( a , Dim = 1 ), &
b , Size( b , Dim = 1 ), &
0.0_wp, c4, Size( c4, Dim = 1 ) )
Call System_clock( finish, rate )
Write( *, * ) 'Time for straight dgemm', Real( finish - start, wp ) / rate
Call System_clock( start, rate )
c5 = 0.0_wp
do ld = 1, nd
do lc = 1, ld
Call dgemv( 'N', na, nb, 1.0_wp, a, na, b(1,lc,ld), 1, 0.0_wp, c5, 1 )
enddo
enddo
do ld = 1, nd
do lc = ld+1, nc
do la = 1, na
c5(la,lc,ld) = c5(la,ld,lc)
end do
enddo
enddo
Call System_clock( finish, rate )
Write( *, * ) 'Time for symmetric dgemv', Real( finish - start, wp ) / rate
do la = 1, na
do lc = 1, nc
do ld = 1, nd
if ( dabs(c1(la,lc,ld) - c2(la,lc,ld)) > 1.e-6 ) then
write (*,*) '!!! c2', la,lc,ld, c1(la,lc,ld) - c2(la,lc,ld)
endif
if ( dabs(c1(la,lc,ld) - c3(la,lc,ld)) > 1.e-6 ) then
write (*,*) '!!! c3', la,lc,ld, c1(la,lc,ld) - c3(la,lc,ld)
endif
if ( dabs(c1(la,lc,ld) - c4(la,lc,ld)) > 1.e-6 ) then
write (*,*) '!!! c4', la,lc,ld, c1(la,lc,ld) - c4(la,lc,ld)
endif
if ( dabs(c1(la,lc,ld) - c5(la,lc,ld)) > 1.e-6 ) then
write (*,*) '!!! c5', la,lc,ld, c1(la,lc,ld) - c5(la,lc,ld)
endif
enddo
enddo
enddo
End Program blas_symm_tensor
Got different number in c3
and c5
than naive nested loops :(
PS: There seems to be some tensor contraction with symmetry in Julia https://jutho.github.io/TensorKit.jl/stable/man/intro/ I am not sure the time in interacting with Julia from Fortran :(