Skip to content
33 changes: 29 additions & 4 deletions lightguide/blast.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import annotations
from collections import deque

import logging
from copy import deepcopy
Expand Down Expand Up @@ -330,6 +331,8 @@ def follow_phase(
window_size: int | tuple[int, int] = 50,
threshold: float = 5e-1,
max_shift: int = 20,
stack: bool = False,
no_of_stacks: int = 5,
) -> tuple[np.ndarray, list[datetime], np.ndarray]:
"""Follow a phase pick through a Blast.

Expand All @@ -340,7 +343,7 @@ def follow_phase(
2. Calculate normalized cross correlate with downwards neighbor.
3. Evaluate maximum x-correlation in allowed window (max_shift).
4. Update template trace and go to 2.

4a. if stack=True: stack templates for correlation to stabilize
5. Repeat for upward neighbors.

Args:
Expand All @@ -353,6 +356,12 @@ def follow_phase(
Defaults to 5e-1.
max_shift (int, optional): Maximum allowed shift in samples for
neighboring picks. Defaults to 20.
stack (bool): If True - (a default number of 5) templates will be stacked and used
as correlation template. Stacking close to the initial template is limited to
the distance to the initial tamplate. I.e. the correlation of a trace 3 traces
next to the initial template will only use a stacked template of 3 traces
(initial trace an trace 1 and 2 next to it), altough the no_of_stacks is set higher.
no_of_stacks (int): Numbers of traces to stack to define the template.

Returns:
tuple[np.ndarray, list[datetime], np.ndarray]: Tuple of channel number,
Expand All @@ -375,12 +384,21 @@ def follow_phase(
def prepare_template(data: np.ndarray) -> np.ndarray:
return data * template_taper

# def stack_n_straces(data: np.ndarray,stack_traces) -> np.ndarray:

# return stacked_data

def correlate(data: np.ndarray, direction: Literal[1, -1] = 1) -> None:
template = root_template.copy()
index = root_idx
template_deque = deque([np.array(template)])

index = root_idx
for ichannel, trace in enumerate(data):
template = prepare_template(template)
# check if stacking is activated
if stack and len(template_deque) > 2:
template = prepare_template(template_stack)

norm = np.sqrt(np.sum(template**2)) * np.sqrt(np.sum(trace**2))
correlation = np.correlate(trace, template, mode="same")
correlation = np.abs(correlation / norm)
Expand All @@ -395,7 +413,7 @@ def correlate(data: np.ndarray, direction: Literal[1, -1] = 1) -> None:
phase_correlation = correlation[phase_idx]
phase_time = self._sample_to_time(int(phase_idx))

if phase_correlation < threshold:
if phase_correlation > threshold:
continue

# Avoid the edges
Expand All @@ -409,14 +427,21 @@ def correlate(data: np.ndarray, direction: Literal[1, -1] = 1) -> None:
template = trace[
phase_idx - window_size[0] : phase_idx + window_size[1] + 1
].copy()

# stacking
if len(template_deque) <= no_of_stacks:
template_deque.append(template)
if len(template_deque) == no_of_stacks + 1:
template_deque.popleft()
template_stack = np.sum(template_deque, axis=0) / len(template_deque)

index = phase_idx

correlate(self.data[pick_channel:])
correlate(self.data[: pick_channel - 1][::-1], direction=-1)

pick_channels = np.array(pick_channels) + self.start_channel
pick_correlations = np.array(pick_correlations)

return pick_channels, pick_times, pick_correlations

def taper(self, alpha: float = 0.05) -> None:
Expand Down