|
| 1 | +#pragma once |
| 2 | + |
| 3 | +#include <gauge_field_order.h> |
| 4 | +#include <quda_matrix.h> |
| 5 | +#include <index_helper.cuh> |
| 6 | +#include <kernel.h> |
| 7 | +#include <thread_array.h> |
| 8 | + |
| 9 | +namespace quda { |
| 10 | + |
| 11 | + template <int dim_> |
| 12 | + struct paths { |
| 13 | + static constexpr int dim = dim_; |
| 14 | + const int num_paths; |
| 15 | + const int max_length; |
| 16 | + int *input_path[dim]; |
| 17 | + const int *length; |
| 18 | + const double *path_coeff; |
| 19 | + int *buffer; |
| 20 | + int count; |
| 21 | + |
| 22 | + paths(std::vector<int**>& input_path, std::vector<int>& length_h, std::vector<double>& path_coeff_h, int num_paths, int max_length) : |
| 23 | + num_paths(num_paths), |
| 24 | + max_length(max_length), |
| 25 | + count(0) |
| 26 | + { |
| 27 | + if (static_cast<int>(input_path.size()) != dim) |
| 28 | + errorQuda("Input path vector is of size %lu, expected %d", input_path.size(), dim); |
| 29 | + if (static_cast<int>(length_h.size()) != num_paths) |
| 30 | + errorQuda("Path length vector is of size %lu, expected %d", length_h.size(), num_paths); |
| 31 | + if (static_cast<int>(path_coeff_h.size()) != num_paths) |
| 32 | + errorQuda("Path coefficient vector is of size %lu, expected %d", path_coeff_h.size(), num_paths); |
| 33 | + |
| 34 | + // create path struct in a single allocation |
| 35 | + size_t bytes = dim * num_paths * max_length * sizeof(int) + num_paths * sizeof(int); |
| 36 | + int pad = ((sizeof(double) - bytes % sizeof(double)) % sizeof(double))/sizeof(int); |
| 37 | + bytes += pad*sizeof(int) + num_paths*sizeof(double); |
| 38 | + |
| 39 | + buffer = static_cast<int*>(pool_device_malloc(bytes)); |
| 40 | + int *path_h = static_cast<int*>(safe_malloc(bytes)); |
| 41 | + memset(path_h, 0, bytes); |
| 42 | + |
| 43 | + for (int dir=0; dir<dim; dir++) { |
| 44 | + // flatten the input_path array for copying to the device |
| 45 | + for (int i = 0; i < num_paths; i++) { |
| 46 | + for (int j = 0; j < length_h[i]; j++) { |
| 47 | + path_h[dir * num_paths * max_length + i * max_length + j] = input_path[dir][i][j]; |
| 48 | + if (dir==0) count++; |
| 49 | + } |
| 50 | + } |
| 51 | + } |
| 52 | + |
| 53 | + // length array |
| 54 | + memcpy(path_h + dim * num_paths * max_length, length_h.data(), num_paths*sizeof(int)); |
| 55 | + |
| 56 | + // path_coeff array |
| 57 | + memcpy(path_h + dim * num_paths * max_length + num_paths + pad, path_coeff_h.data(), num_paths*sizeof(double)); |
| 58 | + |
| 59 | + qudaMemcpy(buffer, path_h, bytes, qudaMemcpyHostToDevice); |
| 60 | + host_free(path_h); |
| 61 | + |
| 62 | + // finally set the pointers to the correct offsets in the buffer |
| 63 | + for (int d=0; d < dim; d++) this->input_path[d] = buffer + d*num_paths*max_length; |
| 64 | + length = buffer + dim*num_paths*max_length; |
| 65 | + path_coeff = reinterpret_cast<double*>(buffer + dim * num_paths * max_length + num_paths + pad); |
| 66 | + } |
| 67 | + |
| 68 | + void free() { |
| 69 | + pool_device_free(buffer); |
| 70 | + } |
| 71 | + }; |
| 72 | + |
| 73 | + constexpr int flipDir(int dir) { return (7-dir); } |
| 74 | + constexpr bool isForwards(int dir) { return (dir <= 3); } |
| 75 | + |
| 76 | + /** |
| 77 | + @brief Calculates an arbitary gauge path, returning the product matrix |
| 78 | +
|
| 79 | + @return The product of the gauge path |
| 80 | + @param[in] arg Kernel argumnt |
| 81 | + @param[in] x Full index array |
| 82 | + @param[in] parity Parity index (note: assumes that an offset from a non-zero dx is baked in) |
| 83 | + @param[in] path Gauge link path |
| 84 | + @param[in] length Length of gauge path |
| 85 | + @param[in] dx Temporary shared memory storage for relative coordinate shift |
| 86 | + */ |
| 87 | + template <typename Arg, typename I> |
| 88 | + __device__ __host__ inline typename Arg::Link |
| 89 | + computeGaugePath(const Arg &arg, int x[4], int parity, const int* path, int length, I& dx) |
| 90 | + { |
| 91 | + using Link = typename Arg::Link; |
| 92 | + |
| 93 | + // linkA: current matrix |
| 94 | + // linkB: the loaded matrix in this round |
| 95 | + Link linkA, linkB; |
| 96 | + setIdentity(&linkA); |
| 97 | + |
| 98 | + int nbr_oddbit = parity; |
| 99 | + |
| 100 | + for (int j = 0; j < length; j++) { |
| 101 | + |
| 102 | + int pathj = path[j]; |
| 103 | + int lnkdir = isForwards(pathj) ? pathj : flipDir(pathj); |
| 104 | + |
| 105 | + if (isForwards(pathj)) { |
| 106 | + linkB = arg.u(lnkdir, linkIndexShift(x,dx,arg.E), nbr_oddbit); |
| 107 | + linkA = linkA * linkB; |
| 108 | + dx[lnkdir]++; // now have to update to new location |
| 109 | + nbr_oddbit = nbr_oddbit^1; |
| 110 | + } else { |
| 111 | + dx[lnkdir]--; // if we are going backwards the link is on the adjacent site |
| 112 | + nbr_oddbit = nbr_oddbit^1; |
| 113 | + linkB = arg.u(lnkdir, linkIndexShift(x,dx,arg.E), nbr_oddbit); |
| 114 | + linkA = linkA * conj(linkB); |
| 115 | + } |
| 116 | + } //j |
| 117 | + |
| 118 | + return linkA; |
| 119 | + } |
| 120 | + |
| 121 | +} |
| 122 | + |
0 commit comments