Skip to content
29 changes: 29 additions & 0 deletions include/gauge_field.h
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,13 @@ namespace quda {
*/
virtual void copy(const GaugeField &src) = 0;

/**
* @brief Generic gauge field shift
* @param[in] src Source from which we are shifting (extended field in case of MPI)
* @param[in] dx Host array of shifts to apply to the field
*/
virtual void shift(const GaugeField &src, const array<int, 4> &dx) = 0;

/**
@brief Compute the L1 norm of the field
@param[in] dim Which dimension we are taking the norm of (dim=-1 mean all dimensions)
Expand Down Expand Up @@ -535,6 +542,13 @@ namespace quda {
*/
void copy(const GaugeField &src);

/**
* @brief Generic gauge field shift
* @param[in] src Source from which we are shifting (extended field in case of MPI)
* @param[in] dx Host array of shifts to apply to the field
*/
void shift(const GaugeField &src, const array<int, 4> &dx);

/**
@brief Download into this field from a CPU field
@param[in] cpu The CPU field source
Expand Down Expand Up @@ -672,6 +686,13 @@ namespace quda {
*/
void copy(const GaugeField &src);

/**
* @brief Generic gauge field shift
* @param[in] src Source from which we are shifting (extended field in case of MPI)
* @param[in] dx Host array of shifts to apply to the field
*/
void shift(const GaugeField &src, const array<int, 4> &dx);

void* Gauge_p() { return gauge; }
const void* Gauge_p() const { return gauge; }

Expand Down Expand Up @@ -864,4 +885,12 @@ namespace quda {

#define checkReconstruct(...) Reconstruct_(__func__, __FILE__, __LINE__, __VA_ARGS__)

/**
* @brief Generic gauge field shift
* @param[out] dst Gauge field to store output
* @param[in] srd Source from which we are shifting (extended field in case of MPI)
* @param[in] dx Host array of shifts to apply to the field
*/
void gaugeShift(GaugeField &dst, const GaugeField &src, const array<int, 4> &dx);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be moved to e.g. gauge_tools.h?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that seems like a consistent place for it.


} // namespace quda
61 changes: 61 additions & 0 deletions include/kernels/gauge_shift.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#pragma once

#include <gauge_field_order.h>
#include <quda_matrix.h>
#include <index_helper.cuh>
#include <kernel.h>

namespace quda
{

template <typename Float_, int nColor_, QudaReconstructType recon_u> struct GaugeShiftArg : kernel_param<> {
using Float = Float_;
static constexpr int nColor = nColor_;
static_assert(nColor == 3, "Only nColor=3 enabled at this time");
typedef typename gauge_mapper<Float, recon_u>::type Gauge;

Gauge out;
const Gauge in;

int S[4]; // the regular volume parameters
int X[4]; // the regular volume parameters
int E[4]; // the extended volume parameters
int border[4]; // radius of border
int P; // change of parity

GaugeShiftArg(GaugeField &out, const GaugeField &in, const array<int, 4> &dx) :
kernel_param(dim3(in.VolumeCB(), 2, in.Geometry())), out(out), in(in)
{
P = 0;
for (int i = 0; i < 4; i++) {
S[i] = dx[i];
X[i] = out.X()[i];
E[i] = in.X()[i];
border[i] = (E[i] - X[i]) / 2;
P += dx[i];
}
P = std::abs(P) % 2;
}
};

template <typename Arg> struct GaugeShift {
const Arg &arg;
constexpr GaugeShift(const Arg &arg) : arg(arg) { }
static constexpr const char *filename() { return KERNEL_FILE; }

__device__ __host__ void operator()(int x_cb, int parity, int dir)
{
using real = typename Arg::Float;
typedef Matrix<complex<real>, Arg::nColor> Link;

int x[4] = {0, 0, 0, 0};
getCoords(x, x_cb, arg.X, parity);
for (int dr = 0; dr < 4; ++dr) x[dr] += arg.border[dr]; // extended grid coordinates
int nbr_oddbit = arg.P == 1 ? (parity ^ 1) : parity;

Link link = arg.in(dir, linkIndexShift(x, arg.S, arg.E), nbr_oddbit);
arg.out(dir, x_cb, parity) = link;
}
};

} // namespace quda
1 change: 1 addition & 0 deletions lib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ set (QUDA_OBJS
copy_gauge_half.cu copy_gauge_quarter.cu
copy_gauge.cpp copy_clover.cu
copy_gauge_offset.cu copy_color_spinor_offset.cu copy_clover_offset.cu
gauge_shift.cu
staggered_oprod.cu clover_trace_quda.cu
hisq_paths_force_quda.cu
unitarize_force_quda.cu unitarize_links_quda.cu milc_interface.cpp
Expand Down
20 changes: 20 additions & 0 deletions lib/cpu_gauge_field.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,26 @@ namespace quda {
}
}

void cpuGaugeField::shift(const GaugeField &src, const array<int, 4> &dx)
{
for (int i = 0; i < this->nDim; i++) {
if (dx[i] != 0) break;
// if zero shift, we simply copy
if (i == this->nDim - 1) return this->copy(src);
}
if (this == &src) errorQuda("Cannot copy in itself");

checkField(src);

// TODO: check src extension (needs to be enough for shifting)

if (typeid(src) == typeid(cudaGaugeField)) {
errorQuda("Not Implemented");
} else {
errorQuda("Not compatible type");
}
}

void cpuGaugeField::setGauge(void **gauge_)
{
if(create != QUDA_REFERENCE_FIELD_CREATE) {
Expand Down
19 changes: 19 additions & 0 deletions lib/cuda_gauge_field.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,25 @@ namespace quda {
qudaDeviceSynchronize(); // include sync here for accurate host-device profiling
}

void cudaGaugeField::shift(const GaugeField &src, const array<int, 4> &dx)
{
for (int i = 0; i < this->nDim; i++) {
if (dx[i] != 0) break;
if (i == this->nDim - 1) return this->copy(src);
}
if (this == &src) errorQuda("Cannot copy in itself");

checkField(src);

// TODO: check src extension (needs to be enough for shifting)

if (typeid(src) == typeid(cudaGaugeField)) {
gaugeShift(*this, src, dx);
} else {
errorQuda("Not compatible type");
}
}

void cudaGaugeField::loadCPUField(const cpuGaugeField &cpu) {
copy(cpu);
qudaDeviceSynchronize();
Expand Down
53 changes: 53 additions & 0 deletions lib/gauge_shift.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#include <tunable_nd.h>
#include <instantiate.h>
#include <gauge_field.h>
#include <kernels/gauge_shift.cuh>

namespace quda
{

template <typename Float, int nColor, QudaReconstructType recon_u> class ShiftGauge : public TunableKernel3D
{
GaugeField &out;
const GaugeField &in;
const array<int, 4> &dx;
unsigned int minThreads() const { return in.VolumeCB(); }

public:
ShiftGauge(GaugeField &out, const GaugeField &in, const array<int, 4> &dx) :
TunableKernel3D(in, 2, in.Geometry()), out(out), in(in), dx(dx)
{
strcat(aux, ",shift=");
for (int i = 0; i < in.Ndim(); i++) { strcat(aux, std::to_string(dx[i]).c_str()); }
strcat(aux, comm_dim_partitioned_string());
apply(device::get_default_stream());
}

void apply(const qudaStream_t &stream)
{
TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity());
launch<GaugeShift>(tp, stream, GaugeShiftArg<Float, nColor, recon_u>(out, in, dx));
}

void preTune() { }
void postTune() { }

long long flops() const { return in.Volume() * 4; }
long long bytes() const { return in.Bytes(); }
};

void gaugeShift(GaugeField &out, const GaugeField &in, const array<int, 4> &dx)
{
checkPrecision(in, out);
checkLocation(in, out);
checkReconstruct(in, out);

if (out.Geometry() != in.Geometry()) {
errorQuda("Field geometries %d %d do not match", out.Geometry(), in.Geometry());
}

// gauge field must be passed as first argument so we peel off its reconstruct type
instantiate<ShiftGauge>(out, in, dx);
}

} // namespace quda