Skip to content

Commit c0d7559

Browse files
committed
add dag
1 parent 91ac2fe commit c0d7559

File tree

7 files changed

+356
-40
lines changed

7 files changed

+356
-40
lines changed

xla/hlo/experimental/auto_reorder/auto_reorder.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,10 @@ AutoReorderPass::ScheduleComputation(HloComputation* computation) {
130130
std::vector<HloInstruction*> new_schedule;
131131
auto sorted_nodes = solver_->GetSortedNodes();
132132
for (auto node : sorted_nodes) {
133-
new_schedule.push_back(
134-
const_cast<xla::HloInstruction*>(node->GetValue()));
133+
auto insts = node->GetValues();
134+
for (auto inst : insts) {
135+
new_schedule.push_back(const_cast<xla::HloInstruction*>(inst));
136+
}
135137
}
136138
return new_schedule;
137139
}

xla/hlo/experimental/auto_reorder/auto_reorder_solver.cc

Lines changed: 76 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,11 @@
55
#ifndef LPSchedulerFunc(return_type)
66
#define LPSchedulerFunc(return_type) template <typename ContainerType,typename ElementType> return_type LinearProgramScheduler<ContainerType,ElementType>
77
#endif
8+
9+
#ifndef LPContainerDAGFunc(return_type)
10+
#define LPContainerDAGFunc(return_type) template <typename ElementType> return_type LPContainerDAG<ElementType>
11+
#endif
12+
813
namespace xla {
914
using IntVar = operations_research::sat::IntVar;
1015
using CpModelBuilder = operations_research::sat::CpModelBuilder;
@@ -100,7 +105,7 @@ LPSchedulerFunc(tsl::Status)::Solve() {
100105
max_execution_time += cost;
101106
}
102107
}
103-
SetHorizon(max_execution_time * reorder::kChannelNumber);
108+
SetHorizon(reorder::get_horizon(max_execution_time));
104109

105110
for (auto node : nodes_) {
106111
VLOG(3) << "Add to scheduler" << node->GetName();
@@ -141,7 +146,7 @@ LPSchedulerFunc(tsl::Status)::Solve() {
141146
parameters.set_log_to_stdout(true);
142147
parameters.set_log_search_progress(true);
143148
}
144-
parameters.set_num_search_workers(8);
149+
parameters.set_num_search_workers(reorder::get_cpu_number());
145150
const operations_research::sat::CpSolverResponse response =
146151
operations_research::sat::SolveWithParameters(cp_model_.Build(),
147152
parameters);
@@ -294,6 +299,75 @@ LPSchedulerFunc(std::vector<ContainerType*>)::GetSortedNodes() {
294299
[this](ContainerType* a, ContainerType* b) { return a->GetStart() < b->GetStart(); });
295300
return sorted_nodes;
296301
}
302+
303+
LPContainerDAGFunc(bool)::IsIn(LPContainer<ElementType>* a) {
304+
return operands_.find(a) != operands_.end();
305+
};
306+
LPContainerDAGFunc(void)::AddToDAG(LPContainer<ElementType>* child){
307+
inner_elements.push_back(child);
308+
if(IsIn(child)){
309+
operands_.erase(child);
310+
}
311+
for(auto dep_pair: child->GetDeps()){
312+
auto dep = std::get<0>(dep_pair);
313+
auto cost = std::get<1>(dep_pair);//if cost need store ?
314+
operands_.insert(dep);
315+
}
316+
}
317+
LPContainerDAGFunc(Status)::MergeFrom(LPContainerDAG<ElementType>* other){
318+
/*
319+
step 1: this inner_elements must have dep to other's inner_elements. so that link to other's inner_elements change to inner edges
320+
*/
321+
322+
// maintain this LPContainerDAG inner_elements's deps,so that can create inner edge after merge
323+
// {dep: [<element1, cost>,<element2, cost>]}
324+
std::unordered_map<
325+
int,
326+
std::vector<std::tuple<LPContainer<ElementType>*, CostType>>
327+
> dep_operands2element;
328+
329+
for(LPContainer<ElementType>* element: GetInnerElements()){
330+
// from operate to element, there are outer edge,maybe convert to inner edge
331+
for(auto dep_pair: element->GetDeps()){
332+
auto dep = std::get<0>(dep_pair);
333+
auto cost = std::get<1>(dep_pair);
334+
if(dep_operands2element.find(dep->UUID())==dep_operands2element.end()){
335+
dep_operands2element[dep->UUID()] = std::vector<std::tuple<LPContainer<ElementType>*, CostType>>();
336+
}
337+
dep_operands2element[dep->UUID()].push_back(std::make_tuple(element, cost));
338+
}
339+
}
340+
//other
341+
for(auto child:other->GetInnerElements()){
342+
// there child must in inner_elements_deps
343+
TF_RET_CHECK(dep_operands2element.find(child->UUID())==dep_operands2element.end()
344+
)<<"child is not in dep_operands2element";
345+
for(auto dep_pair: dep_operands2element[child->UUID()]){
346+
auto dep = std::get<0>(dep_pair);
347+
auto cost = std::get<1>(dep_pair);
348+
if(dep_operands2element.find(dep->UUID())!=dep_operands2element.end()){
349+
for(auto element_pair: dep_operands2element[dep->UUID()]){
350+
auto element = std::get<0>(element_pair);
351+
auto cost = std::get<1>(element_pair);
352+
//create edge between element and child
353+
DAGEdge edge;
354+
edge.from = element;
355+
edge.to = child;
356+
edge.cost = cost;
357+
edges_.push_back(edge);
358+
}
359+
}
360+
}
361+
362+
AddToDAG(child);
363+
364+
};
365+
}
297366
template class LPContainer<const HloInstruction*>;
298367
template class LinearProgramScheduler<LPContainer<const HloInstruction*>, const HloInstruction*>;
368+
369+
370+
template class LPContainerDAG<const HloInstruction*>;
371+
// template class LinearProgramScheduler<LPContainerDAG<const HloInstruction*>, const HloInstruction*>;
372+
299373
} // namespace xla

xla/hlo/experimental/auto_reorder/auto_reorder_solver.h

Lines changed: 62 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#include <tuple>
66
#include <unordered_map>
77
#include <set>
8-
8+
#include <thread>
99
#include "absl/container/flat_hash_map.h"
1010
#include "absl/strings/string_view.h"
1111
#include "xla/hlo/ir/hlo_module.h"
@@ -17,10 +17,19 @@ namespace xla {
1717
using CpModelBuilder = operations_research::sat::CpModelBuilder;
1818
using IntervalVar = operations_research::sat::IntervalVar;
1919
namespace reorder{
20-
const uint32_t ksolveTimeout = 30; // 30s
20+
const uint32_t ksolveTimeout = 60; // 30s
2121

2222
static const int kChannelNumber = 2;
23-
bool solve_debug=false;
23+
int get_horizon(int max_time){
24+
//scale
25+
return max_time*1.2;
26+
}
27+
bool solve_debug=true;
28+
//get cpu number of current machine
29+
int get_cpu_number(){
30+
// return 8;
31+
return std::thread::hardware_concurrency();
32+
}
2433
}
2534
enum class NodeType { kCompute = 0, kCommunication = 1 };
2635

@@ -59,7 +68,7 @@ class LPNode{
5968
template <typename ElementType>
6069
class LPContainer{
6170
public:
62-
71+
//create a LPContainer with inner_element, cost and type
6372
LPContainer(ElementType inner_element, CostType cost, NodeType type)
6473
: inner_element_(inner_element), cost_(cost), type_(type) {
6574
uuid_ = reinterpret_cast<uintptr_t>(this);
@@ -71,15 +80,20 @@ class LPContainer{
7180
CostType GetCost() const { return cost_; }
7281
void SetStart(CostType start) { startat_ = start; }
7382
CostType GetStart() { return startat_; }
83+
// Get the type of the container: compute or communication
7484
bool IsComputation() const { return type_ == NodeType::kCompute; }
7585
bool IsCommunication() const { return type_ == NodeType::kCommunication; }
7686
NodeType GetType() const { return type_; }
77-
bool HasValue() const { return inner_element_ != nullptr; }
78-
ElementType GetValue() const { return inner_element_; }
87+
88+
const bool HasValue() { return inner_element_ != nullptr; }
89+
const std::vector<ElementType> GetValues() { return std::vector<ElementType>{inner_element_}; }
90+
// Add a dep of this container, cost is the cost of the edge; this Container will be executed after dep
7991
void AddDep(LPContainer* dep, CostType cost);
92+
// Get all deps of the container
8093
const std::vector<std::tuple<LPContainer*, CostType>> GetDeps() const {
8194
return deps_;
8295
}
96+
//when a container is frozen, it can not be add deps
8397
void Freeze() { frozen_ = true; }
8498

8599
private:
@@ -97,32 +111,59 @@ class LPContainer{
97111
// LPContainerDAG is a graph of container, it can be used to store the DAG of container
98112
// be used as a atomic unit of LPContainer
99113
template <typename ElementType>
100-
class LPContainerDAG{
114+
class LPContainerDAG: public LPContainer<ElementType>{
101115
//we can use InstructionDAG to get memory effect order
102116
public:
103117
// maintain a DAG of inner elements
104118
struct DAGEdge{
105-
LPContainerDAG* from;
106-
LPContainerDAG* to;
119+
LPContainer<ElementType>* from;
120+
LPContainer<ElementType>* to;
107121
CostType cost;
108122
};
109-
//create a LPContainerDAG with
110-
LPContainerDAG(ElementType inner_element, CostType cost, NodeType type): cost_(cost), type_(type){
111-
inner_elements.push_back(LPContainer<ElementType>(inner_element, cost, type));
112-
};
113-
bool IsIn(LPContainerDAG<ElementType> *a){
114-
return users_.find(a) != users_.end();
123+
//create a LPContainerDAG with one element
124+
LPContainerDAG(ElementType inner_element, CostType cost, NodeType type): LPContainer<ElementType>(inner_element,cost,type){
125+
//TODO: there should not create element?
126+
auto ele = new LPContainer<ElementType>(inner_element, cost, type);
127+
inner_elements.push_back(ele);
115128
};
129+
bool IsIn(LPContainer<ElementType>* a);
116130
//which container can be put together:1. they have the same type 2. they have dep between them
117-
static bool CanFused(LPContainerDAG<ElementType>* a, LPContainerDAG<ElementType>* b){
131+
// static bool CanFused(LPContainerDAG<ElementType>* a, LPContainerDAG<ElementType>* b);
118132

119-
};
120-
// AddChild
133+
//override LPContainer
134+
const std::string GetName(){
135+
std::string name = "LPContainerDAG{";
136+
for(auto ele: inner_elements){
137+
name += ele->GetName();
138+
name+="\n";
139+
}
140+
name+="}";
141+
return name;
142+
}
143+
const int UUID() { return inner_elements[0]->UUID(); }
144+
const bool HasValue() { return inner_elements.size()>0;}
145+
const std::vector<ElementType> GetValues() {
146+
std::vector<ElementType> values;
147+
for(auto ele: inner_elements){
148+
for(auto inst:ele->GetValues()){
149+
values.push_back(inst);
150+
}
151+
}
152+
return values;
153+
}
154+
// AddChild, child should maintain the deps before
155+
void AddToDAG(LPContainer<ElementType>* child);
156+
const std::vector<LPContainer<ElementType>*> GetInnerElements() const{
157+
return inner_elements;
158+
}
159+
//merge other LPContainerDAG to this LPContainerDAG,then destroy other LPContainerDAG
160+
Status MergeFrom(LPContainerDAG<ElementType>* other);
121161
private:
122162

123-
std::set<ElementType> users_;
124-
std::set<ElementType> operands_;
125-
std::vector<LPContainer<ElementType>> inner_elements;
163+
std::set<LPContainer<ElementType>*> operands_;
164+
std::vector<LPContainer<ElementType>*> inner_elements;
165+
//maintain edges between inner_elements
166+
std::vector<DAGEdge> edges_;
126167
CostType cost_;
127168
CostType startat_;
128169
NodeType type_;

xla/hlo/experimental/auto_reorder/auto_reorder_test.cc

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -436,7 +436,7 @@ ENTRY %elementwise {
436436
insts2cost.push_back(std::make_tuple(ar_done, 1));
437437

438438
insts_list.push_back(ar_done);
439-
edge2cost.push_back(std::make_tuple(ar_done, cost_gen()));
439+
edge2cost.push_back(std::make_tuple(ar_done, cost_gen()+50));
440440
not_used_insts.insert(ar_done);
441441
}
442442
}
@@ -861,9 +861,7 @@ TEST_F(AutoReorderingTest, ReorderScheduleComputation) {
861861
std::unique_ptr<LatencyEstimator> latency_estimator;
862862
int pointer_size_ = 4;
863863
Backend& test_backend = backend();
864-
const se::DeviceDescription& gpu_device_info =
865-
test_backend.default_stream_executor()->GetDeviceDescription();
866-
// auto gpu_device_info = TestGpuDeviceInfo::RTXA6000DeviceInfo();
864+
auto gpu_device_info = TestGpuDeviceInfo::RTXA6000DeviceInfo();
867865

868866
VLOG(2) << "threads_per_block_limit:"
869867
<< gpu_device_info.threads_per_block_limit() << " threads_per_warp"
@@ -914,8 +912,7 @@ TEST_F(AutoReorderingTest, ReorderPass) {
914912
EXPECT_TRUE(st.ok());
915913
int pointer_size_ = 4;
916914
Backend& test_backend = backend();
917-
const se::DeviceDescription& gpu_device_info =
918-
test_backend.default_stream_executor()->GetDeviceDescription();
915+
auto gpu_device_info = TestGpuDeviceInfo::RTXA6000DeviceInfo();
919916
const int64_t scheduler_mem_limit = xla::gpu::GetSchedulerMemoryLimit(
920917
hlo_module.get(), gpu_device_info, pointer_size_);
921918
SchedulerConfig config = GetSchedulerConfig(scheduler_mem_limit);
@@ -954,8 +951,7 @@ TEST_F(AutoReorderingTest, ReorderPassWithDefaultEstimator) {
954951
EXPECT_TRUE(st.ok());
955952
int pointer_size_ = 4;
956953
Backend& test_backend = backend();
957-
const se::DeviceDescription& gpu_device_info =
958-
test_backend.default_stream_executor()->GetDeviceDescription();
954+
auto gpu_device_info = TestGpuDeviceInfo::RTXA6000DeviceInfo();
959955
const int64_t scheduler_mem_limit = xla::gpu::GetSchedulerMemoryLimit(
960956
hlo_module.get(), gpu_device_info, pointer_size_);
961957
SchedulerConfig config = GetSchedulerConfig(scheduler_mem_limit);
@@ -972,8 +968,8 @@ TEST_F(AutoReorderingTest, ReorderPassWithDefaultEstimator) {
972968
EXPECT_TRUE(status.ok());
973969
}
974970
TEST_F(AutoReorderingTest, ReorderPassWithRandom) {
971+
// GTEST_SKIP() << "Skipping single test";
975972
std::srand(kRandomSeed);
976-
// communication rate from 0.05 to 0.95,step is 0.05
977973
auto hlo_module = CreateNewUnverifiedModule();
978974
auto gpu_latency_estimator = std::make_unique<SavedInstLatencyEstimator>();
979975
SchedulerConfig sched_config = GetDefaultSchedConfig();
@@ -1027,12 +1023,20 @@ TEST_F(AutoReorderingTest, ReorderPassWithRandom) {
10271023
}
10281024
// skip this test
10291025
TEST_F(AutoReorderingTest, ReorderPassDataAnalyse) {
1030-
// GTEST_SKIP() << "Skipping single test";
1026+
GTEST_SKIP() << "Skipping single test";
10311027
std::srand(kRandomSeed);
10321028
auto gen = std::mt19937{kRandomSeed};
1033-
int repeat_time = 3;
1034-
uint32_t nnodes = 50;
1035-
std::vector<float> communication_rates = {0.1,0.15,0.2,0.25,0.3,0.65,0.7,0.75,0.8,0.85};
1029+
int repeat_time = 1;
1030+
uint32_t nnodes = 100;
1031+
std::vector<float> communication_rates;
1032+
// = {
1033+
// 0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9
1034+
// };
1035+
for (float current=0.1; current < 0.9; current+=0.05)
1036+
{
1037+
communication_rates.push_back(current);
1038+
}
1039+
10361040
// communication rate from 0.05 to 0.95,step is 0.05
10371041
std::ofstream csv_out("/tmp/test_ret.csv");
10381042
csv_out<<"exp_id,nnodes,communication_rate,auto_reorder_cost,post_order_cost,xla_hiding_order_cost,xla_hiding_solve_time,auto_reorder_solve_time"<<std::endl;
@@ -1050,6 +1054,8 @@ TEST_F(AutoReorderingTest, ReorderPassDataAnalyse) {
10501054
/*communication rate*/ communication_rate,
10511055
/* gen */gen);
10521056
EXPECT_TRUE(st.ok());
1057+
// auto latency_estimator = create_latency_estimator();
1058+
10531059
auto gpu_latency_estimator2 = gpu_latency_estimator->clone();
10541060
auto gpu_latency_estimator3 = gpu_latency_estimator->clone();
10551061
// run AutoReorder for compare

xla/service/gpu/gpu_hlo_schedule.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ limitations under the License.
4444
#include "xla/hlo/ir/hlo_schedule.h"
4545
#include "xla/hlo/utils/hlo_query.h"
4646
#include "xla/service/buffer_value.h"
47+
#include "xla/hlo/experimental/auto_reorder/auto_reorder.h"
4748
#include "xla/service/gpu/backend_configs.pb.h"
4849
#include "xla/service/gpu/cublas_cudnn.h"
4950
#include "xla/service/gpu/gpu_schedule_postprocessing.h"

0 commit comments

Comments
 (0)