Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[submodule "triton"]
path = triton
url = https://github.com/triton-lang/triton.git
path = triton
url = https://github.com/triton-lang/triton.git
[submodule "triton_shared"]
path = triton_shared
url = https://github.com/microsoft/triton-shared
path = triton_shared
url = https://github.com/facebookincubator/triton-shared.git
90 changes: 90 additions & 0 deletions ci/apply_patches.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
#!/bin/bash
#
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# SPDX-License-Identifier: BSD-3-Clause.
# For more license information:
# https://github.com/qualcomm/hexagon-mlir/LICENSE.txt
#
set -Eeuox pipefail

# If the CI image accidentally has /etc/gitconfig as a *directory*, Git will fail
# when reading the system config. We detect that here and warn that this script
# will ignore the system gitconfig by using GIT_CONFIG_NOSYSTEM=1 on all git calls.
if [ -d /etc/gitconfig ]; then
echo "WARNING: Detected /etc/gitconfig is a DIRECTORY; system gitconfig will be ignored via GIT_CONFIG_NOSYSTEM=1 for all git operations in this script."
fi

SCRIPT_DIR="$(readlink -f "$(dirname "$0")")"
HEXAGON_MLIR_ROOT="$(readlink -f "$SCRIPT_DIR/../")"
TRITON_ROOT="$HEXAGON_MLIR_ROOT/triton"

# Apply a patch if it isn't already applied (stateless; no marker file).
apply_patch_if_needed() {
local repo_dir="$1" # e.g., "$TRITON_ROOT" or "$HEXAGON_MLIR_ROOT/triton_shared"
local patch_file="$2" # e.g., ".../patches/triton/third_party_triton.patch"

if [ ! -f "$patch_file" ]; then
echo "WARNING: Patch file not found at $patch_file"
return 0
fi

echo "Checking/applying patch: $patch_file in $repo_dir"
pushd "$repo_dir" >/dev/null

# 1) If the reverse applies, the patch is already present — skip.
if GIT_CONFIG_NOSYSTEM=1 git apply --reverse --check "$patch_file" >/dev/null 2>&1; then
echo "Patch already applied (reverse-check passed): $patch_file — skipping."
popd >/dev/null
return 0
fi

# 2) Try to apply the patch directly.
# git apply is the source of truth to avoid TOCTOU races.
if GIT_CONFIG_NOSYSTEM=1 git apply "$patch_file"; then
echo "Patch applied successfully: $patch_file"
popd >/dev/null
return 0
fi

# 3) Neither forward nor reverse apply cleanly -> inconsistent state / conflicts.
echo "ERROR: Patch neither applies nor is already applied: $patch_file"
echo "----- git apply --check (verbose) output -----"
GIT_CONFIG_NOSYSTEM=1 git apply --check -v "$patch_file" || true
echo "----------------------------------------------"
popd >/dev/null
exit 1
}

# -----------------------------------------------------------------------------
# Apply patches (drop the marker basename arg; keep the order if patches depend on one another)
# -----------------------------------------------------------------------------

# -----------------------------------------------------------------------------
# triton_shared patches
# -----------------------------------------------------------------------------
# Triton shared patch to update the API for compatibility with the latest LLVM
TRITON_SHARED_API_UPDATE_PATCH_FILE="$HEXAGON_MLIR_ROOT/third_party_software/patches/triton_shared/triton_shared_3_6_triton.patch"
apply_patch_if_needed "$HEXAGON_MLIR_ROOT/triton_shared" "$TRITON_SHARED_API_UPDATE_PATCH_FILE"

# Triton shared patch on Pointer Analysis
TRITON_SHARED_POINTER_ANALYSIS_PATCH_FILE="$HEXAGON_MLIR_ROOT/third_party_software/patches/triton_shared/triton_shared_ptr_analysis.patch"
apply_patch_if_needed "$HEXAGON_MLIR_ROOT/triton_shared" "$TRITON_SHARED_POINTER_ANALYSIS_PATCH_FILE"

# Triton shared patch on split pointers
TRITON_SHARED_SPLIT_DIM_PATCH_FILE="$HEXAGON_MLIR_ROOT/third_party_software/patches/triton_shared/triton_shared_split_dim.patch"
apply_patch_if_needed "$HEXAGON_MLIR_ROOT/triton_shared" "$TRITON_SHARED_SPLIT_DIM_PATCH_FILE"

# Triton shared patch to handle canonicalization pattern of Max with NaN propagation
TRITON_SHARED_MAX_NAN_PATCH_FILE="$HEXAGON_MLIR_ROOT/third_party_software/patches/triton_shared/triton_shared_max_nan.patch"
apply_patch_if_needed "$HEXAGON_MLIR_ROOT/triton_shared" "$TRITON_SHARED_MAX_NAN_PATCH_FILE"

# -----------------------------------------------------------------------------
# Triton patches (build third-party backends + NVVM ReductionKind compatibility)
# -----------------------------------------------------------------------------
# Triton patch to get around NVVM ReductionKind compatibility issue
TRITON_NVVM_COMPATIBILITY_PATCH_FILE="$HEXAGON_MLIR_ROOT/third_party_software/patches/triton/nvvm_reduction_kind_compatibility.patch"
apply_patch_if_needed "$TRITON_ROOT" "$TRITON_NVVM_COMPATIBILITY_PATCH_FILE"

# Add libdevice sigmoid support to triton
TRITON_LIBDEVICE_SIGMOID_PATCH_FILE="$HEXAGON_MLIR_ROOT/third_party_software/patches/triton/libdevice_sigmoid.patch"
apply_patch_if_needed "$TRITON_ROOT" "$TRITON_LIBDEVICE_SIGMOID_PATCH_FILE"
2 changes: 1 addition & 1 deletion ci/hexagon-mlir-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ pybind11
scipy
lit
wheel
transformers==4.52.4
transformers==4.52.4
58 changes: 39 additions & 19 deletions ci/setup_submodules.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,49 @@
# SPDX-License-Identifier: BSD-3-Clause.
# For more license information:
# https://github.com/qualcomm/hexagon-mlir/LICENSE.txt
#


set -euo pipefail

REPO_ROOT="$(git rev-parse --show-toplevel)"
echo "Configuring git for submodules"
echo "Configuring git submodules"
cd "${REPO_ROOT}"
# Add submodules if missing
if [ ! -d "triton" ]; then
git submodule add --force https://github.com/triton-lang/triton.git triton
cd triton
echo "Applying qcom specific patches to triton"
git checkout e44bd1c83c1c3e8deac7c4f02683cfb3cc395c8b
git apply "${REPO_ROOT}/third_party_software/patches/triton/third_party_triton.patch"
fi

cd "${REPO_ROOT}"
if [ ! -d "triton_shared" ]; then
git submodule add --force https://github.com/microsoft/triton-shared triton_shared
cd triton_shared
git checkout 2b728ad97bc02af821a0805b09075838911d4c19
echo "Applying qcom specific patches to triton_shared"
git apply "${REPO_ROOT}/third_party_software/patches/triton_shared/max_with_nan_propagation.patch"
git apply "${REPO_ROOT}/third_party_software/patches/triton_shared/tt_shared_split_dim.patch"
fi
# Ensure existing submodules are initialized
git submodule update --init

add_and_checkout() {
local name="$1"
local url="$2"
local commit="$3"

cd "${REPO_ROOT}"
if [ ! -d "${REPO_ROOT}/${name}" ]; then
echo "Adding submodule ${name}"
git submodule add --force "${url}" "${name}"
fi

echo "Checking out ${name} at ${commit}"
cd "${REPO_ROOT}/${name}"
git fetch origin
git checkout "${commit}"
}

add_and_checkout \
triton \
https://github.com/triton-lang/triton.git \
df38505e451a1541555379bcf378be9e8c00545c

add_and_checkout \
triton_shared \
https://github.com/facebookincubator/triton-shared.git \
0614763d270ec0eacba9d5d8283cdff6bedb03c8

cd "${REPO_ROOT}"
echo "Applying qcom specific patches to triton_shared"
bash "${REPO_ROOT}/ci/apply_patches.sh" || {
echo "ERROR: Failed while applying patches"
exit 1
}

echo "Submodules triton and triton_shared initialized and patched successfully."
72 changes: 58 additions & 14 deletions qcom_hexagon_backend/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import os
import re
import tempfile
from dataclasses import replace
from pathlib import Path
import subprocess
from typing import Any, Dict, no_type_check
Expand All @@ -30,21 +31,36 @@
# mypy compiler.py hexagon_executor.py hexagon_launcher_base.py torch_mlir_hexagon_launcher.py triton_hexagon_launcher.py --follow-untyped-imports --check-untyped-defs


def _get_triton_shared_opt_path() -> str:
path = os.getenv("TRITON_SHARED_OPT_PATH", "")
def _get_triton_shared_opt_path(device_type: str) -> str:
path = os.getenv(
"TRITON_SHARED_OPT_PATH",
"",
)
if path == "":
raise Exception("TRITON_SHARED_OPT_PATH is not set.")
return path

bin_path = Path(path).resolve()

if not bin_path.exists() or not bin_path.is_file():
raise FileNotFoundError(
f"Could not find 'triton-shared-opt' at expected location: {bin_path}"
)
if not os.access(bin_path, os.X_OK):
raise PermissionError(
f"'triton-shared-opt' exists but is not executable: {bin_path}"
)

return str(bin_path)

def ttir_to_ttsharedir(mod):

def ttir_to_ttsharedir(mod: str, options):
# Get Triton-MLIR as string
ttir_code = str(mod)
with tempfile.TemporaryDirectory() as tmpdir:
src_path = os.path.join(tmpdir, "tt.mlir")
dst_path = os.path.join(tmpdir, "ttshared.mlir")
Path(src_path).write_text(ttir_code)
triton_shared_opt_path = _get_triton_shared_opt_path()
triton_shared_opt_path = _get_triton_shared_opt_path(options.device_type)
subprocess.check_call(
[
triton_shared_opt_path,
Expand Down Expand Up @@ -90,6 +106,10 @@ def ttsharedir_to_obj(mod: str, options, metadata={}) -> bytes:
options_map = {k: str(v) for k, v in (options.__dict__).items()}
# TODO: Move setting benchmarking iterations when additional stage for shared object creation is part of compilation pipeline.
metadata["iterations"] = options_map["iterations"]
metadata["scratch"] = options_map["scratch"]
metadata["enableMultiThreading"] = options_map["enableMultiThreading"]
metadata["enableThreadedDispatch"] = options_map["enableThreadedDispatch"]
metadata["enableLWP"] = options_map["enableLWP"]

# TODO: The lowering pipeline needs to be refactored similar to other Triton backends to
# have a dynamic pipeline filtered by options with each pass represented by a pybind function.
Expand Down Expand Up @@ -138,13 +158,31 @@ def hash(self):
return f"{version}-{self.target}"

def parse_options(self, opts) -> Any:
assert self.target.backend == "hexagon"
args = {
k: opts[k]
for k in HexagonOptions.__dataclass_fields__.keys()
if k in opts
}
return HexagonOptions(**args)
assert self.target.backend == "hexagon"
args = {
k: opts[k] for k in HexagonOptions.__dataclass_fields__.keys() if k in opts
}
hexagon_opts = HexagonOptions(**args)

# When external VTCM scratch is enabled (scratch > 0), automatically
# configure flags for correct SPMD behavior at compile time so that
# both the compiled IR and the generated wrapper are consistent:
# - Disable enableConvertToHexagonmem to prevent VTCMPool from
# concurrently trying to allocate VTCM alongside the external pool.
# - Disable enableHexagonmemCopyToDMA to avoid DMA/thread conflicts.
# - Enable enableThreadedDispatch so instances run in parallel on
# real qurt hardware threads (handled at wrapper generation time).
# Users do not need to set these flags manually when scratch > 0.
if hexagon_opts.scratch > 0:
hexagon_opts = replace(
hexagon_opts,
enableMultiThreading=False,
enableConvertToHexagonmem=False,
enableVTCMTiling=True,
enableThreadedDispatch=True,
)

return hexagon_opts

@staticmethod
def make_ttir(mod, metadata, opt):
Expand All @@ -161,13 +199,13 @@ def make_ttir(mod, metadata, opt):
passes.common.add_symbol_dce(pm)
passes.ttir.add_loop_unroll(pm)
passes.common.add_cse(pm)
pm.run(mod)
pm.run(mod, "make_ttir")
return mod

# May need to add num_warps
def add_stages(self, stages, options, language):
stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options)
stages["ttsharedir"] = lambda src, metadata: ttir_to_ttsharedir(src)
stages["ttsharedir"] = lambda src, metadata: ttir_to_ttsharedir(src, options)
if options.htp_kernel_gen:
if options.target_artifact == "llir":
stages["llir"] = lambda src, metadata: ttsharedir_to_llir(
Expand All @@ -190,6 +228,8 @@ def add_stages(self, stages, options, language):
src, options, metadata
)
else: # Default compilation pipeline
assert options.device_type == "hexagon"

stages["o"] = lambda src, metadata: ttsharedir_to_obj(
src, options, metadata
)
Expand Down Expand Up @@ -240,6 +280,10 @@ def pack_metadata(self, metadata):
metadata.name,
metadata.return_types,
metadata.iterations,
metadata.scratch,
metadata.enableMultiThreading,
metadata.enableThreadedDispatch,
metadata.enableLWP,
)

def get_module_map(self) -> Dict[str, ModuleType]:
Expand Down
36 changes: 35 additions & 1 deletion qcom_hexagon_backend/backend/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,14 @@ def __call__(self, *args, **kwargs):
"_mlir_ciface_" if len(return_profs) > 0 else ""
) + pack_metadata[6]
iterations = pack_metadata[8]
compiled_scratch = pack_metadata[9] if len(pack_metadata) > 9 else None
compiled_enable_multithreading = (
pack_metadata[10] if len(pack_metadata) > 10 else None
)
compiled_enable_threaded_dispatch = (
pack_metadata[11] if len(pack_metadata) > 11 else None
)
compiled_enable_lwp = pack_metadata[12] if len(pack_metadata) > 12 else None
num_fixed_args = 9
inputs_with_constants = list(args[num_fixed_args:])
inputs = [
Expand All @@ -65,7 +73,17 @@ def __call__(self, *args, **kwargs):
"""
)
self.launcher._exec_kernel(
kernel_llir, iterations, func_name, inputs, return_profs, launch_grid
kernel_llir,
iterations,
func_name,
inputs,
return_profs,
launch_grid,
compiled_scratch=compiled_scratch,
compiled_enable_multithreading=compiled_enable_multithreading,
compiled_enable_threaded_dispatch=compiled_enable_threaded_dispatch,
compiled_enable_lwp=compiled_enable_lwp,
runtime_options=kwargs,
)
# TODO: There seems to be no way to propogate the call returns upward, because
# - The call result is not used by the caller
Expand Down Expand Up @@ -132,3 +150,19 @@ def get_active_torch_device(self):

def get_current_stream(self, device):
return None

def get_device_interface(self):
import torch

return torch.cpu

def get_empty_cache_for_benchmark(self):
import torch

device = "cpu"
# 256MB cache
cache_size = 256 * 1024 * 1024
return torch.empty(int(cache_size // 4), dtype=torch.int, device=device)

def clear_cache(self, cache):
cache.zero_()
Loading
Loading