55#include < tuple>
66#include < unordered_map>
77#include < set>
8-
8+ # include < thread >
99#include " absl/container/flat_hash_map.h"
1010#include " absl/strings/string_view.h"
1111#include " xla/hlo/ir/hlo_module.h"
@@ -17,10 +17,19 @@ namespace xla {
1717 using CpModelBuilder = operations_research::sat::CpModelBuilder;
1818 using IntervalVar = operations_research::sat::IntervalVar;
1919 namespace reorder {
20- const uint32_t ksolveTimeout = 30 ; // 30s
20+ const uint32_t ksolveTimeout = 60 ; // 30s
2121
2222 static const int kChannelNumber = 2 ;
23- bool solve_debug=false ;
23+ int get_horizon (int max_time){
24+ // scale
25+ return max_time*1.2 ;
26+ }
27+ bool solve_debug=true ;
28+ // get cpu number of current machine
29+ int get_cpu_number (){
30+ // return 8;
31+ return std::thread::hardware_concurrency ();
32+ }
2433 }
2534 enum class NodeType { kCompute = 0 , kCommunication = 1 };
2635
@@ -59,7 +68,7 @@ class LPNode{
5968template <typename ElementType>
6069class LPContainer {
6170 public:
62-
71+ // create a LPContainer with inner_element, cost and type
6372 LPContainer (ElementType inner_element, CostType cost, NodeType type)
6473 : inner_element_(inner_element), cost_(cost), type_(type) {
6574 uuid_ = reinterpret_cast <uintptr_t >(this );
@@ -71,15 +80,20 @@ class LPContainer{
7180 CostType GetCost () const { return cost_; }
7281 void SetStart (CostType start) { startat_ = start; }
7382 CostType GetStart () { return startat_; }
83+ // Get the type of the container: compute or communication
7484 bool IsComputation () const { return type_ == NodeType::kCompute ; }
7585 bool IsCommunication () const { return type_ == NodeType::kCommunication ; }
7686 NodeType GetType () const { return type_; }
77- bool HasValue () const { return inner_element_ != nullptr ; }
78- ElementType GetValue () const { return inner_element_; }
87+
88+ const bool HasValue () { return inner_element_ != nullptr ; }
89+ const std::vector<ElementType> GetValues () { return std::vector<ElementType>{inner_element_}; }
90+ // Add a dep of this container, cost is the cost of the edge; this Container will be executed after dep
7991 void AddDep (LPContainer* dep, CostType cost);
92+ // Get all deps of the container
8093 const std::vector<std::tuple<LPContainer*, CostType>> GetDeps () const {
8194 return deps_;
8295 }
96+ // when a container is frozen, it can not be add deps
8397 void Freeze () { frozen_ = true ; }
8498
8599 private:
@@ -97,32 +111,59 @@ class LPContainer{
97111// LPContainerDAG is a graph of container, it can be used to store the DAG of container
98112// be used as a atomic unit of LPContainer
99113template <typename ElementType>
100- class LPContainerDAG {
114+ class LPContainerDAG : public LPContainer <ElementType> {
101115 // we can use InstructionDAG to get memory effect order
102116 public:
103117 // maintain a DAG of inner elements
104118 struct DAGEdge {
105- LPContainerDAG * from;
106- LPContainerDAG * to;
119+ LPContainer<ElementType> * from;
120+ LPContainer<ElementType> * to;
107121 CostType cost;
108122 };
109- // create a LPContainerDAG with
110- LPContainerDAG (ElementType inner_element, CostType cost, NodeType type): cost_(cost), type_(type){
111- inner_elements.push_back (LPContainer<ElementType>(inner_element, cost, type));
112- };
113- bool IsIn (LPContainerDAG<ElementType> *a){
114- return users_.find (a) != users_.end ();
123+ // create a LPContainerDAG with one element
124+ LPContainerDAG (ElementType inner_element, CostType cost, NodeType type): LPContainer<ElementType>(inner_element,cost,type){
125+ // TODO: there should not create element?
126+ auto ele = new LPContainer<ElementType>(inner_element, cost, type);
127+ inner_elements.push_back (ele);
115128 };
129+ bool IsIn (LPContainer<ElementType>* a);
116130 // which container can be put together:1. they have the same type 2. they have dep between them
117- static bool CanFused (LPContainerDAG<ElementType>* a, LPContainerDAG<ElementType>* b){
131+ // static bool CanFused(LPContainerDAG<ElementType>* a, LPContainerDAG<ElementType>* b);
118132
119- };
120- // AddChild
133+ // override LPContainer
134+ const std::string GetName (){
135+ std::string name = " LPContainerDAG{" ;
136+ for (auto ele: inner_elements){
137+ name += ele->GetName ();
138+ name+=" \n " ;
139+ }
140+ name+=" }" ;
141+ return name;
142+ }
143+ const int UUID () { return inner_elements[0 ]->UUID (); }
144+ const bool HasValue () { return inner_elements.size ()>0 ;}
145+ const std::vector<ElementType> GetValues () {
146+ std::vector<ElementType> values;
147+ for (auto ele: inner_elements){
148+ for (auto inst:ele->GetValues ()){
149+ values.push_back (inst);
150+ }
151+ }
152+ return values;
153+ }
154+ // AddChild, child should maintain the deps before
155+ void AddToDAG (LPContainer<ElementType>* child);
156+ const std::vector<LPContainer<ElementType>*> GetInnerElements () const {
157+ return inner_elements;
158+ }
159+ // merge other LPContainerDAG to this LPContainerDAG,then destroy other LPContainerDAG
160+ Status MergeFrom (LPContainerDAG<ElementType>* other);
121161 private:
122162
123- std::set<ElementType> users_;
124- std::set<ElementType> operands_;
125- std::vector<LPContainer<ElementType>> inner_elements;
163+ std::set<LPContainer<ElementType>*> operands_;
164+ std::vector<LPContainer<ElementType>*> inner_elements;
165+ // maintain edges between inner_elements
166+ std::vector<DAGEdge> edges_;
126167 CostType cost_;
127168 CostType startat_;
128169 NodeType type_;
0 commit comments