Skip to content

Commit 1b46170

Browse files
authored
fix no prox bug (#13)
1 parent 04caf74 commit 1b46170

File tree

16 files changed

+331
-278
lines changed

16 files changed

+331
-278
lines changed

poetry.lock

Lines changed: 159 additions & 133 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "hotaru"
3-
version = "3.2.4"
3+
version = "3.3.0"
44
description = "High performance Optimizer to extract spike Timing And cell location from calcium imaging data via lineaR impUlse"
55
license = "GPL-3.0-only"
66
authors = ["TAKEKAWA Takashi <[email protected]>"]
@@ -10,7 +10,7 @@ keywords = ["Calcium Imaging", "Spike Detection", "Cell Extraction"]
1010

1111
[tool.poetry.dependencies]
1212
python = "^3.7"
13-
tensorflow = "^2.2.1"
13+
tensorflow = "^2.3.1"
1414
tifffile = "^2020.5.11"
1515
matplotlib = "^3.2.1"
1616
cleo = "^0.8.1"

requirements.txt

Lines changed: 128 additions & 105 deletions
Large diffs are not rendered by default.

src/hotaru/console/clean.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ class CleanCommand(Command):
1919
'''
2020

2121
options = [
22-
_option('job-dir', 'j', ''),
22+
_option('job-dir', 'j', 'target directory'),
2323
option('force', 'f', ''),
2424
]
2525

src/hotaru/console/data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class DataCommand(Command):
2222
'''
2323

2424
options = [
25-
_option('job-dir'),
25+
_option('job-dir', 'j', 'target directory'),
2626
option('force', 'f'),
2727
]
2828

src/hotaru/console/footprint.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ class FootprintCommand(Command):
1515
'''
1616

1717
options = [
18-
_option('job-dir'),
18+
_option('job-dir', 'j', 'target directory'),
1919
option('force', 'f'),
2020
]
2121

@@ -33,10 +33,10 @@ def create(self, data, prev, curr, logs, la, bx, bt):
3333
spike = load_numpy(f'{prev}-spike')
3434
nk = spike.shape[0]
3535
lu = self.status.params['lu']
36-
with self.application.strategy.scope():
37-
elems = self.get_model(data, tau, nk)
38-
model = FootprintModel(*elems)
39-
model.compile()
36+
#with self.application.strategy.scope():
37+
elems = self.get_model(data, tau, nk)
38+
model = FootprintModel(*elems)
39+
model.compile()
4040
model.set_penalty(la, lu, bx, bt)
4141
model.fit(
4242
spike, stage=curr[-3:],

src/hotaru/console/history.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ class HistoryCommand(Command):
1111
'''
1212

1313
options = [
14-
_option('job-dir'),
14+
_option('job-dir', 'j', 'target directory'),
1515
]
1616

1717
def handle(self):

src/hotaru/console/output.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ class OutputCommand(Command):
2020
'''
2121

2222
options = [
23-
_option('job-dir'),
23+
_option('job-dir', 'j', 'target directory'),
2424
]
2525

2626
def handle(self):

src/hotaru/console/peak.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ class PeakCommand(Command):
1515
'''
1616

1717
options = [
18-
_option('job-dir'),
18+
_option('job-dir', 'j', 'target directory'),
1919
option('force', 'f'),
2020
]
2121

src/hotaru/console/run.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ class RunCommand(Command):
1111
'''
1212

1313
options = [
14-
_option('job-dir'),
14+
_option('job-dir', 'j', 'target directory'),
1515
_option('goal'),
1616
]
1717

src/hotaru/console/segment.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class SegmentCommand(Command):
1818
'''
1919

2020
options = [
21-
_option('job-dir'),
21+
_option('job-dir', 'j', 'target directory'),
2222
option('force', 'f'),
2323
]
2424

src/hotaru/console/spike.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class SpikeCommand(Command):
1313

1414
name = 'spike'
1515
options = [
16-
_option('job-dir'),
16+
_option('job-dir', 'j', 'target directory'),
1717
option('force', 'f'),
1818
]
1919

@@ -30,10 +30,10 @@ def create(self, data, prev, curr, logs, tau, lu, bx, bt):
3030
segment = load_numpy(f'{prev}-segment')
3131
nk = segment.shape[0]
3232
la = self.status.params['la']
33-
with self.application.strategy.scope():
34-
elems = self.get_model(data, tau, nk)
35-
model = SpikeModel(*elems)
36-
model.compile()
33+
#with self.application.strategy.scope():
34+
elems = self.get_model(data, tau, nk)
35+
model = SpikeModel(*elems)
36+
model.compile()
3737
model.set_penalty(la, lu, bx, bt)
3838
model.fit(
3939
segment, stage=curr[-3:],

src/hotaru/console/test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ class TestCommand(Command):
1919
'''
2020

2121
options = [
22-
_option('job-dir', 'j', ''),
22+
_option('job-dir', 'j', 'target directory'),
2323
]
2424

2525
def handle(self):

src/hotaru/train/footprint.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import os
2+
13
import tensorflow.keras.backend as K
24
import tensorflow as tf
35
import numpy as np
@@ -19,7 +21,7 @@ def __init__(self, footprint, spike, variance, **kwargs):
1921
def call(self, inputs):
2022
footprint = self.footprint(inputs)
2123
loss, footprint_penalty, spike_penalty = self.call_common(footprint)
22-
self.add_metric(footprint_penalty, 'mean', 'penalty')
24+
self.add_metric(footprint_penalty, 'penalty')
2325
return loss
2426

2527
def fit(self, spike, lr, batch, verbose, **kwargs):
@@ -50,21 +52,21 @@ def __init__(self, stage, *args, **kwargs):
5052
super().__init__(*args, **kwargs)
5153
self.stage = stage
5254

53-
def on_epoch_end(self, epoch, logs=None):
54-
super().on_epoch_end(logs)
55+
def set_model(self, model):
56+
super().set_model(model)
57+
self._train_dir = os.path.join(self._log_write_dir, 'footprint')
5558

59+
def on_epoch_end(self, epoch, logs=None):
5660
stage = self.stage
57-
writer = self._get_writer('footprint')
58-
with writer.as_default():
61+
with self._train_writer.as_default():
5962
val = self.model.footprint.val
6063
summary_stat(val, stage, step=epoch)
64+
super().on_epoch_end(logs)
6165

6266
def on_train_end(self, logs=None):
63-
super().on_train_end(logs)
64-
6567
stage = self.stage
66-
writer = self._get_writer('footprint')
67-
with writer.as_default():
68+
with self._train_writer.as_default():
6869
val = self.model.footprint.val
6970
mask = self.model.footprint.mask
7071
summary_footprint_max(val, mask, stage, step=0)
72+
super().on_train_end(logs)

src/hotaru/train/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ def call_common(self, val):
3535
footprint_penalty = self.footprint_penalty()
3636
spike_penalty = self.spike_penalty()
3737
me = loss + footprint_penalty + spike_penalty
38-
self.add_metric(K.sqrt(variance), 'mean', 'sigma')
39-
self.add_metric(me, 'mean', 'score')
38+
self.add_metric(K.sqrt(variance), 'sigma')
39+
self.add_metric(me, 'score')
4040
return loss, footprint_penalty, spike_penalty
4141

4242
def fit_common(self, callback, log_dir=None, stage=None, callbacks=None,

src/hotaru/train/spike.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import os
2+
13
import tensorflow.keras.backend as K
24
import tensorflow as tf
35
import numpy as np
@@ -19,7 +21,7 @@ def call(self, inputs):
1921
spike = self.spike(inputs)
2022
calcium = self.variance.spike_to_calcium(spike)
2123
loss, footprint_penalty, spike_penalty = self.call_common(calcium)
22-
self.add_metric(footprint_penalty, 'mean', 'penalty')
24+
self.add_metric(footprint_penalty, 'penalty')
2325
return loss
2426

2527
def fit(self, footprint, lr, batch, verbose, **kwargs):
@@ -50,18 +52,18 @@ def __init__(self, stage, *args, **kwargs):
5052
super().__init__(*args, **kwargs)
5153
self.stage = stage
5254

53-
def on_epoch_end(self, epoch, logs=None):
54-
super().on_epoch_end(epoch, logs)
55+
def set_model(self, model):
56+
super().set_model(model)
57+
self._train_dir = os.path.join(self._log_write_dir, 'spike')
5558

59+
def on_epoch_end(self, epoch, logs=None):
5660
stage = self.stage
57-
writer = self._get_writer('spike')
58-
with writer.as_default():
61+
with self._train_writer.as_default():
5962
summary_stat(self.model.spike.val, stage, step=epoch)
63+
super().on_epoch_end(epoch, logs)
6064

6165
def on_train_end(self, logs=None):
62-
super().on_train_end(logs)
63-
6466
stage = self.stage
65-
writer = self._get_writer('spike')
66-
with writer.as_default():
67+
with self._train_writer.as_default():
6768
summary_spike(self.model.spike.val, stage, step=0)
69+
super().on_train_end(logs)

0 commit comments

Comments
 (0)