Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a script to run a single matmul configuration with custom MatmulParams #3918

Merged
merged 7 commits into from
Feb 21, 2025

Conversation

rdspring1
Copy link
Collaborator

This PR adds a script to run a single matmul configuration with custom MatmulParams. It profiles the nvfuser kernel and compares its runtime against nvjet kernel runtime. The nvjet kernel runtimes are stored in a json file, generated by python benchmarks.

  • Update python bindings to support constructing Matmul Options.

@rdspring1
Copy link
Collaborator Author

!test

Copy link

github-actions bot commented Feb 18, 2025

Review updated until commit f269852

Description

  • Add script for profiling single matmul configuration

  • Update Python bindings for MatmulParams

  • Implement custom scheduler for matmul

  • Validate nvFuser against PyTorch matmul


Changes walkthrough 📝

Relevant files
Enhancement
python_bindings.cpp
Update Python bindings for MatmulParams                                   

csrc/python_frontend/python_bindings.cpp

  • Add constructors for GemmTile, MatMulTileOptions,
    CircularBufferOptions, SupportedVectorization, ClusterDims, and
    MmaMacroEncode
  • Update class definitions to use py::class_ instead of DEFINECLASS
    macro
  • +23/-4   
    profile_matmul.py
    Add script for profiling single matmul configuration         

    doc/dev/python_scheduling/profile_matmul.py

  • Add script to run and profile a single matmul configuration
  • Implement functions for estimating matmul size, getting kernel time,
    and defining matmul fusion
  • Implement custom scheduler with custom parameters
  • Add main function to parse arguments and run profiling
  • +209/-0 

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Initialization Order

    The order of initialization for MatmulParams classes should be consistent with the class definitions to avoid potential issues.

        .def(py::init<int64_t, int64_t, int64_t>())
        .PARAM(GemmTile, m)
        .PARAM(GemmTile, n)
        .PARAM(GemmTile, k)
        .TOSTRINGTOPLEVEL(GemmTile);
    
    DEFINECLASS(MatMulTileOptions)
        .def(py::init<GemmTile, GemmTile>())
        .PARAM(MatMulTileOptions, cta_tile)
        .PARAM(MatMulTileOptions, warp_tile)
        .TOSTRINGTOPLEVEL(MatMulTileOptions);
    
    py::class_<MatmulParams::CircularBufferOptions>(
    Missing Definitions

    The DEFINECLASS macro is used instead of py::class_ for some classes, which might lead to missing bindings or incorrect behavior.

        .def(py::init<int64_t, int64_t, int64_t>())
        .PARAM(GemmTile, m)
        .PARAM(GemmTile, n)
        .PARAM(GemmTile, k)
        .TOSTRINGTOPLEVEL(GemmTile);
    
    DEFINECLASS(MatMulTileOptions)
        .def(py::init<GemmTile, GemmTile>())
        .PARAM(MatMulTileOptions, cta_tile)
        .PARAM(MatMulTileOptions, warp_tile)
        .TOSTRINGTOPLEVEL(MatMulTileOptions);
    
    py::class_<MatmulParams::CircularBufferOptions>(
    Error Handling

    The error handling in test_matmul_nvf could be improved to provide more informative messages or handle specific exceptions.

        nvf_outputs = scheduled_fd.execute([a, b], profile=True)
    except Exception as e:
        if verbose:
            print(e)
        return -1

    Copy link
    Collaborator

    @jacobhinkle jacobhinkle left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Looks good! Thanks for adding this. Comments are pretty minor I think.

    Comment on lines 82 to 83
    for shape in [[m, k], [n, k], [m, n]]:
    total_in_gbs += _estimate_size(shape, dtype)
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    This is the bare minimum size required to compute the GEMM. If there is validation enabled, there will be more than that because of more outputs: 1 for the eager output plus intermediates required for torch.allclose. So you might want to add a fudge factor.

    @rdspring1 rdspring1 force-pushed the single_problem_matmul branch from 4654545 to 6f56b47 Compare February 20, 2025 23:29
    @rdspring1
    Copy link
    Collaborator Author

    @jacobhinkle I added the following argparse.

    usage: profile_matmul.py [-h] [--verbose] [--validate] m n k {NN,NT,TN,TT}
    
    Run through a combination of matmul parameters and compare relative performance against nvjet for a single problem.
    
    positional arguments:
      m              The size of M dimension
      n              The size of N dimension
      k              The size of K dimension
      {NN,NT,TN,TT}  The layout for matmul problem.
    
    options:
      -h, --help     show this help message and exit
      --verbose      Print matmul parameters and exceptions.
      --validate     Validate nvfuser against pytorch matmul.
    
    How to run script: NVFUSER_ENABLE=fuse_matmul NVFUSER_DISABLE=matmul_expr_eval python single_matmul.py nvjet_pybench.json 1752 4720 584 NN --verbose --validate

    @rdspring1
    Copy link
    Collaborator Author

    !build

    @rdspring1 rdspring1 merged commit 2b5ea2a into main Feb 21, 2025
    16 checks passed
    @rdspring1 rdspring1 deleted the single_problem_matmul branch February 21, 2025 03:55
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
    Labels
    Projects
    None yet
    Development

    Successfully merging this pull request may close these issues.

    2 participants