/*
* Copyright (c) 2010-2013 Michael Pippig
*
* This file is part of PFFT.
*
* PFFT is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* PFFT is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with PFFT. If not, see .
*
*/
#include "pfft.h"
#include "ipfft.h"
#include "util.h"
static PX(plan) mkplan(
int rnk_n, int rnk_pm);
static void malloc_and_split_cart_procmesh(
int rnk_n, unsigned transp_flag,
MPI_Comm comm_cart,
int *rnk_pm, MPI_Comm **comms_pm);
static void malloc_and_compute_cart_np_and_coords(
int rnk_n, unsigned transp_flag,
MPI_Comm comm_cart, int pid,
int *rnk_pm, int **np_pm, int **coords_pm);
static void init_param_local_size(
INT *lni, INT *lis, INT *dummy_ln, INT *dummy_ls, INT *lno, INT *los,
unsigned transp_flag,
INT **lni_to, INT **lis_to, INT **lno_to, INT **los_to,
INT **lni_ti, INT **lis_ti, INT **lno_ti, INT **los_ti);
static void save_param_into_plan(
int rnk_n, const INT *n, const INT *ni, const INT *no,
INT howmany, const INT *iblock, const INT *mblock, const INT *oblock,
MPI_Comm comm_cart, int rnk_pm, MPI_Comm *comms_pm,
R *in, R *out, int sign, const X(r2r_kind) *kinds, const int *skip_trafos,
unsigned transp_flag, unsigned trafo_flag,
unsigned opt_flag, unsigned fftw_flags,
unsigned pfft_flags,
PX(plan) ths);
static void init_param_size_and_trafo_flags(
int rnk_n, const INT *n, const INT *ni, const INT *no, int rnk_pm,
unsigned trafo_flag, unsigned transp_flag, const int *skip_trafos,
INT *pn_to, INT *pni_to, INT *pno_to, unsigned *trafo_flags_to,
INT *pn_ti, INT *pni_ti, INT *pno_ti, unsigned *trafo_flags_ti);
static unsigned get_skip_flag(
const int *skip_trafos, int t);
static void set_plans_to_null(
int rnk_pm, unsigned transp_flag,
outrafo_plan *trafos, gtransp_plan *remaps);
static void evaluate_blocks_by_comms(
int rnk_n, const INT *ni, const INT *no,
const INT *iblock_user, const INT *oblock_user,
int rnk_pm, const MPI_Comm *comms_pm,
unsigned trafo_flag, unsigned transp_flag,
INT *iblk, INT *mblk, INT *oblk);
static void evaluate_blocks(
int rnk_n, const INT *ni, const INT *no,
const INT *iblock_user, const INT *oblock_user,
int rnk_pm, const int *np_pm,
unsigned trafo_flag, unsigned transp_flag,
INT *iblk, INT *mblk, INT *oblk);
static unsigned extract_transp_flag(
unsigned pfft_flags);
static unsigned extract_opt_flag(
unsigned pfft_flags);
static unsigned extract_io_flag(
unsigned pfft_flags);
static unsigned extract_shift_index_flag(
unsigned pfft_flags);
static unsigned extract_fftw_flags(
unsigned pfft_flags);
/* TRANPOSED_OUT:
* - iblock gives block size of nontransp. input
* - oblock gives block size of transp. output
* TRANPOSED_IN:
* - iblock gives block size of transp. input
* - oblock gives block size of nontransp. output
* NOT TRANSPOSED:
* - iblock gives block size of nontransp. input
* - defblock gives block size of intermediate transp. data layout
* - oblock gives block size of nontransp. output
*/
void PX(local_block_partrafo)(
int rnk_n, const INT *ni, const INT *no,
const INT *iblock_user, const INT *oblock_user,
MPI_Comm comm, int pid,
unsigned trafo_flag_user, unsigned pfft_flags,
INT *local_ni, INT *local_i_start,
INT *local_no, INT *local_o_start
)
{
unsigned transp_flag = extract_transp_flag(pfft_flags);
unsigned *trafo_flags_to, *trafo_flags_ti;
int rnk_pm, *np_pm, *coords_pm;
INT *pni_to, *pn_to, *pno_to, *pni_ti, *pn_ti, *pno_ti;
INT *iblk, *mblk, *oblk;
INT *dummy_ln, *dummy_ls;
INT *lni_to, *lis_to, *lno_to, *los_to;
INT *lni_ti, *lis_ti, *lno_ti, *los_ti;
MPI_Comm comm_cart = PX(assure_cart_comm)(comm);
malloc_and_compute_cart_np_and_coords(rnk_n, transp_flag, comm_cart, pid,
&rnk_pm, &np_pm, &coords_pm);
pni_to = PX(malloc_INT)(rnk_n);
pn_to = PX(malloc_INT)(rnk_n);
pno_to = PX(malloc_INT)(rnk_n);
pni_ti = PX(malloc_INT)(rnk_n);
pn_ti = PX(malloc_INT)(rnk_n);
pno_ti = PX(malloc_INT)(rnk_n);
trafo_flags_to = PX(malloc_unsigned)(rnk_pm + 1);
trafo_flags_ti = PX(malloc_unsigned)(rnk_pm + 1);
dummy_ln = PX(malloc_INT)(rnk_n);
dummy_ls = PX(malloc_INT)(rnk_n);
iblk = PX(malloc_INT)(rnk_pm);
mblk = PX(malloc_INT)(rnk_pm);
oblk = PX(malloc_INT)(rnk_pm);
/* calculate blocksizes according to trafo and transp flags */
evaluate_blocks(rnk_n, ni, no, iblock_user, oblock_user,
rnk_pm, np_pm, trafo_flag_user, transp_flag,
iblk, mblk, oblk);
/* split trafo into transposed out and transposed in step */
init_param_size_and_trafo_flags(
rnk_n, ni, ni, no, rnk_pm, trafo_flag_user, transp_flag, NULL,
pn_to, pni_to, pno_to, trafo_flags_to,
pn_ti, pni_ti, pno_ti, trafo_flags_ti);
init_param_local_size(
local_ni, local_i_start, dummy_ln, dummy_ls,
local_no, local_o_start, transp_flag,
&lni_to, &lis_to, &lno_to, &los_to,
&lni_ti, &lis_ti, &lno_ti, &los_ti);
/* overwrite input blocks if remap_3dto2d is used */
if( ~transp_flag & PFFT_TRANSPOSED_IN ){
PX(local_block_partrafo_transposed)(
rnk_n, pni_to, pno_to, iblk, mblk,
rnk_pm, coords_pm, PFFT_TRANSPOSED_OUT, trafo_flags_to[rnk_pm],
lni_to, lis_to, lno_to, los_to);
PX(local_block_remap_3dto2d_transposed)(
rnk_n, pni_to, comm_cart, pid, PFFT_TRANSPOSED_OUT, trafo_flags_to[rnk_pm],
local_ni, local_i_start, dummy_ln, dummy_ls);
}
/* overwrite input blocks if remap_3dto2d is used */
if( ~transp_flag & PFFT_TRANSPOSED_OUT ){
PX(local_block_partrafo_transposed)(
rnk_n, pni_ti, pno_ti, mblk, oblk,
rnk_pm, coords_pm, PFFT_TRANSPOSED_IN, trafo_flags_ti[rnk_pm],
lni_ti, lis_ti, lno_ti, los_ti);
PX(local_block_remap_3dto2d_transposed)(
rnk_n, pno_ti, comm_cart, pid, PFFT_TRANSPOSED_IN, trafo_flags_ti[rnk_pm],
dummy_ln, dummy_ls, local_no, local_o_start);
}
if(pfft_flags & PFFT_SHIFTED_IN){
for(int t=0; tlocal_ni, ths->local_ni_start,
ths->local_no, ths->local_no_start);
/* calculate blocksizes according to trafo and transp flags */
evaluate_blocks_by_comms(rnk_n, ni, no, iblock_user, oblock_user,
rnk_pm, comms_pm, trafo_flag, transp_flag,
iblk, mblk, oblk);
/* Avoid recalculation of the same parameters all the time. */
save_param_into_plan(rnk_n, n, ni, no, howmany, iblk, mblk, oblk,
comm_cart, rnk_pm, comms_pm, in, out, sign, kinds, skip_trafos_user,
transp_flag, trafo_flag, opt_flag, fftw_flags, pfft_flags,
ths);
/* split trafo into transposed out and transposed in step */
init_param_size_and_trafo_flags(
rnk_n, n, ni, no, rnk_pm, trafo_flag, transp_flag, skip_trafos_user,
pn_to, pni_to, pno_to, trafo_flags_to,
pn_ti, pni_ti, pno_ti, trafo_flags_ti);
/* For C2R trafos the output of the forward (transpose) step ends
* up in pointer 'in', since we skip the last local transposition.
* For all other trafos the input of the backward step is given by
* pointer 'out', since we ommit its first local transposition.
* So, for forward and backward steps we use 'in' for input and
* 'out' for output. */
/* conjugate inputs because fftw only supports backward trafo for c2r */
/* conjugate outputs because fftw only supports forward trafo for r2c */
if((sign == PFFT_FORWARD) && (trafo_flag & PFFTI_TRAFO_C2R)) {
if(io_flag & PFFT_DESTROY_INPUT)
ths->conjugate_in = ths->conjugate_out = in;
else {
ths->conjugate_in = in;
ths->conjugate_out = out;
/* Go on with in-place transforms in order to preserve input. */
in = out;
}
sign = ths->sign = PFFT_BACKWARD;
} else if((sign == PFFT_BACKWARD) && (trafo_flag & PFFTI_TRAFO_R2C)) {
ths->conjugate_in = ths->conjugate_out = out;
sign = ths->sign = PFFT_FORWARD;
} else
ths->conjugate_in = ths->conjugate_out = NULL;
/* twiddle inputs in order to get outputs shifted by n/2 */
if(pfft_flags & PFFT_SHIFTED_OUT){
if(io_flag & PFFT_DESTROY_INPUT){
ths->itwiddle_in = ths->itwiddle_out = in;
} else {
ths->itwiddle_in = in;
ths->itwiddle_out = out;
/* Go on with in-place transforms in order to preserve input. */
in = out;
}
} else
ths->itwiddle_in = ths->itwiddle_out = NULL;
/* plan with transposed output */
if(transp_flag & PFFT_TRANSPOSED_IN){
ths->remap_3dto2d[0] = NULL;
set_plans_to_null(rnk_pm, PFFT_TRANSPOSED_OUT,
ths->serial_trafo, ths->global_remap);
} else {
ths->remap_3dto2d[0] = PX(plan_remap_3dto2d_transposed)(
rnk_n, pni_to, howmany, comm_cart, in, out,
PFFT_TRANSPOSED_OUT, trafo_flags_to[rnk_pm], opt_flag, io_flag, fftw_flags);
/* If remap_3dto2d exists, go on with in-place transforms in order to preserve input. */
if( (ths->remap_3dto2d[0] != NULL) && (~io_flag & PFFT_DESTROY_INPUT) )
in = out;
PX(plan_partrafo_transposed)(
rnk_n, pn_to, pni_to, pno_to, howmany, iblk, mblk,
rnk_pm, comms_pm, in, out, sign, kinds,
PFFT_TRANSPOSED_OUT, trafo_flags_to, opt_flag, io_flag, si_flag, fftw_flags,
ths->serial_trafo, ths->global_remap);
/* Go on with in-place transforms in order to preserve input. */
if( ~io_flag & PFFT_DESTROY_INPUT)
in = out;
}
/* plan with transposed input */
if(transp_flag & PFFT_TRANSPOSED_OUT){
set_plans_to_null(rnk_pm, PFFT_TRANSPOSED_IN,
ths->serial_trafo, ths->global_remap);
ths->remap_3dto2d[1] = NULL;
} else {
PX(plan_partrafo_transposed)(
rnk_n, pn_ti, pni_ti, pno_ti, howmany, mblk, oblk,
rnk_pm, comms_pm, in, out, sign, kinds,
PFFT_TRANSPOSED_IN, trafo_flags_ti, opt_flag, io_flag, si_flag, fftw_flags,
ths->serial_trafo, ths->global_remap);
/* Go on with in-place transforms in order to preserve input. */
if( ~io_flag & PFFT_DESTROY_INPUT )
in = out;
ths->remap_3dto2d[1] = PX(plan_remap_3dto2d_transposed)(
rnk_n, pno_ti, howmany, comm_cart, out, in,
PFFT_TRANSPOSED_IN, trafo_flags_ti[rnk_pm], opt_flag, io_flag, fftw_flags);
}
/* twiddle outputs in order to get inputs shifted by n/2 */
if(pfft_flags & PFFT_SHIFTED_IN)
ths->otwiddle_in = ths->otwiddle_out = out;
else
ths->otwiddle_in = ths->otwiddle_out = NULL;
/* free one-dimensional comms */
for(int t=0; t First remap to 2d decomposition. */
static void malloc_and_split_cart_procmesh(
int rnk_n, unsigned transp_flag,
MPI_Comm comm_cart,
int *rnk_pm, MPI_Comm **comms_pm
)
{
MPI_Cartdim_get(comm_cart, rnk_pm);
if( PX(needs_3dto2d_remap)(rnk_n, comm_cart) )
*rnk_pm = 2;
*comms_pm = (MPI_Comm*) malloc(sizeof(MPI_Comm) * (size_t) *rnk_pm);
if( PX(needs_3dto2d_remap)(rnk_n, comm_cart) ){
PX(split_cart_procmesh_3dto2d_p0q0)(comm_cart,
*comms_pm + 0);
PX(split_cart_procmesh_3dto2d_p1q1)(comm_cart,
*comms_pm + 1);
} else
PX(split_cart_procmesh)(comm_cart, *comms_pm);
}
static void malloc_and_compute_cart_np_and_coords(
int rnk_n, unsigned transp_flag,
MPI_Comm comm_cart, int pid,
int *rnk_pm, int **np_pm, int **coords_pm
)
{
MPI_Cartdim_get(comm_cart, rnk_pm);
if( PX(needs_3dto2d_remap)(rnk_n, comm_cart) )
*rnk_pm = 2;
*np_pm = PX(malloc_int)(*rnk_pm);
*coords_pm = PX(malloc_int)(*rnk_pm);
if( PX(needs_3dto2d_remap)(rnk_n, comm_cart) ){
int p0, p1, q0, q1, coords_3d[3];
PX(get_procmesh_dims_2d)(comm_cart, &p0, &p1, &q0, &q1);
MPI_Cart_coords(comm_cart, pid, 3, coords_3d);
PX(coords_3dto2d)(q0, q1, coords_3d, *coords_pm);
(*np_pm)[0] = p0*q0; (*np_pm)[1] = p1*q1;
} else {
int *periods = PX(malloc_int)(*rnk_pm);
MPI_Cart_get(comm_cart, *rnk_pm, *np_pm, periods, *coords_pm);
MPI_Cart_coords(comm_cart, pid, *rnk_pm, *coords_pm);
free(periods);
}
}
static void save_param_into_plan(
int rnk_n, const INT *n, const INT *ni, const INT *no,
INT howmany, const INT *iblock, const INT *mblock, const INT *oblock,
MPI_Comm comm_cart, int rnk_pm, MPI_Comm *comms_pm,
R *in, R *out, int sign, const X(r2r_kind) *kinds, const int *skip_trafos_user,
unsigned transp_flag, unsigned trafo_flag,
unsigned opt_flag, unsigned fftw_flags,
unsigned pfft_flags,
PX(plan) ths
)
{
ths->rnk_n = rnk_n;
for(int t=0; tn[t] = n[t];
ths->ni[t] = ni[t];
ths->no[t] = no[t];
ths->skip_trafos[t] = (skip_trafos_user) ? skip_trafos_user[t] : 0;
}
ths->howmany = howmany;
for(int t=0; tiblock[t] = iblock[t];
ths->mblock[t] = mblock[t];
ths->oblock[t] = oblock[t];
}
MPI_Comm_dup(comm_cart, &ths->comm_cart);
ths->rnk_pm = rnk_pm;
for(int t=0; tcomms_pm[t]);
MPI_Comm_size(ths->comms_pm[t], &ths->np[t]);
}
ths->in = in;
ths->out = out;
ths->sign = sign;
if(kinds != NULL){
ths->kinds = (X(r2r_kind)*) malloc(sizeof(X(r2r_kind)) * (size_t) rnk_n);
for(int t=0; tkinds[t] = kinds[t];
} else
ths->kinds = NULL;
ths->fftw_flags = fftw_flags;
ths->transp_flag = transp_flag;
ths->trafo_flag = trafo_flag;
ths->opt_flag = opt_flag;
ths->pfft_flags = pfft_flags;
}
static void evaluate_blocks_by_comms(
int rnk_n, const INT *ni, const INT *no,
const INT *iblock_user, const INT *oblock_user,
int rnk_pm, const MPI_Comm *comms_pm,
unsigned trafo_flag, unsigned transp_flag,
INT *iblk, INT *mblk, INT *oblk
)
{
int *np_pm = PX(malloc_int)(rnk_pm);
for(int t=0; tn = PX(malloc_INT)(rnk_n);
ths->ni = PX(malloc_INT)(rnk_n);
ths->no = PX(malloc_INT)(rnk_n);
ths->local_ni = PX(malloc_INT)(rnk_n);
ths->local_no = PX(malloc_INT)(rnk_n);
ths->local_ni_start = PX(malloc_INT)(rnk_n);
ths->local_no_start = PX(malloc_INT)(rnk_n);
ths->iblock = PX(malloc_INT)(rnk_pm);
ths->mblock = PX(malloc_INT)(rnk_pm);
ths->oblock = PX(malloc_INT)(rnk_pm);
ths->comms_pm = (MPI_Comm*) malloc(sizeof(MPI_Comm) * (size_t) rnk_pm);
ths->np = PX(malloc_int)(rnk_pm);
ths->kinds = NULL; /* allocate later if needed */
ths->skip_trafos = PX(malloc_int)(rnk_n);
/* allocate array of plans */
ths->serial_trafo = (outrafo_plan *)
malloc(sizeof(outrafo_plan) * (size_t) (2*rnk_pm+2));
ths->global_remap = (gtransp_plan *)
malloc(sizeof(gtransp_plan) * (size_t) (2*rnk_pm));
/* allocate timer and set all times to zero */
ths->timer = PX(mktimer)(rnk_pm);
return ths;
}
void PX(rmplan)(
PX(plan) ths
)
{
/* plan was already destroyed or never initialized */
if(ths==NULL)
return;
for(int t=0; t<2*ths->rnk_pm+2; t++)
PX(outrafo_rmplan)(ths->serial_trafo[t]);
free(ths->serial_trafo);
for(int t=0; t<2*ths->rnk_pm; t++)
PX(gtransp_rmplan)(ths->global_remap[t]);
free(ths->global_remap);
free(ths->ni); free(ths->n); free(ths->no);
free(ths->local_ni); free(ths->local_no);
free(ths->local_ni_start); free(ths->local_no_start);
free(ths->iblock); free(ths->mblock); free(ths->oblock);
MPI_Comm_free(&ths->comm_cart);
for(int t=0; trnk_pm; t++)
MPI_Comm_free(&ths->comms_pm[t]);
free(ths->np);
free(ths->comms_pm);
if(ths->kinds != NULL)
free(ths->kinds);
free(ths->skip_trafos);
for(int t=0; t<2; t++)
PX(remap_3dto2d_rmplan)(ths->remap_3dto2d[t]);
PX(destroy_timer)(ths->timer);
/* free memory */
free(ths);
/* ths=NULL; would be senseless, since we can not change the pointer itself */
}
static unsigned extract_transp_flag(
unsigned pfft_flags
)
{
unsigned transp_flag = PFFT_TRANSPOSED_NONE;
if(pfft_flags & PFFT_TRANSPOSED_IN)
transp_flag |= PFFT_TRANSPOSED_IN;
if(pfft_flags & PFFT_TRANSPOSED_OUT)
transp_flag |= PFFT_TRANSPOSED_OUT;
return transp_flag;
}
static unsigned extract_opt_flag(
unsigned pfft_flags
)
{
return (pfft_flags & PFFT_NO_TUNE) ? PFFT_NO_TUNE : 0;
}
static unsigned extract_io_flag(
unsigned pfft_flags
)
{
unsigned flag = 0;
if( (pfft_flags & PFFT_DESTROY_INPUT) && (~pfft_flags & PFFT_PRESERVE_INPUT))
flag |= PFFT_DESTROY_INPUT;
else
flag |= PFFT_PRESERVE_INPUT;
return flag;
}
static unsigned extract_shift_index_flag(
unsigned pfft_flags
)
{
unsigned si_flag = PFFT_SHIFTED_NONE;
if(pfft_flags & PFFT_SHIFTED_IN)
si_flag |= PFFT_SHIFTED_IN;
if(pfft_flags & PFFT_SHIFTED_OUT)
si_flag |= PFFT_SHIFTED_OUT;
return si_flag;
}
/* Assure that only PFFT-compatible FFTW flags are used. */
static unsigned extract_fftw_flags(
unsigned pfft_flags
)
{
unsigned fftw_flags = FFTW_MEASURE;
if(pfft_flags & PFFT_ESTIMATE)
fftw_flags = FFTW_ESTIMATE;
if(pfft_flags & PFFT_PATIENT)
fftw_flags = FFTW_PATIENT;
if(pfft_flags & PFFT_EXHAUSTIVE)
fftw_flags = FFTW_EXHAUSTIVE;
/* PFFT_PRESERVE_INPUT needs one out-of-place FFT with
* FFTW_PRESERVE_INPUT followed by several in-place FFTs */
if( (pfft_flags & PFFT_DESTROY_INPUT) && (~pfft_flags & PFFT_PRESERVE_INPUT))
fftw_flags |= FFTW_DESTROY_INPUT;
else
fftw_flags |= FFTW_PRESERVE_INPUT;
return fftw_flags;
}