Skip to content

Commit

Permalink
Add sample and mpeg command (#24)
Browse files Browse the repository at this point in the history
  • Loading branch information
tk2lab authored Apr 9, 2022
1 parent e28c71c commit 0b98ec7
Show file tree
Hide file tree
Showing 12 changed files with 144 additions and 76 deletions.
12 changes: 7 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,15 @@ pip install hotaru

### Demonstration
```shell
cd sample
python make.py
hotaru --tag default trial
hotaru --tag d12 trial
hotaru sample --outdir mysample
cd mysample
hotaru config
# edit mysample/hotaru.ini if you need
hotaru trial
hotaru auto
python mpeg.py d12
hotaru mpeg --has-truth
```
- see `mysample/hotaru/figure/`

[Sample Video](https://drive.google.com/file/d/12jl1YTZDuNAq94ciJ-_Cj5tBcKmCqgRH)

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "hotaru"
version = "4.0.0"
version = "4.0.1"
description = "High performance Optimizer to extract spike Timing And cell location from calcium imaging data via lineaR impUlse"
license = "GPL-3.0-only"
authors = ["TAKEKAWA Takashi <[email protected]>"]
Expand Down
8 changes: 0 additions & 8 deletions sample/hotaru.ini

This file was deleted.

5 changes: 0 additions & 5 deletions sample/make.py

This file was deleted.

6 changes: 0 additions & 6 deletions sample/mpeg.py

This file was deleted.

12 changes: 12 additions & 0 deletions src/hotaru/console/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import pkgutil

import click


@click.command()
def config():
'''Make hotaru.ini'''

data = pkgutil.get_data('hotaru.console', 'hotaru.ini').decode('utf-8')
with open('hotaru.ini', 'w') as f:
f.write(data)
21 changes: 15 additions & 6 deletions src/hotaru/console/hotaru.ini
Original file line number Diff line number Diff line change
@@ -1,6 +1,20 @@
[main]
tag = default

[default]

[radius32]
data_tag = default
radius_max = 32.0

[distance12]
data_tag = default
find_tag = default
distance = 1.2


[DEFAULT]
workdir = hotaru
tag = default
data_tag =
find_tag =
init_tag =
Expand Down Expand Up @@ -36,8 +50,3 @@ epoch = 100
steps = 100
batch = 100
window = 100

[main]
tag = default

[default]
6 changes: 6 additions & 0 deletions src/hotaru/console/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import click

from .obj import Obj
from .config import config
from .run.data import data
from .run.find import find
from .run.init import init
Expand All @@ -14,6 +15,8 @@
from .trial import trial
from .auto import auto
from .figure import figure
from .mpeg import mpeg
from .sample import sample


def configure(ctx, param, configfile):
Expand Down Expand Up @@ -64,6 +67,7 @@ def main(ctx, config, tag, **args):
del obj['quit']


main.add_command(config)
main.add_command(data)
main.add_command(find)
main.add_command(init)
Expand All @@ -74,3 +78,5 @@ def main(ctx, config, tag, **args):
main.add_command(trial)
main.add_command(auto)
main.add_command(figure)
main.add_command(mpeg)
main.add_command(sample)
17 changes: 17 additions & 0 deletions src/hotaru/console/mpeg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import click

from hotaru.sim.mpeg import make_mpeg


@click.command()
@click.option('--stage', '-s', type=int)
@click.option('--has-truth', is_flag=True)
@click.pass_obj
def mpeg(obj, stage, has_truth):
'''Make Mp4'''

if stage is None:
stage = 'curr'
else:
stage = f'{stage:03}'
make_mpeg(obj.data_tag, obj.tag, stage, has_truth)
27 changes: 27 additions & 0 deletions src/hotaru/console/sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import click

from hotaru.sim.make import make_sim


@click.command()
@click.option('--outdir', type=click.Path(), default='sample', show_default=True)
@click.option('--width', type=int, default=200, show_default=True)
@click.option('--height', type=int, default=200, show_default=True)
@click.option('--frames', type=int, default=1000, show_default=True)
@click.option('--hz', type=float, default=20.0, show_default=True)
@click.option('--num-neurons', type=int, default=1000, show_default=True)
@click.option('--intensity-min', type=float, default=0.5, show_default=True)
@click.option('--intensity-max', type=float, default=2.0, show_default=True)
@click.option('--radius-min', type=float, default=4.0, show_default=True)
@click.option('--radius-max', type=float, default=8.0, show_default=True)
@click.option('--radius-min', type=float, default=4.0, show_default=True)
@click.option('--distance', type=float, default=1.8, show_default=True)
@click.option('--firingrate-min', type=float, default=0.2, show_default=True)
@click.option('--firingrate-max', type=float, default=2.2, show_default=True)
@click.option('--tau_rise', type=float, default=0.08, show_default=True)
@click.option('--tau_fall', type=float, default=0.18, show_default=True)
@click.option('--seed', type=int)
def sample(**args):
'''Make Sample'''

make_sim(**args)
40 changes: 22 additions & 18 deletions src/hotaru/sim/make.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import random
import os

import tensorflow as tf

Expand All @@ -11,11 +12,24 @@
from hotaru.train.dynamics import SpikeToCalcium


def make_sim(nv=1000, h=200, w=200, hz=20.0, nk=1000, sig_min=0.5, sig_max=2.0, min_dist=1.8, seed=None):
def make_sim(
outdir='sample',
frames=1000, height=200, width=200, hz=20.0, num_neurons=1000,
intensity_min=0.5, intensity_max=2.0, radius_min=4.0, radius_max=8.0,
firingrate_min=0.2, firingrate_max=2.2, distance=1.8,
tau_rise=0.08, tau_fall=0.16, seed=None,
):
os.makedirs(f'{outdir}/truth', exist_ok=True)
w, h, nv = width, height, frames
nk = num_neurons
sig_min, sig_max = intensity_min, intensity_max
rad_min, rad_max = radius_min, radius_max
fr_min, fr_max = firingrate_min, firingrate_max
min_dist = distance
np.random.seed(seed)

gamma = SpikeToCalcium()
gamma.set_double_exp(hz, 0.08, 0.16, 6.0)
gamma.set_double_exp(hz, tau_rise, tau_fall, 6.0)
nu = nv + gamma.pad

a_true = []
Expand All @@ -35,14 +49,14 @@ def make_sim(nv=1000, h=200, w=200, hz=20.0, nk=1000, sig_min=0.5, sig_max=2.0,
if ylist.size == 0:
break
for i in range(10):
radius = st.uniform(4.0, 4.0).rvs()
radius = st.uniform(rad_min, rad_max - rad_min).rvs()
r = np.random.randint(ylist.size)
y0, x0 = ylist[r], xlist[r]
if (len(xs) == 0) or np.all((np.array(xs) - x0)**2 + (np.array(ys) - y0)**2 > (min_dist*radius)**2):
break
if i == 9:
break
rate = st.uniform(0.2 / hz, 2.0 / hz).rvs()
rate = st.uniform(fr_min / hz, (fr_max - fr_min) / hz).rvs()
ui = st.bernoulli(rate).rvs(nu)
signal = st.uniform(sig_min, sig_max).rvs()
cond = (xlist - x0)**2 + (ylist - y0)**2 > (min_dist*radius)**2
Expand Down Expand Up @@ -75,17 +89,7 @@ def make_sim(nv=1000, h=200, w=200, hz=20.0, nk=1000, sig_min=0.5, sig_max=2.0,
base_x = -5.0 * (((x - hh) / hh)**2 + ((y - wh) / wh)**2)
imgs = (f_t + n_t + base_t + base_x).numpy()

np.save('./a0.npy', a_t.numpy())
np.save('./u0.npy', u_t.numpy())
np.save('./v0.npy', v_t.numpy())
np.save('./f0.npy', imgs)
tifffile.imwrite('./imgs.tif', imgs)

imgs -= imgs.mean(axis=0, keepdims=True)
imgs -= imgs.mean(axis=(1, 2), keepdims=True)
imgs /= imgs.std()
tifffile.imwrite('./norm.tif', imgs)


if __name__ == '__main__':
make_sim(nk=100, sig_min=1.0)
np.save(f'{outdir}/truth/a0.npy', a_t.numpy())
np.save(f'{outdir}/truth/u0.npy', u_t.numpy())
np.save(f'{outdir}/truth/v0.npy', v_t.numpy())
tifffile.imwrite(f'{outdir}/imgs.tif', imgs)
64 changes: 37 additions & 27 deletions src/hotaru/sim/mpeg.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from hotaru.util.mpeg import MpegStream


def make_mpeg(tag='defautl', outfile='sample.mp4', stage='curr', data_tag='default'):
def make_mpeg(data_tag, tag, stage, has_truth=False):
outfile = f'hotaru/figure/{tag}_{stage}.mp4'
cmap = get_cmap('Greens')
p = load_pickle(f'hotaru/log/{data_tag}_data.pickle')
mask = p['mask']
Expand All @@ -25,11 +26,13 @@ def make_mpeg(tag='defautl', outfile='sample.mp4', stage='curr', data_tag='defau
data = load_tfrecord(f'hotaru/data/{data_tag}.tfrecord')
data = unmasked(data, mask)

a0 = load_numpy('a0.npy')
a0 = a0 > 0.5
v0 = load_numpy('v0.npy')
v0 -= v0.min(axis=1, keepdims=True)
v0 /= (smax - smin)
if has_truth:
a0 = load_numpy('truth/a0.npy')
a0 = a0 > 0.5
v0 = load_numpy('truth/v0.npy')
v0 -= v0.min(axis=1, keepdims=True)
v0 /= (smax - smin)
v0 *= 2.0

a2 = load_numpy(f'hotaru/segment/{tag}_{stage}.npy')
a2 = a2.reshape(-1, h, w)
Expand All @@ -41,26 +44,33 @@ def make_mpeg(tag='defautl', outfile='sample.mp4', stage='curr', data_tag='defau
v2 = g(u2).numpy()
v2 *= sstd / (smax - smin)
v2 -= v2.min(axis=1, keepdims=True)

v0 *= 2.0
v2 *= 2.0

filters = [
('drawtext', dict(x=10, y=10, text='orig')),
('drawtext', dict(x=10 + w, y=10, text='true')),
('drawtext', dict(x=10 + 2 * w, y=10, text='HOTARU')),
]
with MpegStream(3 * w, h, hz, outfile, filters) as mpeg:
for t, d in enumerate(data.as_numpy_iterator()):
imgo = d * sstd + avgt[t] + avgx
imgo = (imgo - smin) / (smax - smin)
img0 = np.clip((a0 * v0[:, t, None, None]).sum(axis=0), 0.0, 1.0)
img2 = np.clip((a2 * v2[:, t, None, None]).sum(axis=0), 0.0, 1.0)
img = np.concatenate([imgo, img0, img2], axis=1)
img = (255 * cmap(img)).astype(np.uint8)
mpeg.write(img)


if __name__ == '__main__':
import sys
make_mpeg(sys.argv[1])
if has_truth:
filters = [
('drawtext', dict(x=10, y=10, text='orig')),
('drawtext', dict(x=10 + w, y=10, text='true')),
('drawtext', dict(x=10 + 2 * w, y=10, text='HOTARU')),
]
with MpegStream(3 * w, h, hz, outfile, filters) as mpeg:
for t, d in enumerate(data.as_numpy_iterator()):
imgo = d * sstd + avgt[t] + avgx
imgo = (imgo - smin) / (smax - smin)
img0 = np.clip((a0 * v0[:, t, None, None]).sum(axis=0), 0.0, 1.0)
img2 = np.clip((a2 * v2[:, t, None, None]).sum(axis=0), 0.0, 1.0)
img = np.concatenate([imgo, img0, img2], axis=1)
img = (255 * cmap(img)).astype(np.uint8)
mpeg.write(img)
else:
filters = [
('drawtext', dict(x=10, y=10, text='orig')),
('drawtext', dict(x=10 + w, y=10, text='HOTARU')),
]
with MpegStream(2 * w, h, hz, outfile, filters) as mpeg:
for t, d in enumerate(data.as_numpy_iterator()):
imgo = d * sstd + avgt[t] + avgx
imgo = (imgo - smin) / (smax - smin)
img2 = np.clip((a2 * v2[:, t, None, None]).sum(axis=0), 0.0, 1.0)
img = np.concatenate([imgo, img2], axis=1)
img = (255 * cmap(img)).astype(np.uint8)
mpeg.write(img)

0 comments on commit 0b98ec7

Please sign in to comment.