diff --git a/phlex/core/declared_provider.cpp b/phlex/core/declared_provider.cpp index da4e98209..6dcff66e3 100644 --- a/phlex/core/declared_provider.cpp +++ b/phlex/core/declared_provider.cpp @@ -1,7 +1,8 @@ #include "phlex/core/declared_provider.hpp" +#include "phlex/model/full_product_spec.hpp" namespace phlex::experimental { - declared_provider::declared_provider(algorithm_name name, product_query output_product) : + declared_provider::declared_provider(algorithm_name name, full_product_spec output_product) : name_{std::move(name)}, output_product_{std::move(output_product)} { } @@ -10,10 +11,10 @@ namespace phlex::experimental { std::string declared_provider::full_name() const { return name_.full(); } - product_query const& declared_provider::output_product() const noexcept + full_product_spec const& declared_provider::output_product() const noexcept { return output_product_; } - identifier const& declared_provider::layer() const noexcept { return output_product_.layer; } + identifier const& declared_provider::layer() const noexcept { return output_product_.layer(); } } diff --git a/phlex/core/declared_provider.hpp b/phlex/core/declared_provider.hpp index 0d1fbd5c7..cb149a312 100644 --- a/phlex/core/declared_provider.hpp +++ b/phlex/core/declared_provider.hpp @@ -8,6 +8,7 @@ #include "phlex/core/message.hpp" #include "phlex/model/algorithm_name.hpp" #include "phlex/model/data_cell_index.hpp" +#include "phlex/model/full_product_spec.hpp" #include "phlex/model/product_specification.hpp" #include "phlex/model/product_store.hpp" #include "phlex/utilities/simple_ptr_map.hpp" @@ -26,11 +27,11 @@ namespace phlex::experimental { class PHLEX_CORE_EXPORT declared_provider { public: - declared_provider(algorithm_name name, product_query output_product); + declared_provider(algorithm_name name, full_product_spec output_product); virtual ~declared_provider(); std::string full_name() const; - product_query const& output_product() const noexcept; + full_product_spec const& output_product() const noexcept; identifier const& layer() const noexcept; virtual tbb::flow::receiver* input_port() = 0; @@ -39,7 +40,7 @@ namespace phlex::experimental { private: algorithm_name name_; - product_query output_product_; + full_product_spec output_product_; }; using declared_provider_ptr = std::unique_ptr; @@ -56,11 +57,9 @@ namespace phlex::experimental { std::size_t concurrency, tbb::flow::graph& g, AlgorithmBits alg, - product_query output) : + full_product_spec output) : declared_provider{std::move(name), output}, - output_{algorithm_name::create(std::string_view(identifier(output.creator))), - output.suffix.value_or(identifier("")), - output.type}, + output_{output.spec()}, provider_{g, concurrency, [this, ft = alg.release_algorithm()](index_message const& index_msg, auto& output) { diff --git a/phlex/core/product_query.cpp b/phlex/core/product_query.cpp index e91422f3e..d0a04aeca 100644 --- a/phlex/core/product_query.cpp +++ b/phlex/core/product_query.cpp @@ -44,6 +44,22 @@ namespace phlex { return true; } + // Check if a full_product_spec satisfies this query + bool product_query::match(experimental::full_product_spec const& spec) const + { + using experimental::identifier; + if (!match(spec.spec())) { + return false; + } + if (identifier(layer) != spec.layer()) { + return false; + } + if (stage && stage != spec.stage()) { + return false; + } + return true; + } + std::string product_query::to_string() const { if (suffix) { diff --git a/phlex/core/product_query.hpp b/phlex/core/product_query.hpp index 1dfa2c8b5..fda7620af 100644 --- a/phlex/core/product_query.hpp +++ b/phlex/core/product_query.hpp @@ -3,6 +3,7 @@ #include "phlex/phlex_core_export.hpp" +#include "phlex/model/full_product_spec.hpp" #include "phlex/model/identifier.hpp" #include "phlex/model/product_specification.hpp" #include "phlex/model/product_store.hpp" @@ -80,10 +81,28 @@ namespace phlex { // Check if a product_specification satisfies this query bool match(experimental::product_specification const& spec) const; + // Check if a full_product_spec satisfies this query + bool match(experimental::full_product_spec const& spec) const; + std::string to_string() const; bool operator==(product_query const& rhs) const; std::strong_ordering operator<=>(product_query const& rhs) const; + + // Transitional automatic conversion operator so I don't have to rewrite all the tests + // The deprecated annotation is commented out because we have -Werror on + // [[deprecated( + // "Generation of a full_product_spec from a product_query is only transitionally supported")]] + operator experimental::full_product_spec() + { + return experimental::full_product_spec{ + experimental::product_specification{experimental::algorithm_name::create(std::string_view( + experimental::identifier(this->creator))), + this->suffix.value(), + this->type}, + experimental::identifier(this->layer), + this->stage.value_or("")}; + } }; inline std::string format_as(product_query const& q) { return q.to_string(); } diff --git a/phlex/core/registration_api.hpp b/phlex/core/registration_api.hpp index e63632333..d05c689c9 100644 --- a/phlex/core/registration_api.hpp +++ b/phlex/core/registration_api.hpp @@ -133,12 +133,12 @@ namespace phlex::experimental { { } - auto output_product(product_query output) + auto output_product(full_product_spec output) { using return_type = return_type; using provider_type = provider_node; - output.type = make_type_id(); + output.set_type(make_type_id()); registrar_.set_creator([this, output = std::move(output)]( auto /* predicates */, auto /* output_product_suffixes */) { diff --git a/phlex/model/full_product_spec.hpp b/phlex/model/full_product_spec.hpp new file mode 100644 index 000000000..658174033 --- /dev/null +++ b/phlex/model/full_product_spec.hpp @@ -0,0 +1,54 @@ +#ifndef PHLEX_MODEL_FULL_PRODUCT_SPEC_HPP +#define PHLEX_MODEL_FULL_PRODUCT_SPEC_HPP + +#include "phlex/phlex_model_export.hpp" + +#include "phlex/model/identifier.hpp" +#include "phlex/model/product_specification.hpp" +#include "phlex/model/type_id.hpp" + +#include "boost/container_hash/hash.hpp" +#include "fmt/format.h" + +namespace phlex::experimental { + class PHLEX_MODEL_EXPORT full_product_spec { + public: + full_product_spec(product_specification spec, identifier layer, identifier stage) : + spec_{std::move(spec)}, layer_{std::move(layer)}, stage_{std::move(stage)} + { + } + + bool operator==(full_product_spec const&) const noexcept = default; + + product_specification const& spec() const noexcept { return spec_; } + algorithm_name const& creator() const noexcept { return spec_.qualifier(); } + identifier const& suffix() const noexcept { return spec_.suffix(); } + type_id type() const noexcept { return spec_.type(); } + void set_type(type_id&& type) { spec_.set_type(std::move(type)); } + identifier const& layer() const noexcept { return layer_; } + identifier const& stage() const noexcept { return stage_; } + std::size_t hash() const noexcept + { + std::size_t result = creator().plugin().hash(); + boost::hash_combine(result, creator().algorithm().hash()); + boost::hash_combine(result, suffix().hash()); + boost::hash_combine(result, layer().hash()); + boost::hash_combine(result, stage().hash()); + boost::hash_combine(result, type()); + return result; + } + + std::string to_string() const + { + return fmt::format( + "{}:{}/{} ϵ {}", creator().plugin(), creator().algorithm(), suffix(), layer()); + } + + private: + product_specification spec_; + identifier layer_; + identifier stage_; + }; +} + +#endif // PHLEX_MODEL_FULL_PRODUCT_SPEC_HPP diff --git a/plugins/python/src/modulewrap.cpp b/plugins/python/src/modulewrap.cpp index a43f726b1..50d8564d8 100644 --- a/plugins/python/src/modulewrap.cpp +++ b/plugins/python/src/modulewrap.cpp @@ -1,6 +1,7 @@ #include "wrap.hpp" #include "phlex/model/data_cell_index.hpp" +#include "phlex/model/full_product_spec.hpp" #include #include @@ -1188,6 +1189,12 @@ static PyObject* sc_provide(py_phlex_source* src, PyObject* args, PyObject* kwds throw std::runtime_error("output specification error: " + msg); } } + auto ops = full_product_spec( + product_specification(algorithm_name::create(std::string_view(identifier(opq->creator))), + opq->suffix.value(), + opq->type), + identifier(opq->layer), + opq->stage.value_or(""_id)); // insert provider node (TODO: as in transform and observe, we'll leak the // callable for now, until there's a proper shutdown procedure) @@ -1197,47 +1204,47 @@ static PyObject* sc_provide(py_phlex_source* src, PyObject* args, PyObject* kwds std::string const& out_type = output_types[0]; if (out_type == "bool") { auto* pyc = new provider_cb_bool{callable}; - src->ph_source->provide(functor_name, *pyc).output_product(opq.value()); + src->ph_source->provide(functor_name, *pyc).output_product(ops); } else if (out_type == "int32_t") { auto* pyc = new provider_cb_int{callable}; - src->ph_source->provide(functor_name, *pyc).output_product(opq.value()); + src->ph_source->provide(functor_name, *pyc).output_product(ops); } else if (out_type == "uint32_t") { auto* pyc = new provider_cb_uint{callable}; - src->ph_source->provide(functor_name, *pyc).output_product(opq.value()); + src->ph_source->provide(functor_name, *pyc).output_product(ops); } else if (out_type == "int64_t") { auto* pyc = new provider_cb_long{callable}; - src->ph_source->provide(functor_name, *pyc).output_product(opq.value()); + src->ph_source->provide(functor_name, *pyc).output_product(ops); } else if (out_type == "uint64_t") { auto* pyc = new provider_cb_ulong{callable}; - src->ph_source->provide(functor_name, *pyc).output_product(opq.value()); + src->ph_source->provide(functor_name, *pyc).output_product(ops); } else if (out_type == "float") { auto* pyc = new provider_cb_float{callable}; - src->ph_source->provide(functor_name, *pyc).output_product(opq.value()); + src->ph_source->provide(functor_name, *pyc).output_product(ops); } else if (out_type == "double") { auto* pyc = new provider_cb_double{callable}; - src->ph_source->provide(functor_name, *pyc).output_product(opq.value()); + src->ph_source->provide(functor_name, *pyc).output_product(ops); } else if (out_type.compare(0, 7, "ndarray") == 0 || out_type.compare(0, 4, "list") == 0) { // TODO: just like for input types, these are hard-coded, but should be handled by // an IDL instead. std::string_view dtype{out_type.begin() + out_type.rfind('['), out_type.end()}; if (dtype == "[int32_t]") { auto* pyc = new provider_cb_vint{callable}; - src->ph_source->provide(functor_name, *pyc).output_product(opq.value()); + src->ph_source->provide(functor_name, *pyc).output_product(ops); } else if (dtype == "[uint32_t]") { auto* pyc = new provider_cb_vuint{callable}; - src->ph_source->provide(functor_name, *pyc).output_product(opq.value()); + src->ph_source->provide(functor_name, *pyc).output_product(ops); } else if (dtype == "[int64_t]") { auto* pyc = new provider_cb_vlong{callable}; - src->ph_source->provide(functor_name, *pyc).output_product(opq.value()); + src->ph_source->provide(functor_name, *pyc).output_product(ops); } else if (dtype == "[uint64_t]") { auto* pyc = new provider_cb_vulong{callable}; - src->ph_source->provide(functor_name, *pyc).output_product(opq.value()); + src->ph_source->provide(functor_name, *pyc).output_product(ops); } else if (dtype == "[float]") { auto* pyc = new provider_cb_vfloat{callable}; - src->ph_source->provide(functor_name, *pyc).output_product(opq.value()); + src->ph_source->provide(functor_name, *pyc).output_product(ops); } else if (dtype == "[double]") { auto* pyc = new provider_cb_vdouble{callable}; - src->ph_source->provide(functor_name, *pyc).output_product(opq.value()); + src->ph_source->provide(functor_name, *pyc).output_product(ops); } else { PyErr_Format(PyExc_TypeError, "unsupported collection output type \"%s\"", out_type.c_str()); return nullptr;