Skip to content

add PluggableDeviceLibrary #402

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ doctest = false

# Prevent downloading or building TensorFlow when building docs on docs.rs.
[package.metadata.docs.rs]
features = ["private-docs-rs", "tensorflow_unstable", "ndarray", "eager"]
features = ["private-docs-rs", "tensorflow_unstable", "ndarray", "eager", "experimental"]

[dependencies]
libc = "0.2.132"
Expand All @@ -40,6 +40,7 @@ serial_test = "0.9.0"

[features]
default = ["tensorflow-sys"]
experimental = ["tensorflow-sys/experimental"]
tensorflow_gpu = ["tensorflow-sys/tensorflow_gpu"]
tensorflow_unstable = []
tensorflow_runtime_linking = ["tensorflow-sys-runtime"]
Expand Down
5 changes: 5 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ use tensorflow_sys as tf;
#[cfg(feature = "tensorflow_runtime_linking")]
use tensorflow_sys_runtime as tf;

#[cfg(feature = "experimental")]
mod pluggable_device;
#[cfg(feature = "experimental")]
pub use pluggable_device::*;

////////////////////////

/// Will panic if `msg` contains an embedded 0 byte.
Expand Down
53 changes: 53 additions & 0 deletions src/pluggable_device.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
use crate::{Result, Status};
use std::ffi::CString;
use tensorflow_sys as tf;

/// PluggableDeviceLibrary handler.
#[derive(Debug)]
pub struct PluggableDeviceLibrary {
inner: *mut tf::TF_Library,
}

impl PluggableDeviceLibrary {
/// Load the library specified by library_filename and register the pluggable
/// device and related kernels present in that library. This function is not
/// supported on embedded on mobile and embedded platforms and will fail if
/// called.
///
/// Pass "library_filename" to a platform-specific mechanism for dynamically
/// loading a library. The rules for determining the exact location of the
/// library are platform-specific and are not documented here.
pub fn load(library_filename: &str) -> Result<PluggableDeviceLibrary> {
let status = Status::new();
let library_filename = CString::new(library_filename)?;
let lib_handle =
unsafe { tf::TF_LoadPluggableDeviceLibrary(library_filename.as_ptr(), status.inner) };
status.into_result()?;

Ok(PluggableDeviceLibrary { inner: lib_handle })
}
}

impl Drop for PluggableDeviceLibrary {
/// Frees the memory associated with the library handle.
/// Does NOT unload the library.
fn drop(&mut self) {
unsafe {
tf::TF_DeletePluggableDeviceLibraryHandle(self.inner);
}
}
}

#[cfg(test)]
mod tests {
use super::*;

#[ignore]
#[test]
fn load_pluggable_device_library() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might as well leave the test out, since there's no reasonable way to run it (as far as I know). We have similar issues with TF_LoadLibrary.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, right. I tried some, but none of them succeeded. There is too little information on this API.

Is there anything else I can do with this pull request?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In #387 they solved it (at least for macOS) by installing tensorflow-metal and loading that. If you want to leave this test out, I can merge this first and ask them to rebase their changes on top of this, so we'd at least have a test that runs in the CI, even if not on everyone's local machine.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adjusting the PRs. If that order is acceptable, please do so.

A while ago, I was trying to see if the C-API would work with a plugin for Windows, and latest plugin version (currently build from source only) seemed to work with TF 2.12. I'm thinking of trying to see if I can support that as well (apart from this PR) when the version of tensorflow here is updated.

https://github.com/microsoft/tensorflow-directml-plugin

let library_filename = "path-to-library";
let pluggable_divice_library = PluggableDeviceLibrary::load(library_filename);
dbg!(&pluggable_divice_library);
assert!((pluggable_divice_library.is_ok()));
}
}
1 change: 1 addition & 0 deletions tensorflow-sys/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ zip = "0.6.4"
[features]
tensorflow_gpu = []
eager = []
experimental = []
# This is for testing purposes; users should not use this.
examples_system_alloc = []
private-docs-rs = [] # DO NOT RELY ON THIS
16 changes: 12 additions & 4 deletions tensorflow-sys/generate_bindgen_rs.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,22 @@ if ! which bindgen > /dev/null; then
exit 1
fi

include_dir="$HOME/git/tensorflow"
include_dir="../../tensorflow"

# Export C-API
bindgen_options_c_api="--allowlist-function TF_.+ --allowlist-type TF_.+ --allowlist-var TF_.+ --size_t-is-usize --default-enum-style=rust --generate-inline-functions"
cmd="bindgen ${bindgen_options_c_api} ${include_dir}/tensorflow/c/c_api.h --output src/c_api.rs -- -I ${include_dir}"
cmd="bindgen ${bindgen_options_c_api} ${include_dir}/tensorflow/c/c_api.h --output src/c_api.rs -- -I ${include_dir}"
echo ${cmd}
${cmd}

bindgen_options_eager="--allowlist-function TFE_.+ --allowlist-type TFE_.+ --allowlist-var TFE_.+ --blocklist-type TF_.+ --size_t-is-usize --default-enum-style=rust --generate-inline-functions"
cmd="bindgen ${bindgen_options_eager} ${include_dir}/tensorflow/c/eager/c_api.h --output src/eager/c_api.rs -- -I ${include_dir}"
# Export PluggableDeviceLibrary from C-API experimental
bindgen_options_c_api_experimental="--allowlist-function TF_.+PluggableDeviceLibrary.* --blocklist-type TF_.+ --size_t-is-usize"
cmd="bindgen ${bindgen_options_c_api_experimental} ${include_dir}/tensorflow/c/c_api_experimental.h --output src/c_api_experimental.rs -- -I ${include_dir}"
echo ${cmd}
${cmd}

# Export Eager C-API
bindgen_options_eager="--allowlist-function TFE_.+ --allowlist-type TFE_.+ --allowlist-var TFE_.+ --blocklist-type TF_.+ --size_t-is-usize --default-enum-style=rust --generate-inline-functions --no-layout-tests"
cmd="bindgen ${bindgen_options_eager} ${include_dir}/tensorflow/c/eager/c_api.h --output src/eager/c_api.rs -- -I ${include_dir}"
echo ${cmd}
${cmd}
11 changes: 11 additions & 0 deletions tensorflow-sys/src/c_api_experimental.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
/* automatically generated by rust-bindgen 0.59.1 */

extern "C" {
pub fn TF_LoadPluggableDeviceLibrary(
library_filename: *const ::std::os::raw::c_char,
status: *mut TF_Status,
) -> *mut TF_Library;
}
extern "C" {
pub fn TF_DeletePluggableDeviceLibraryHandle(lib_handle: *mut TF_Library);
}
2 changes: 2 additions & 0 deletions tensorflow-sys/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ mod eager;
#[cfg(feature = "eager")]
pub use eager::*;
include!("c_api.rs");
#[cfg(feature = "experimental")]
include!("c_api_experimental.rs");

pub use crate::TF_AttrType::*;
pub use crate::TF_Code::*;
Expand Down
4 changes: 2 additions & 2 deletions test-all
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ run cargo run --example regression
run cargo run --example xor
run cargo run --features tensorflow_unstable --example expressions
run cargo run --features eager --example mobilenetv3
run cargo doc -vv --features tensorflow_unstable,ndarray,eager
run cargo doc -vv --features tensorflow_unstable,ndarray,eager,private-docs-rs
run cargo doc -vv --features experimental,tensorflow_unstable,ndarray,eager
run cargo doc -vv --features experimental,tensorflow_unstable,ndarray,eager,private-docs-rs
# TODO(#66): Re-enable: (cd tensorflow-sys && cargo test -vv -j 1)
(cd tensorflow-sys && run cargo run --example multiplication)
(cd tensorflow-sys && run cargo run --example tf_version)
Expand Down