Skip to content

[SYCL][Graph] Breadth-first schedule #182

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: sycl-graph-develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions sycl/source/detail/graph_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,6 @@ bool check_for_requirement(sycl::detail::AccessorImplHost *Req,
}
} // anonymous namespace

void exec_graph_impl::schedule() {
if (MSchedule.empty()) {
for (auto Node : MGraphImpl->MRoots) {
Node->topology_sort(Node, MSchedule);
}
}
}

std::shared_ptr<node_impl> graph_impl::add_subgraph_nodes(
const std::list<std::shared_ptr<node_impl>> &NodeList) {
// Find all input and output nodes from the node list
Expand All @@ -104,6 +96,15 @@ std::shared_ptr<node_impl> graph_impl::add_subgraph_nodes(
return this->add(Outputs);
}

std::list<std::shared_ptr<node_impl>> graph_impl::compute_schedule() {
exec_order_recompute();
std::list<std::shared_ptr<node_impl>> Sched;
for (auto &Next : MExecOrder) {
Sched.push_back(Next.second);
}
return Sched;
};

void graph_impl::add_root(const std::shared_ptr<node_impl> &Root) {
MRoots.insert(Root);
}
Expand Down Expand Up @@ -571,7 +572,6 @@ command_graph<graph_state::executable>::command_graph(

void command_graph<graph_state::executable>::finalize_impl() {
// Create PI command-buffers for each device in the finalized context
impl->schedule();

auto Context = impl->get_context();
for (auto Device : impl->get_context().get_devices()) {
Expand Down
83 changes: 60 additions & 23 deletions sycl/source/detail/graph_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
#include <functional>
#include <list>
#include <set>
#include <optional>
#include <map>

namespace sycl {
__SYCL_INLINE_VER_NAMESPACE(_V1) {
Expand Down Expand Up @@ -75,22 +77,29 @@ class node_impl {
std::unique_ptr<sycl::detail::CG> &&CommandGroup)
: MCGType(CGType), MCommandGroup(std::move(CommandGroup)) {}

/// Recursively add nodes to execution stack.
/// @param NodeImpl Node to schedule.
/// @param Schedule Execution ordering to add node to.
void topology_sort(std::shared_ptr<node_impl> NodeImpl,
std::list<std::shared_ptr<node_impl>> &Schedule) {
for (auto Next : MSuccessors) {
// Check if we've already scheduled this node
if (std::find(Schedule.begin(), Schedule.end(), Next) == Schedule.end())
Next->topology_sort(Next, Schedule);
}
// We don't need to schedule empty nodes as they are only used when
// calculating dependencies
if (!NodeImpl->is_empty())
Schedule.push_front(NodeImpl);
}
private:
/// Depth of this node in a containing graph.
///
/// The first call to graph.exec_order_recompute computes & caches the value.
/// It will likely become stale whenever the containing graph is changed and
/// a single value will be inadequate if this node is added to multiple graphs.
/// Caching is dangerous but recomputing takes O(graph_size) worst-case time.
std::optional<int> MDepth;

public:
/// Gets the depth of this node in its containing graph.
/// @return the depth of this node in its containing graph.
int get_depth() {
if (!MDepth.has_value()) {
int MaxDepthFound = -1;
for (auto &P : MPredecessors) {
MaxDepthFound = std::max(MaxDepthFound, P.lock()->get_depth());
}
MDepth = MaxDepthFound + 1;
}
return MDepth.value();
};

/// Checks if this node has a given requirement.
/// @param Requirement Requirement to lookup.
/// @return True if \p Requirement is present in node, false otherwise.
Expand Down Expand Up @@ -180,7 +189,7 @@ class node_impl {
}
};

/// Class representing implementation details of command_graph<modifiable>.
/// Class representing implementation details of modifiable command_graph.
class graph_impl {
public:
/// Constructor.
Expand All @@ -190,6 +199,39 @@ class graph_impl {
: MContext(SyclContext), MDevice(SyclDevice), MRecordingQueues(),
MEventsMap() {}

private:
/// A sorted multimap capturing a breadth-first execution/submission order.
///
/// The SortKey is the depth in the graph for the node_impl in the value.
/// Depth is the length of the longest dependence chain to any root node.
std::multimap<int, std::shared_ptr<node_impl>> MExecOrder;

/// Depth-first recursion from V to build the execution order.
/// @param V Starting node for depth-first recursion.
void exec_order_recompute(node_impl &V) {
// depth-first recursion to access all nodes that succeed this node
for (auto &S : V.MSuccessors) {
exec_order_recompute(*S.get());
}
// insert this into execution order based on its depth in the graph
MExecOrder.insert(std::pair(V.get_depth(), &V));
};

/// Recomputes the submission/execution order for this whole graph.
void exec_order_recompute() {
MExecOrder.clear();
// for all root nodes ...
for (auto &Root : MRoots) {
// ... recurse towards all exit nodes
exec_order_recompute(*Root);
}
};

public:
/// Recomputes the submission/execution order then schedules all nodes.
/// @return A list of shared pointers to nodes in linear scheduling order.
std::list<std::shared_ptr<node_impl>> compute_schedule();

/// Insert node into list of root nodes.
/// @param Root Node to add to list of root nodes.
void add_root(const std::shared_ptr<node_impl> &Root);
Expand Down Expand Up @@ -313,17 +355,15 @@ class exec_graph_impl {
/// @param GraphImpl Modifiable graph implementation to create with.
exec_graph_impl(sycl::context Context,
const std::shared_ptr<graph_impl> &GraphImpl)
: MSchedule(), MGraphImpl(GraphImpl), MPiCommandBuffers(),
: MSchedule(GraphImpl->compute_schedule()),
MPiCommandBuffers(),
MPiSyncPoints(), MContext(Context) {}

/// Destructor.
///
/// Releases any PI command-buffers the object has created.
~exec_graph_impl();

/// Add nodes to MSchedule.
void schedule();

/// Called by handler::ext_oneapi_command_graph() to schedule graph for
/// execution.
/// @param Queue Command-queue to schedule execution on.
Expand Down Expand Up @@ -378,9 +418,6 @@ class exec_graph_impl {

/// Execution schedule of nodes in the graph.
std::list<std::shared_ptr<node_impl>> MSchedule;
/// Pointer to the modifiable graph impl associated with this executable
/// graph.
std::shared_ptr<graph_impl> MGraphImpl;
/// Map of devices to command buffers.
std::unordered_map<sycl::device, RT::PiExtCommandBuffer> MPiCommandBuffers;
/// Map of nodes in the exec graph to the sync point representing their
Expand Down