Fixed params? How could make this more elegant? #1388
-
Hello! I'm trying to use Flax in my research. I wanted to have a module or flass to calculate the laplacian of a matrix using the spectral method. For this, I need my module to save a filter ( The central problems are:
Here was my attempt, where calculate def applyfilter(phi, A):
return jnp.fft.ifft2(phi * jnp.fft.fft2(A)).real
class SpectralLaplacian(nn.Module):
dz: float
dx: float
def gen_filter(self, A, dz=1., dx=1.):
nz, nx = A.shape
dkz = 2*np.pi/(nz*dz)
dkx = 2*np.pi/(nx*dx)
kz, kx = jnp.mgrid[-nz/2:nz-nz/2, -nx/2:nx-nx/2]
kz = jnp.fft.ifftshift(kz) * dkz
kx = jnp.fft.ifftshift(kx) * dkx
phi = -(kz**2 + kx**2)
return phi
def setup(self):
# I wanted to maybe set phi up here somehow...
pass
def __call__(self, A):
phi = self.gen_filter(A, self.dz, self.dx) # phi should only be generated once
return applyfilter(phi, A) How could I solve this problem? What would be an elegant way to do it? I'm open for tips! Thank you beforehand! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Consider using a variable that is not a param for phi (https://flax.readthedocs.io/en/latest/flax.linen.html#flax.linen.Module.variable). You can define your own collection name for this filter like
|
Beta Was this translation helpful? Give feedback.
Consider using a variable that is not a param for phi (https://flax.readthedocs.io/en/latest/flax.linen.html#flax.linen.Module.variable). You can define your own collection name for this filter like