Skip to content

Commit 4b86ff2

Browse files
Merge pull request jax-ml#25097 from jburnim:jburnim_pallas_interpret_mode
PiperOrigin-RevId: 724073443
2 parents 840192d + 1c82484 commit 4b86ff2

File tree

7 files changed

+2040
-6
lines changed

7 files changed

+2040
-6
lines changed

Diff for: jax/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -652,6 +652,7 @@ pytype_strict_library(
652652
"//jax/_src/pallas",
653653
"//jax/_src/pallas/mosaic:core",
654654
"//jax/_src/pallas/mosaic:helpers",
655+
"//jax/_src/pallas/mosaic:interpret",
655656
"//jax/_src/pallas/mosaic:lowering",
656657
"//jax/_src/pallas/mosaic:pallas_call_registration", # build_cleaner: keep
657658
"//jax/_src/pallas/mosaic:pipeline",

Diff for: jax/_src/pallas/mosaic/BUILD

+12
Original file line numberDiff line numberDiff line change
@@ -148,3 +148,15 @@ py_library(
148148
"//jax/_src/pallas",
149149
],
150150
)
151+
152+
py_library(
153+
name = "interpret",
154+
srcs = ["interpret.py"],
155+
deps = [
156+
":core",
157+
":primitives",
158+
"//jax",
159+
"//jax/_src/lib",
160+
"//jax/_src/pallas",
161+
] + py_deps("numpy"),
162+
)

0 commit comments

Comments
 (0)