Skip to content

Commit 0b98ec7

Browse files
authored
Add sample and mpeg command (#24)
1 parent e28c71c commit 0b98ec7

File tree

12 files changed

+144
-76
lines changed

12 files changed

+144
-76
lines changed

README.md

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,15 @@ pip install hotaru
3737

3838
### Demonstration
3939
```shell
40-
cd sample
41-
python make.py
42-
hotaru --tag default trial
43-
hotaru --tag d12 trial
40+
hotaru sample --outdir mysample
41+
cd mysample
42+
hotaru config
43+
# edit mysample/hotaru.ini if you need
44+
hotaru trial
4445
hotaru auto
45-
python mpeg.py d12
46+
hotaru mpeg --has-truth
4647
```
48+
- see `mysample/hotaru/figure/`
4749

4850
[Sample Video](https://drive.google.com/file/d/12jl1YTZDuNAq94ciJ-_Cj5tBcKmCqgRH)
4951

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "hotaru"
3-
version = "4.0.0"
3+
version = "4.0.1"
44
description = "High performance Optimizer to extract spike Timing And cell location from calcium imaging data via lineaR impUlse"
55
license = "GPL-3.0-only"
66
authors = ["TAKEKAWA Takashi <[email protected]>"]

sample/hotaru.ini

Lines changed: 0 additions & 8 deletions
This file was deleted.

sample/make.py

Lines changed: 0 additions & 5 deletions
This file was deleted.

sample/mpeg.py

Lines changed: 0 additions & 6 deletions
This file was deleted.

src/hotaru/console/config.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import pkgutil
2+
3+
import click
4+
5+
6+
@click.command()
7+
def config():
8+
'''Make hotaru.ini'''
9+
10+
data = pkgutil.get_data('hotaru.console', 'hotaru.ini').decode('utf-8')
11+
with open('hotaru.ini', 'w') as f:
12+
f.write(data)

src/hotaru/console/hotaru.ini

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,20 @@
1+
[main]
2+
tag = default
3+
4+
[default]
5+
6+
[radius32]
7+
data_tag = default
8+
radius_max = 32.0
9+
10+
[distance12]
11+
data_tag = default
12+
find_tag = default
13+
distance = 1.2
14+
15+
116
[DEFAULT]
217
workdir = hotaru
3-
tag = default
418
data_tag =
519
find_tag =
620
init_tag =
@@ -36,8 +50,3 @@ epoch = 100
3650
steps = 100
3751
batch = 100
3852
window = 100
39-
40-
[main]
41-
tag = default
42-
43-
[default]

src/hotaru/console/main.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import click
55

66
from .obj import Obj
7+
from .config import config
78
from .run.data import data
89
from .run.find import find
910
from .run.init import init
@@ -14,6 +15,8 @@
1415
from .trial import trial
1516
from .auto import auto
1617
from .figure import figure
18+
from .mpeg import mpeg
19+
from .sample import sample
1720

1821

1922
def configure(ctx, param, configfile):
@@ -64,6 +67,7 @@ def main(ctx, config, tag, **args):
6467
del obj['quit']
6568

6669

70+
main.add_command(config)
6771
main.add_command(data)
6872
main.add_command(find)
6973
main.add_command(init)
@@ -74,3 +78,5 @@ def main(ctx, config, tag, **args):
7478
main.add_command(trial)
7579
main.add_command(auto)
7680
main.add_command(figure)
81+
main.add_command(mpeg)
82+
main.add_command(sample)

src/hotaru/console/mpeg.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import click
2+
3+
from hotaru.sim.mpeg import make_mpeg
4+
5+
6+
@click.command()
7+
@click.option('--stage', '-s', type=int)
8+
@click.option('--has-truth', is_flag=True)
9+
@click.pass_obj
10+
def mpeg(obj, stage, has_truth):
11+
'''Make Mp4'''
12+
13+
if stage is None:
14+
stage = 'curr'
15+
else:
16+
stage = f'{stage:03}'
17+
make_mpeg(obj.data_tag, obj.tag, stage, has_truth)

src/hotaru/console/sample.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import click
2+
3+
from hotaru.sim.make import make_sim
4+
5+
6+
@click.command()
7+
@click.option('--outdir', type=click.Path(), default='sample', show_default=True)
8+
@click.option('--width', type=int, default=200, show_default=True)
9+
@click.option('--height', type=int, default=200, show_default=True)
10+
@click.option('--frames', type=int, default=1000, show_default=True)
11+
@click.option('--hz', type=float, default=20.0, show_default=True)
12+
@click.option('--num-neurons', type=int, default=1000, show_default=True)
13+
@click.option('--intensity-min', type=float, default=0.5, show_default=True)
14+
@click.option('--intensity-max', type=float, default=2.0, show_default=True)
15+
@click.option('--radius-min', type=float, default=4.0, show_default=True)
16+
@click.option('--radius-max', type=float, default=8.0, show_default=True)
17+
@click.option('--radius-min', type=float, default=4.0, show_default=True)
18+
@click.option('--distance', type=float, default=1.8, show_default=True)
19+
@click.option('--firingrate-min', type=float, default=0.2, show_default=True)
20+
@click.option('--firingrate-max', type=float, default=2.2, show_default=True)
21+
@click.option('--tau_rise', type=float, default=0.08, show_default=True)
22+
@click.option('--tau_fall', type=float, default=0.18, show_default=True)
23+
@click.option('--seed', type=int)
24+
def sample(**args):
25+
'''Make Sample'''
26+
27+
make_sim(**args)

src/hotaru/sim/make.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import random
2+
import os
23

34
import tensorflow as tf
45

@@ -11,11 +12,24 @@
1112
from hotaru.train.dynamics import SpikeToCalcium
1213

1314

14-
def make_sim(nv=1000, h=200, w=200, hz=20.0, nk=1000, sig_min=0.5, sig_max=2.0, min_dist=1.8, seed=None):
15+
def make_sim(
16+
outdir='sample',
17+
frames=1000, height=200, width=200, hz=20.0, num_neurons=1000,
18+
intensity_min=0.5, intensity_max=2.0, radius_min=4.0, radius_max=8.0,
19+
firingrate_min=0.2, firingrate_max=2.2, distance=1.8,
20+
tau_rise=0.08, tau_fall=0.16, seed=None,
21+
):
22+
os.makedirs(f'{outdir}/truth', exist_ok=True)
23+
w, h, nv = width, height, frames
24+
nk = num_neurons
25+
sig_min, sig_max = intensity_min, intensity_max
26+
rad_min, rad_max = radius_min, radius_max
27+
fr_min, fr_max = firingrate_min, firingrate_max
28+
min_dist = distance
1529
np.random.seed(seed)
1630

1731
gamma = SpikeToCalcium()
18-
gamma.set_double_exp(hz, 0.08, 0.16, 6.0)
32+
gamma.set_double_exp(hz, tau_rise, tau_fall, 6.0)
1933
nu = nv + gamma.pad
2034

2135
a_true = []
@@ -35,14 +49,14 @@ def make_sim(nv=1000, h=200, w=200, hz=20.0, nk=1000, sig_min=0.5, sig_max=2.0,
3549
if ylist.size == 0:
3650
break
3751
for i in range(10):
38-
radius = st.uniform(4.0, 4.0).rvs()
52+
radius = st.uniform(rad_min, rad_max - rad_min).rvs()
3953
r = np.random.randint(ylist.size)
4054
y0, x0 = ylist[r], xlist[r]
4155
if (len(xs) == 0) or np.all((np.array(xs) - x0)**2 + (np.array(ys) - y0)**2 > (min_dist*radius)**2):
4256
break
4357
if i == 9:
4458
break
45-
rate = st.uniform(0.2 / hz, 2.0 / hz).rvs()
59+
rate = st.uniform(fr_min / hz, (fr_max - fr_min) / hz).rvs()
4660
ui = st.bernoulli(rate).rvs(nu)
4761
signal = st.uniform(sig_min, sig_max).rvs()
4862
cond = (xlist - x0)**2 + (ylist - y0)**2 > (min_dist*radius)**2
@@ -75,17 +89,7 @@ def make_sim(nv=1000, h=200, w=200, hz=20.0, nk=1000, sig_min=0.5, sig_max=2.0,
7589
base_x = -5.0 * (((x - hh) / hh)**2 + ((y - wh) / wh)**2)
7690
imgs = (f_t + n_t + base_t + base_x).numpy()
7791

78-
np.save('./a0.npy', a_t.numpy())
79-
np.save('./u0.npy', u_t.numpy())
80-
np.save('./v0.npy', v_t.numpy())
81-
np.save('./f0.npy', imgs)
82-
tifffile.imwrite('./imgs.tif', imgs)
83-
84-
imgs -= imgs.mean(axis=0, keepdims=True)
85-
imgs -= imgs.mean(axis=(1, 2), keepdims=True)
86-
imgs /= imgs.std()
87-
tifffile.imwrite('./norm.tif', imgs)
88-
89-
90-
if __name__ == '__main__':
91-
make_sim(nk=100, sig_min=1.0)
92+
np.save(f'{outdir}/truth/a0.npy', a_t.numpy())
93+
np.save(f'{outdir}/truth/u0.npy', u_t.numpy())
94+
np.save(f'{outdir}/truth/v0.npy', v_t.numpy())
95+
tifffile.imwrite(f'{outdir}/imgs.tif', imgs)

src/hotaru/sim/mpeg.py

Lines changed: 37 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
from hotaru.util.mpeg import MpegStream
1111

1212

13-
def make_mpeg(tag='defautl', outfile='sample.mp4', stage='curr', data_tag='default'):
13+
def make_mpeg(data_tag, tag, stage, has_truth=False):
14+
outfile = f'hotaru/figure/{tag}_{stage}.mp4'
1415
cmap = get_cmap('Greens')
1516
p = load_pickle(f'hotaru/log/{data_tag}_data.pickle')
1617
mask = p['mask']
@@ -25,11 +26,13 @@ def make_mpeg(tag='defautl', outfile='sample.mp4', stage='curr', data_tag='defau
2526
data = load_tfrecord(f'hotaru/data/{data_tag}.tfrecord')
2627
data = unmasked(data, mask)
2728

28-
a0 = load_numpy('a0.npy')
29-
a0 = a0 > 0.5
30-
v0 = load_numpy('v0.npy')
31-
v0 -= v0.min(axis=1, keepdims=True)
32-
v0 /= (smax - smin)
29+
if has_truth:
30+
a0 = load_numpy('truth/a0.npy')
31+
a0 = a0 > 0.5
32+
v0 = load_numpy('truth/v0.npy')
33+
v0 -= v0.min(axis=1, keepdims=True)
34+
v0 /= (smax - smin)
35+
v0 *= 2.0
3336

3437
a2 = load_numpy(f'hotaru/segment/{tag}_{stage}.npy')
3538
a2 = a2.reshape(-1, h, w)
@@ -41,26 +44,33 @@ def make_mpeg(tag='defautl', outfile='sample.mp4', stage='curr', data_tag='defau
4144
v2 = g(u2).numpy()
4245
v2 *= sstd / (smax - smin)
4346
v2 -= v2.min(axis=1, keepdims=True)
44-
45-
v0 *= 2.0
4647
v2 *= 2.0
4748

48-
filters = [
49-
('drawtext', dict(x=10, y=10, text='orig')),
50-
('drawtext', dict(x=10 + w, y=10, text='true')),
51-
('drawtext', dict(x=10 + 2 * w, y=10, text='HOTARU')),
52-
]
53-
with MpegStream(3 * w, h, hz, outfile, filters) as mpeg:
54-
for t, d in enumerate(data.as_numpy_iterator()):
55-
imgo = d * sstd + avgt[t] + avgx
56-
imgo = (imgo - smin) / (smax - smin)
57-
img0 = np.clip((a0 * v0[:, t, None, None]).sum(axis=0), 0.0, 1.0)
58-
img2 = np.clip((a2 * v2[:, t, None, None]).sum(axis=0), 0.0, 1.0)
59-
img = np.concatenate([imgo, img0, img2], axis=1)
60-
img = (255 * cmap(img)).astype(np.uint8)
61-
mpeg.write(img)
62-
63-
64-
if __name__ == '__main__':
65-
import sys
66-
make_mpeg(sys.argv[1])
49+
if has_truth:
50+
filters = [
51+
('drawtext', dict(x=10, y=10, text='orig')),
52+
('drawtext', dict(x=10 + w, y=10, text='true')),
53+
('drawtext', dict(x=10 + 2 * w, y=10, text='HOTARU')),
54+
]
55+
with MpegStream(3 * w, h, hz, outfile, filters) as mpeg:
56+
for t, d in enumerate(data.as_numpy_iterator()):
57+
imgo = d * sstd + avgt[t] + avgx
58+
imgo = (imgo - smin) / (smax - smin)
59+
img0 = np.clip((a0 * v0[:, t, None, None]).sum(axis=0), 0.0, 1.0)
60+
img2 = np.clip((a2 * v2[:, t, None, None]).sum(axis=0), 0.0, 1.0)
61+
img = np.concatenate([imgo, img0, img2], axis=1)
62+
img = (255 * cmap(img)).astype(np.uint8)
63+
mpeg.write(img)
64+
else:
65+
filters = [
66+
('drawtext', dict(x=10, y=10, text='orig')),
67+
('drawtext', dict(x=10 + w, y=10, text='HOTARU')),
68+
]
69+
with MpegStream(2 * w, h, hz, outfile, filters) as mpeg:
70+
for t, d in enumerate(data.as_numpy_iterator()):
71+
imgo = d * sstd + avgt[t] + avgx
72+
imgo = (imgo - smin) / (smax - smin)
73+
img2 = np.clip((a2 * v2[:, t, None, None]).sum(axis=0), 0.0, 1.0)
74+
img = np.concatenate([imgo, img2], axis=1)
75+
img = (255 * cmap(img)).astype(np.uint8)
76+
mpeg.write(img)

0 commit comments

Comments
 (0)