Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
e9603b6
feat: information extraction
Gorgeous-Patrick Sep 28, 2025
886bae6
feat: JacPIM static analysis ctx
Gorgeous-Patrick Sep 28, 2025
6a07ae2
feat: get networkx.
Gorgeous-Patrick Sep 28, 2025
6fe07f2
fix: use edge archetype
Gorgeous-Patrick Sep 28, 2025
523203f
feat: visit sequence generation
Gorgeous-Patrick Sep 28, 2025
0f7c549
feat: temporal trace graph analysis
Gorgeous-Patrick Sep 28, 2025
6634f4d
feat: add new jacpim starting point
Gorgeous-Patrick Sep 28, 2025
381e399
feat: basic plotting
Gorgeous-Patrick Sep 28, 2025
4c1334f
fix: ignore pngs
Gorgeous-Patrick Sep 28, 2025
984f7ba
feat: ttt finished
Gorgeous-Patrick Sep 28, 2025
ddbd133
feat: plot ttg
Gorgeous-Patrick Sep 28, 2025
38819ca
feat: use MultiDiGraph
Gorgeous-Patrick Sep 28, 2025
f6d82a7
feat: a better temporal tracing debugged.
Gorgeous-Patrick Sep 29, 2025
5604d72
fix: remove ttt print
Gorgeous-Patrick Sep 29, 2025
d7d0c37
feat: partitioning partially finished.
Gorgeous-Patrick Sep 29, 2025
6d8753e
feat: new simulation framework
Gorgeous-Patrick Sep 29, 2025
1d21596
feat: partially done jacpim simulation code gen
Gorgeous-Patrick Sep 30, 2025
c8472ba
feat: CPU runtime scheduler running
Gorgeous-Patrick Sep 30, 2025
ee6396b
par visit
Gorgeous-Patrick Sep 30, 2025
6198ac0
feat: CPU ctx
Gorgeous-Patrick Oct 1, 2025
675be21
feat: remove redundant node distribution
Gorgeous-Patrick Oct 1, 2025
9e73373
half snapshotting system
Gorgeous-Patrick Oct 1, 2025
4040b9c
feat: memory layout generation
Gorgeous-Patrick Oct 2, 2025
2a727f2
feat: memory dump running
Gorgeous-Patrick Oct 2, 2025
dc5abfb
save the results
Gorgeous-Patrick Oct 2, 2025
60ecab5
feat: do not clear the active walkers before running the recording
Gorgeous-Patrick Oct 2, 2025
c4a3ee1
save the memory dumps
Gorgeous-Patrick Oct 2, 2025
a3b5da3
fix: save only once
Gorgeous-Patrick Oct 2, 2025
233658a
remove useless file
Gorgeous-Patrick Oct 2, 2025
8086bcb
feat: context generation
Gorgeous-Patrick Oct 2, 2025
951ed6e
code gen done
Gorgeous-Patrick Oct 2, 2025
1933434
ignore task.c
Gorgeous-Patrick Oct 2, 2025
f0a8f8c
Metadata code gen
Gorgeous-Patrick Oct 2, 2025
26b98bd
fmt
Gorgeous-Patrick Oct 2, 2025
3c99444
buggy and not finished
Gorgeous-Patrick Oct 2, 2025
fc6fe84
print debug info
Gorgeous-Patrick Oct 6, 2025
398c26d
system running
Gorgeous-Patrick Oct 7, 2025
d3929f1
fix: add offset to node position
Gorgeous-Patrick Oct 7, 2025
e151aa7
generating correct programs
Gorgeous-Patrick Oct 7, 2025
6e666e8
System running
Gorgeous-Patrick Oct 7, 2025
0927512
Instruction count
Gorgeous-Patrick Oct 7, 2025
5efc56f
removed acquire (correctness to be verified)
Gorgeous-Patrick Oct 9, 2025
7ec4e2a
plotting
Gorgeous-Patrick Oct 16, 2025
483ba57
binary search running
Gorgeous-Patrick Oct 19, 2025
a6c24e7
walker stop
Gorgeous-Patrick Oct 20, 2025
c110895
overhead included
Gorgeous-Patrick Oct 21, 2025
4ed7d8b
include overhead
Gorgeous-Patrick Oct 21, 2025
da57424
Overhead subtraction working
Gorgeous-Patrick Oct 21, 2025
4431ee2
enable multithreading
Gorgeous-Patrick Oct 21, 2025
7f105c8
I think it makes sense now.
Gorgeous-Patrick Oct 21, 2025
9773346
Name changed
Gorgeous-Patrick Oct 21, 2025
dd18306
save
Gorgeous-Patrick Oct 23, 2025
a2a3662
save
Gorgeous-Patrick Oct 23, 2025
e7c1fa5
more testcases
Gorgeous-Patrick Oct 24, 2025
6bea956
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 24, 2025
97b35dd
connected component working well now
Gorgeous-Patrick Oct 24, 2025
09550d5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 24, 2025
fb472e1
feat: CPU DPU Transfer time estimation
Gorgeous-Patrick Oct 25, 2025
14f6502
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 25, 2025
eaa379f
centralized testcase for stat plot scripts
Gorgeous-Patrick Oct 29, 2025
1ca708a
fix: add node_id
Gorgeous-Patrick Oct 29, 2025
8e16595
data mapper new
Gorgeous-Patrick Oct 30, 2025
00fcbf2
save
Gorgeous-Patrick Nov 1, 2025
28358de
walker trace verified
Gorgeous-Patrick Nov 1, 2025
10a19bf
feat: new memory dumper and verification
Gorgeous-Patrick Nov 1, 2025
375b6a6
feat: more memory dump verification
Gorgeous-Patrick Nov 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,7 @@ mydatabase/
*.dir
*.log
.env
*.png
input_bins/
task.c
.ex*
5 changes: 5 additions & 0 deletions jac/jaclang/runtimelib/jacpim_mapping_analysis/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Mapping phase for JacPIM."""

from .mapping_ctx import JacPIMMappingCtx

__all__ = ["JacPIMMappingCtx"]
156 changes: 156 additions & 0 deletions jac/jaclang/runtimelib/jacpim_mapping_analysis/data_mapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
"""Graph Partitioning binding."""

import random

from jaclang.runtimelib.archetype import NodeArchetype
from jaclang.runtimelib.jacpim_static_analysis.info_extract import (
get_node_info_from_node_arch,
)
from jaclang.runtimelib.jacpim_static_analysis.static_ctx import JacPIMStaticCtx

import networkx as nx
import os

DPU_SIZE_LIMIT = 1024
DPU_NUM = 50
RESERVED_SIZE = 128
MAX_PARTITION_SIZE = DPU_SIZE_LIMIT - RESERVED_SIZE


class NodeDistribution:
"""Node to DPU mapping management."""

def __init__(self) -> None:
"""Initialize to an empty partitioning."""
self.node_to_partition: dict[int, int] = {}
self.partition_availability: list[int] = [0] * DPU_NUM

def add_node(self, node: int, partition: int, node_size: int) -> None:
"""Add a single node to a partition."""
self.node_to_partition[node] = partition
self.partition_availability[partition] += node_size
assert self.partition_availability[partition] <= MAX_PARTITION_SIZE

def node_assigned(self, node: int) -> bool:
"""Check whether a node has been assigned to a DPU."""
return node in self.node_to_partition

def available_partitions(self, node_size: int) -> list[int]:
"""Get a list of available partition IDs."""
return [
i
for i in range(DPU_NUM)
if self.partition_availability[i] + node_size <= MAX_PARTITION_SIZE
]

def get_dpu_data_amount(self) -> list[int]:
"""Get the data amount of each DPU core."""
return self.partition_availability

def get_partition(self) -> dict[int, int]:
"""Get the partitioning."""
return self.node_to_partition


class RoundRobinPartitioner:
"""Round Robin JacPIM Partitioner."""

def _remove_duplicates(self, lst: list[tuple[int, int]]) -> list[tuple[int, int]]:
"""Remove duplicates while preserving order."""
seen = set()
result = []
for depth, item in lst:
if item not in seen:
seen.add(item)
result.append((depth, item))
return result

def _dfs_round_robin_on_node(
self,
node_distribution: NodeDistribution,
ttg: nx.MultiDiGraph,
start_node_idx: int,
offset: int,
) -> None:
"""Run a basic dfs to get the partitioning in DFS order."""
stack: list[tuple[int, int]] = [(0, start_node_idx)]
visited: set[int] = set()
print(f"Starting DFS from node {start_node_idx}")
while len(stack) > 0:
depth, node = stack.pop(0)
next_nodes = ttg.edges(node, keys=True, data=True)
next_nodes = [
(depth + 1, next_node[1])
for next_node in next_nodes
if not (next_node[3].get("ttg_attr").is_parallel_edge)
if (next_node[3].get("ttg_attr").timestamp == depth)
if next_node[1] not in visited
if next_node[1] != node
if next_node[1] not in stack
]
next_nodes = self._remove_duplicates(next_nodes)

next_nodes_idx = [n[1] for n in next_nodes]
# print(next_nodes)
visited |= set(next_nodes_idx)
stack += next_nodes
node_size = get_node_info_from_node_arch(
JacPIMStaticCtx.get_all_nodes()[node]
).node_size_bytes
if node_distribution.node_assigned(node):
continue
partitions = node_distribution.available_partitions(node_size)
if len(partitions) == 0:
raise RuntimeError("No available partitions.")
partition = partitions[offset % len(partitions)]
node_distribution.add_node(node, partition, node_size)
print(f"Visited {len(visited)} nodes.")

def __init__(self, ttg: nx.MultiDiGraph, start_nodes: list[NodeArchetype]) -> None:
"""Get the partitioning done."""
self.node_distribution = NodeDistribution()
for idx, start_node in enumerate(start_nodes):
start_node_idx = JacPIMStaticCtx.get_all_nodes().index(start_node)
self._dfs_round_robin_on_node(
self.node_distribution, ttg, start_node_idx, offset=idx
)
for node_idx in range(len(JacPIMStaticCtx.get_all_nodes())):
if not self.node_distribution.node_assigned(node_idx):
node_size = get_node_info_from_node_arch(
JacPIMStaticCtx.get_all_nodes()[node_idx]
).node_size_bytes
partitions = self.node_distribution.available_partitions(node_size)
if len(partitions) == 0:
raise RuntimeError("No available partitions.")
if os.environ.get("MAPPING_STRATEGY") == "FIRST_FIT":
partition = min(partitions)
else:
partition = random.choice(partitions)
self.node_distribution.add_node(node_idx, partition, node_size)

def get_data_partitioning(self) -> dict[int, int]:
"""Retrieve the partitioning."""
return self.node_distribution.get_partition()


class RandomPartitioner:
"""Random JacPIM Partitioner (baseline)."""

def __init__(self, ttg: nx.MultiDiGraph, _: list[NodeArchetype]) -> None:
"""Get the partitioning done."""
self.node_distribution = NodeDistribution()
# self._dfs_round_robin_on_node(self.node_distribution, ttg, start_node_idx, 0)
for node in ttg.nodes():
if self.node_distribution.node_assigned(node):
continue
node_size = get_node_info_from_node_arch(
JacPIMStaticCtx.get_all_nodes()[node]
).node_size_bytes
partition = random.choice(
self.node_distribution.available_partitions(node_size)
)
self.node_distribution.add_node(node, partition, node_size)

def get_data_partitioning(self) -> dict[int, int]:
"""Retrieve the partitioning."""
return self.node_distribution.get_partition()
80 changes: 80 additions & 0 deletions jac/jaclang/runtimelib/jacpim_mapping_analysis/mapping_ctx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
"""JacPIM Mapping Phase global context."""

import os

import jaclang.compiler.unitree as uni
from jaclang.runtimelib.archetype import NodeArchetype, WalkerArchetype
from jaclang.runtimelib.jacpim_mapping_analysis.data_mapper import (
RandomPartitioner,
RoundRobinPartitioner,
)
from jaclang.runtimelib.jacpim_mapping_analysis.temporal_trace_graph import (
get_access_pattern_single_walker,
get_ttg_from_ttt,
)
from jaclang.runtimelib.jacpim_static_analysis import JacPIMStaticCtx
from jaclang.runtimelib.jacpim_static_analysis.info_extract import extract_name

import networkx


def get_walker_code(walker: WalkerArchetype) -> uni.Archetype:
"""Get the walker type code from walker instance."""
for walker_code in JacPIMStaticCtx.get_jac_program().mod.get_all_sub_nodes(
uni.Archetype
):
if walker_code.get_all_sub_nodes(uni.Name)[0].value == extract_name(walker):
return walker_code
raise ValueError(f"Walker code for {walker} not found in program.")


class JacPIMMappingCtx:
"""JacPIM Mapping Phase global context."""

mapping: dict[NodeArchetype, int] | None
ttg: networkx.MultiDiGraph | None
partitioning: dict[int, int] | None

@classmethod
def setter(
cls, nodes_and_walkers: list[tuple[NodeArchetype, WalkerArchetype]]
) -> None:
"""Set all the values in the context."""
static_ctx = JacPIMStaticCtx
cls.ttg = get_ttg_from_ttt(
[
get_access_pattern_single_walker(
start_node, static_ctx.get_networkx(), get_walker_code(walker)
)
for start_node, walker in nodes_and_walkers
]
)
cls.partitioning = None
mapping_method = os.environ.get("MAPPING")
if mapping_method is None:
raise RuntimeError("Mapping method not specified")
elif mapping_method == "JACPIM":
cls.partitioning = RoundRobinPartitioner(
cls.ttg, [start_node for start_node, _ in nodes_and_walkers]
).get_data_partitioning()
elif mapping_method == "RANDOM":
cls.partitioning = RandomPartitioner(
cls.ttg, [start_node for start_node, _ in nodes_and_walkers]
).get_data_partitioning()
else:
raise RuntimeError("Mapping method undefined")

@classmethod
def get_ttg(cls) -> networkx.MultiDiGraph:
"""Read the Temporal Trace Graph."""
if cls.ttg is None:
raise RuntimeError("TTG is None!")
return cls.ttg

@classmethod
def get_partitioning(cls) -> dict[int, int]:
"""Get the partitioning."""
if cls.partitioning is None:
raise RuntimeError("Partitioning not set.")
else:
return cls.partitioning
51 changes: 51 additions & 0 deletions jac/jaclang/runtimelib/jacpim_mapping_analysis/plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""Plotting diagrams."""

from jaclang.runtimelib.jacpim_mapping_analysis.mapping_ctx import JacPIMMappingCtx
from jaclang.runtimelib.jacpim_static_analysis.info_extract import (
get_node_info_from_node_arch,
)

import matplotlib.pyplot as plt

import matplotlib.pyplot as plt

import networkx as nx


def plot_ttg(graph: nx.MultiDiGraph, pos: dict, filename: str) -> None:
"""Plot and save one graph."""
print(graph)
plt.figure()
colors = ["red", "blue", "green", "orange", "purple", "brown", "pink", "gray"]
partitioning = JacPIMMappingCtx.get_partitioning()
node_colors = [colors[partitioning[n] % len(colors)] for n in graph.nodes()]

nx.draw_networkx_nodes(graph, pos, node_color=node_colors, node_size=100)

display_names = {
n: str(get_node_info_from_node_arch(graph.nodes[n]["archetype"]).display_name)
for n in graph.nodes()
}
assert all(
"ttg_attr" in graph.edges[n] for n in graph.edges
), "Edge attribute 'ttg_attr' missing in some edges."
# All Edges have timestamp
assert all(
hasattr(graph.edges[n]["ttg_attr"], "timestamp") for n in graph.edges
), "Edge attribute 'timestamp' missing in some edges."

# All Edges have timestamp not empty
assert all(
graph.edges[n]["ttg_attr"].timestamp is not None for n in graph.edges
), "Edge attribute 'timestamp' is None in some edges."
assert all( len(str(graph.edges[n]["ttg_attr"].timestamp)) > 0 for n in graph.edges
), "Edge attribute 'timestamp' is empty in some edges."
edge_labels = {
n: ", ".join(str(graph.edges[n]["ttg_attr"].timestamp)) for n in graph.edges
}

nx.draw_networkx_labels(graph, pos, display_names, font_size=10)
nx.draw_networkx_edges(graph, pos)
nx.draw_networkx_edge_labels(graph, pos, edge_labels, font_size=10)
plt.savefig(filename, dpi=300)
plt.close()
Loading