Skip to content

Commit

Permalink
fix no prox bug (#13)
Browse files Browse the repository at this point in the history
  • Loading branch information
tk2lab authored Nov 4, 2020
1 parent 04caf74 commit 1b46170
Show file tree
Hide file tree
Showing 16 changed files with 331 additions and 278 deletions.
292 changes: 159 additions & 133 deletions poetry.lock

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "hotaru"
version = "3.2.4"
version = "3.3.0"
description = "High performance Optimizer to extract spike Timing And cell location from calcium imaging data via lineaR impUlse"
license = "GPL-3.0-only"
authors = ["TAKEKAWA Takashi <[email protected]>"]
Expand All @@ -10,7 +10,7 @@ keywords = ["Calcium Imaging", "Spike Detection", "Cell Extraction"]

[tool.poetry.dependencies]
python = "^3.7"
tensorflow = "^2.2.1"
tensorflow = "^2.3.1"
tifffile = "^2020.5.11"
matplotlib = "^3.2.1"
cleo = "^0.8.1"
Expand Down
233 changes: 128 additions & 105 deletions requirements.txt

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion src/hotaru/console/clean.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class CleanCommand(Command):
'''

options = [
_option('job-dir', 'j', ''),
_option('job-dir', 'j', 'target directory'),
option('force', 'f', ''),
]

Expand Down
2 changes: 1 addition & 1 deletion src/hotaru/console/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class DataCommand(Command):
'''

options = [
_option('job-dir'),
_option('job-dir', 'j', 'target directory'),
option('force', 'f'),
]

Expand Down
10 changes: 5 additions & 5 deletions src/hotaru/console/footprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class FootprintCommand(Command):
'''

options = [
_option('job-dir'),
_option('job-dir', 'j', 'target directory'),
option('force', 'f'),
]

Expand All @@ -33,10 +33,10 @@ def create(self, data, prev, curr, logs, la, bx, bt):
spike = load_numpy(f'{prev}-spike')
nk = spike.shape[0]
lu = self.status.params['lu']
with self.application.strategy.scope():
elems = self.get_model(data, tau, nk)
model = FootprintModel(*elems)
model.compile()
#with self.application.strategy.scope():
elems = self.get_model(data, tau, nk)
model = FootprintModel(*elems)
model.compile()
model.set_penalty(la, lu, bx, bt)
model.fit(
spike, stage=curr[-3:],
Expand Down
2 changes: 1 addition & 1 deletion src/hotaru/console/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class HistoryCommand(Command):
'''

options = [
_option('job-dir'),
_option('job-dir', 'j', 'target directory'),
]

def handle(self):
Expand Down
2 changes: 1 addition & 1 deletion src/hotaru/console/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class OutputCommand(Command):
'''

options = [
_option('job-dir'),
_option('job-dir', 'j', 'target directory'),
]

def handle(self):
Expand Down
2 changes: 1 addition & 1 deletion src/hotaru/console/peak.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class PeakCommand(Command):
'''

options = [
_option('job-dir'),
_option('job-dir', 'j', 'target directory'),
option('force', 'f'),
]

Expand Down
2 changes: 1 addition & 1 deletion src/hotaru/console/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class RunCommand(Command):
'''

options = [
_option('job-dir'),
_option('job-dir', 'j', 'target directory'),
_option('goal'),
]

Expand Down
2 changes: 1 addition & 1 deletion src/hotaru/console/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class SegmentCommand(Command):
'''

options = [
_option('job-dir'),
_option('job-dir', 'j', 'target directory'),
option('force', 'f'),
]

Expand Down
10 changes: 5 additions & 5 deletions src/hotaru/console/spike.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class SpikeCommand(Command):

name = 'spike'
options = [
_option('job-dir'),
_option('job-dir', 'j', 'target directory'),
option('force', 'f'),
]

Expand All @@ -30,10 +30,10 @@ def create(self, data, prev, curr, logs, tau, lu, bx, bt):
segment = load_numpy(f'{prev}-segment')
nk = segment.shape[0]
la = self.status.params['la']
with self.application.strategy.scope():
elems = self.get_model(data, tau, nk)
model = SpikeModel(*elems)
model.compile()
#with self.application.strategy.scope():
elems = self.get_model(data, tau, nk)
model = SpikeModel(*elems)
model.compile()
model.set_penalty(la, lu, bx, bt)
model.fit(
segment, stage=curr[-3:],
Expand Down
2 changes: 1 addition & 1 deletion src/hotaru/console/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class TestCommand(Command):
'''

options = [
_option('job-dir', 'j', ''),
_option('job-dir', 'j', 'target directory'),
]

def handle(self):
Expand Down
20 changes: 11 additions & 9 deletions src/hotaru/train/footprint.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os

import tensorflow.keras.backend as K
import tensorflow as tf
import numpy as np
Expand All @@ -19,7 +21,7 @@ def __init__(self, footprint, spike, variance, **kwargs):
def call(self, inputs):
footprint = self.footprint(inputs)
loss, footprint_penalty, spike_penalty = self.call_common(footprint)
self.add_metric(footprint_penalty, 'mean', 'penalty')
self.add_metric(footprint_penalty, 'penalty')
return loss

def fit(self, spike, lr, batch, verbose, **kwargs):
Expand Down Expand Up @@ -50,21 +52,21 @@ def __init__(self, stage, *args, **kwargs):
super().__init__(*args, **kwargs)
self.stage = stage

def on_epoch_end(self, epoch, logs=None):
super().on_epoch_end(logs)
def set_model(self, model):
super().set_model(model)
self._train_dir = os.path.join(self._log_write_dir, 'footprint')

def on_epoch_end(self, epoch, logs=None):
stage = self.stage
writer = self._get_writer('footprint')
with writer.as_default():
with self._train_writer.as_default():
val = self.model.footprint.val
summary_stat(val, stage, step=epoch)
super().on_epoch_end(logs)

def on_train_end(self, logs=None):
super().on_train_end(logs)

stage = self.stage
writer = self._get_writer('footprint')
with writer.as_default():
with self._train_writer.as_default():
val = self.model.footprint.val
mask = self.model.footprint.mask
summary_footprint_max(val, mask, stage, step=0)
super().on_train_end(logs)
4 changes: 2 additions & 2 deletions src/hotaru/train/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ def call_common(self, val):
footprint_penalty = self.footprint_penalty()
spike_penalty = self.spike_penalty()
me = loss + footprint_penalty + spike_penalty
self.add_metric(K.sqrt(variance), 'mean', 'sigma')
self.add_metric(me, 'mean', 'score')
self.add_metric(K.sqrt(variance), 'sigma')
self.add_metric(me, 'score')
return loss, footprint_penalty, spike_penalty

def fit_common(self, callback, log_dir=None, stage=None, callbacks=None,
Expand Down
20 changes: 11 additions & 9 deletions src/hotaru/train/spike.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os

import tensorflow.keras.backend as K
import tensorflow as tf
import numpy as np
Expand All @@ -19,7 +21,7 @@ def call(self, inputs):
spike = self.spike(inputs)
calcium = self.variance.spike_to_calcium(spike)
loss, footprint_penalty, spike_penalty = self.call_common(calcium)
self.add_metric(footprint_penalty, 'mean', 'penalty')
self.add_metric(footprint_penalty, 'penalty')
return loss

def fit(self, footprint, lr, batch, verbose, **kwargs):
Expand Down Expand Up @@ -50,18 +52,18 @@ def __init__(self, stage, *args, **kwargs):
super().__init__(*args, **kwargs)
self.stage = stage

def on_epoch_end(self, epoch, logs=None):
super().on_epoch_end(epoch, logs)
def set_model(self, model):
super().set_model(model)
self._train_dir = os.path.join(self._log_write_dir, 'spike')

def on_epoch_end(self, epoch, logs=None):
stage = self.stage
writer = self._get_writer('spike')
with writer.as_default():
with self._train_writer.as_default():
summary_stat(self.model.spike.val, stage, step=epoch)
super().on_epoch_end(epoch, logs)

def on_train_end(self, logs=None):
super().on_train_end(logs)

stage = self.stage
writer = self._get_writer('spike')
with writer.as_default():
with self._train_writer.as_default():
summary_spike(self.model.spike.val, stage, step=0)
super().on_train_end(logs)

0 comments on commit 1b46170

Please sign in to comment.