Skip to content

[WIP] hack to investigate parallelization settings#24147

Open
xadupre wants to merge 1 commit intomicrosoft:mainfrom
xadupre:treedoc
Open

[WIP] hack to investigate parallelization settings#24147
xadupre wants to merge 1 commit intomicrosoft:mainfrom
xadupre:treedoc

Conversation

@xadupre
Copy link
Copy Markdown
Member

@xadupre xadupre commented Mar 24, 2025

Description

The parallelization settings for TreeEnsemble are fixed (https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h#L38). This a hack to modify them before creating the session.

Motivation and Context

The fixed settings should be updated for different processor and sometimes for different sizes of datasets.

The following function takes a model with a TreeEnsemble (ai.onnx.ml==3) and updates the parallelization settings.

def transform_model(
    model: onnx.ModelProto,
    parallel_tree: int,
    parallel_tree_N: int,
    parallel_N: int,
    first=-556,
):
    """
    Modifies the graph.
    Attributes is unused ``nodes_hitrates_as_tensor`` by the runtime so we use
    that field to specify parallelization settings.

    :param model: model proto serialized, the function makes a copy
    :param parallel_tree: see https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h#L38
    :param parallel_tree_N: see https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h#L39
    :param parallel_N: see https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h#L40
    :param first: -556 or -555, -556 makes onnxruntime prints out the parallelization
        settings to make sure these are the expected values
    :return: ModelProto
    """
    onx = onnx.ModelProto()
    onx.ParseFromString(model.SerializeToString())
    new_nodes = []
    for node in onx.graph.node:
        if node.op_type.startswith("TreeEnsemble"):
            new_atts = []
            for att in node.attribute:
                if att.name.startswith("nodes_hitrates"):
                    continue
                new_atts.append(att)
            new_atts.append(
                onnx.helper.make_attribute(
                    "nodes_hitrates_as_tensor",
                    onnx.numpy_helper.from_array(
                        numpy.array(
                            [first, parallel_tree, parallel_tree_N, parallel_N],
                            dtype=numpy.float32,
                        ),
                        name="nodes_hitrates_as_tensor",
                    ),
                )
            )
            del node.attribute[:]
            node.attribute.extend(new_atts)
            new_nodes.append(node)
            continue
        new_nodes.append(node)
    del onx.graph.node[:]
    onx.graph.node.extend(new_nodes)
    del onx.opset_import[:]
    onx.opset_import.extend(
        [onnx.helper.make_opsetid("", 18), onnx.helper.make_opsetid("ai.onnx.ml", 3)]
    )
    return onx

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant