Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions loopy/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,7 @@ def map_constant(self, expr: object) -> bool:

def map_variable(self, expr: p.Variable) -> bool:
if expr.name == self.vec_iname:
# Technically, this is doable. But we're not going there.
raise UnvectorizableError()

return True
# A single variable is always a scalar.
return False

Expand Down
15 changes: 15 additions & 0 deletions loopy/symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,21 @@
s: str


@p.expr_dataclass()
class TypedLiteral(Literal):
"""A literal to be used during code generation which we know the type of.

.. note::

Only used in the output of
:mod:`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper` (and
similar mappers). Not for use in Loopy source representation.
"""

s: str
dtype: ToLoopyTypeConvertible


@p.expr_dataclass()
class ArrayLiteral(LoopyExpressionBase):
"""An array literal.
Expand Down Expand Up @@ -1740,7 +1755,7 @@
# pstate.expect(_colon):
pstate.advance()
subscript = self.parse_expression(pstate, _PREC_UNARY)
return SubArrayRef(swept_inames, subscript)

Check warning on line 1758 in loopy/symbolic.py

View workflow job for this annotation

GitHub Actions / Conda Pytest with Intel CL

swept_inames passed to SubArrayRef was not a tuple. This is deprecated and will stop working in 2025. Pass a tuple instead.

Check warning on line 1758 in loopy/symbolic.py

View workflow job for this annotation

GitHub Actions / Conda Pytest with Intel CL

swept_inames passed to SubArrayRef was not a tuple. This is deprecated and will stop working in 2025. Pass a tuple instead.

Check warning on line 1758 in loopy/symbolic.py

View workflow job for this annotation

GitHub Actions / Conda Pytest with Intel CL

swept_inames passed to SubArrayRef was not a tuple. This is deprecated and will stop working in 2025. Pass a tuple instead.

Check warning on line 1758 in loopy/symbolic.py

View workflow job for this annotation

GitHub Actions / Conda Pytest with Intel CL

swept_inames passed to SubArrayRef was not a tuple. This is deprecated and will stop working in 2025. Pass a tuple instead.

Check warning on line 1758 in loopy/symbolic.py

View workflow job for this annotation

GitHub Actions / Conda Pytest with Intel CL

swept_inames passed to SubArrayRef was not a tuple. This is deprecated and will stop working in 2025. Pass a tuple instead.

Check warning on line 1758 in loopy/symbolic.py

View workflow job for this annotation

GitHub Actions / Conda Pytest with Intel CL

swept_inames passed to SubArrayRef was not a tuple. This is deprecated and will stop working in 2025. Pass a tuple instead.

Check warning on line 1758 in loopy/symbolic.py

View workflow job for this annotation

GitHub Actions / Conda Pytest (ubuntu-latest)

swept_inames passed to SubArrayRef was not a tuple. This is deprecated and will stop working in 2025. Pass a tuple instead.

Check warning on line 1758 in loopy/symbolic.py

View workflow job for this annotation

GitHub Actions / Conda Pytest (ubuntu-latest)

swept_inames passed to SubArrayRef was not a tuple. This is deprecated and will stop working in 2025. Pass a tuple instead.

Check warning on line 1758 in loopy/symbolic.py

View workflow job for this annotation

GitHub Actions / Conda Pytest (ubuntu-latest)

swept_inames passed to SubArrayRef was not a tuple. This is deprecated and will stop working in 2025. Pass a tuple instead.

Check warning on line 1758 in loopy/symbolic.py

View workflow job for this annotation

GitHub Actions / Conda Pytest (ubuntu-latest)

swept_inames passed to SubArrayRef was not a tuple. This is deprecated and will stop working in 2025. Pass a tuple instead.

Check warning on line 1758 in loopy/symbolic.py

View workflow job for this annotation

GitHub Actions / Conda Pytest (ubuntu-latest)

swept_inames passed to SubArrayRef was not a tuple. This is deprecated and will stop working in 2025. Pass a tuple instead.

Check warning on line 1758 in loopy/symbolic.py

View workflow job for this annotation

GitHub Actions / Conda Pytest (ubuntu-latest)

swept_inames passed to SubArrayRef was not a tuple. This is deprecated and will stop working in 2025. Pass a tuple instead.

Check warning on line 1758 in loopy/symbolic.py

View workflow job for this annotation

GitHub Actions / Conda Pytest (macos-latest)

swept_inames passed to SubArrayRef was not a tuple. This is deprecated and will stop working in 2025. Pass a tuple instead.

Check warning on line 1758 in loopy/symbolic.py

View workflow job for this annotation

GitHub Actions / Conda Pytest (macos-latest)

swept_inames passed to SubArrayRef was not a tuple. This is deprecated and will stop working in 2025. Pass a tuple instead.

Check warning on line 1758 in loopy/symbolic.py

View workflow job for this annotation

GitHub Actions / Conda Pytest (macos-latest)

swept_inames passed to SubArrayRef was not a tuple. This is deprecated and will stop working in 2025. Pass a tuple instead.

Check warning on line 1758 in loopy/symbolic.py

View workflow job for this annotation

GitHub Actions / Conda Pytest (macos-latest)

swept_inames passed to SubArrayRef was not a tuple. This is deprecated and will stop working in 2025. Pass a tuple instead.

Check warning on line 1758 in loopy/symbolic.py

View workflow job for this annotation

GitHub Actions / Conda Pytest (macos-latest)

swept_inames passed to SubArrayRef was not a tuple. This is deprecated and will stop working in 2025. Pass a tuple instead.

Check warning on line 1758 in loopy/symbolic.py

View workflow job for this annotation

GitHub Actions / Conda Pytest (macos-latest)

swept_inames passed to SubArrayRef was not a tuple. This is deprecated and will stop working in 2025. Pass a tuple instead.

Check warning on line 1758 in loopy/symbolic.py

View workflow job for this annotation

GitHub Actions / Conda Pytest without arg check

swept_inames passed to SubArrayRef was not a tuple. This is deprecated and will stop working in 2025. Pass a tuple instead.

Check warning on line 1758 in loopy/symbolic.py

View workflow job for this annotation

GitHub Actions / Conda Pytest without arg check

swept_inames passed to SubArrayRef was not a tuple. This is deprecated and will stop working in 2025. Pass a tuple instead.

Check warning on line 1758 in loopy/symbolic.py

View workflow job for this annotation

GitHub Actions / Conda Pytest without arg check

swept_inames passed to SubArrayRef was not a tuple. This is deprecated and will stop working in 2025. Pass a tuple instead.

Check warning on line 1758 in loopy/symbolic.py

View workflow job for this annotation

GitHub Actions / Conda Pytest without arg check

swept_inames passed to SubArrayRef was not a tuple. This is deprecated and will stop working in 2025. Pass a tuple instead.

Check warning on line 1758 in loopy/symbolic.py

View workflow job for this annotation

GitHub Actions / Conda Pytest without arg check

swept_inames passed to SubArrayRef was not a tuple. This is deprecated and will stop working in 2025. Pass a tuple instead.

Check warning on line 1758 in loopy/symbolic.py

View workflow job for this annotation

GitHub Actions / Conda Pytest without arg check

swept_inames passed to SubArrayRef was not a tuple. This is deprecated and will stop working in 2025. Pass a tuple instead.

Check warning on line 1758 in loopy/symbolic.py

View workflow job for this annotation

GitHub Actions / Conda Pytest Twice (for cache behavior)

swept_inames passed to SubArrayRef was not a tuple. This is deprecated and will stop working in 2025. Pass a tuple instead.

Check warning on line 1758 in loopy/symbolic.py

View workflow job for this annotation

GitHub Actions / Conda Pytest Twice (for cache behavior)

swept_inames passed to SubArrayRef was not a tuple. This is deprecated and will stop working in 2025. Pass a tuple instead.

Check warning on line 1758 in loopy/symbolic.py

View workflow job for this annotation

GitHub Actions / Conda Pytest Twice (for cache behavior)

swept_inames passed to SubArrayRef was not a tuple. This is deprecated and will stop working in 2025. Pass a tuple instead.

Check warning on line 1758 in loopy/symbolic.py

View workflow job for this annotation

GitHub Actions / Conda Pytest Twice (for cache behavior)

swept_inames passed to SubArrayRef was not a tuple. This is deprecated and will stop working in 2025. Pass a tuple instead.

Check warning on line 1758 in loopy/symbolic.py

View workflow job for this annotation

GitHub Actions / Conda Pytest Twice (for cache behavior)

swept_inames passed to SubArrayRef was not a tuple. This is deprecated and will stop working in 2025. Pass a tuple instead.

Check warning on line 1758 in loopy/symbolic.py

View workflow job for this annotation

GitHub Actions / Conda Pytest Twice (for cache behavior)

swept_inames passed to SubArrayRef was not a tuple. This is deprecated and will stop working in 2025. Pass a tuple instead.
else:
pstate = rollback_pstate
return super().parse_prefix(rollback_pstate)
Expand Down
17 changes: 9 additions & 8 deletions loopy/target/c/codegen/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
from loopy.expression import dtype_to_type_context
from loopy.target.c import CExpression
from loopy.type_inference import TypeInferenceMapper, TypeReader
from loopy.types import LoopyType
from loopy.types import LoopyType, to_loopy_type
from loopy.typing import Expression, is_integer


Expand Down Expand Up @@ -435,7 +435,7 @@ def map_type_cast(self, expr: TypeCast, type_context: str):
return self.rec(expr.child, type_context, expr.type)

def map_constant(self, expr, type_context):
from loopy.symbolic import Literal
from loopy.symbolic import TypedLiteral

if isinstance(expr, (complex, np.complexfloating)):
real = self.rec(expr.real, type_context)
Expand All @@ -462,10 +462,10 @@ def map_constant(self, expr, type_context):

# FIXME: This assumes a 32-bit architecture.
if isinstance(expr, np.float32):
return Literal(repr(float(expr))+"f")
return TypedLiteral(repr(float(expr))+"f", to_loopy_type(np.float32))

elif isinstance(expr, np.float64):
return Literal(repr(float(expr)))
return TypedLiteral(repr(float(expr)), to_loopy_type(np.float64))

# Disabled for now, possibly should be a subtarget.
# elif isinstance(expr, np.float128):
Expand All @@ -478,18 +478,19 @@ def map_constant(self, expr, type_context):
suffix += "u"
if iinfo.max > (2**31-1):
suffix += "l"
return Literal(repr(int(expr))+suffix)
return TypedLiteral(repr(int(expr))+suffix, to_loopy_type(iinfo.dtype))
elif isinstance(expr, np.bool_):
return Literal("true") if expr else Literal("false")
return TypedLiteral("true", to_loopy_type(np.bool_)) if expr \
else TypedLiteral("false", to_loopy_type(np.bool_))
else:
raise LoopyError("do not know how to generate code for "
"constant of numpy type '%s'" % type(expr).__name__)

elif np.isfinite(expr):
if type_context == "f":
return Literal(repr(float(expr))+"f")
return TypedLiteral(repr(float(expr))+"f", to_loopy_type(np.float32))
elif type_context == "d":
return Literal(repr(float(expr)))
return TypedLiteral(repr(float(expr)), to_loopy_type(np.float64))
elif type_context in ["i", "b"]:
return int(expr)
else:
Expand Down
7 changes: 4 additions & 3 deletions loopy/target/ispc.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,14 @@
CoefficientCollector,
CombineMapper,
GroupHardwareAxisIndex,
Literal,
LocalHardwareAxisIndex,
SubstitutionMapper,
TypedLiteral,
flatten,
)
from loopy.target.c import CFamilyASTBuilder, CFamilyTarget
from loopy.target.c.codegen.expression import ExpressionToCExpressionMapper
from loopy.types import to_loopy_type


if TYPE_CHECKING:
Expand Down Expand Up @@ -125,10 +126,10 @@ def map_constant(self, expr, type_context):
raise NotImplementedError("complex numbers in ispc")
else:
if type_context == "f":
return Literal(repr(float(expr)))
return TypedLiteral(repr(float(expr)), to_loopy_type(np.float32))
elif type_context == "d":
# Keepin' the good ideas flowin' since '66.
return Literal(repr(float(expr))+"d")
return TypedLiteral(repr(float(expr))+"d", to_loopy_type(np.float64))
elif type_context in ["i", "b"]:
return expr
else:
Expand Down
94 changes: 93 additions & 1 deletion loopy/target/opencl.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
THE SOFTWARE.
"""

from contextlib import suppress
from typing import TYPE_CHECKING, Literal, Sequence

import numpy as np
Expand All @@ -46,6 +47,7 @@

from loopy.codegen import CodeGenerationState
from loopy.codegen.result import CodeGenerationResult
from loopy.kernel import LoopKernel


# {{{ dtype registry wrappers
Expand Down Expand Up @@ -456,7 +458,8 @@ def get_opencl_callables():

# {{{ symbol mangler

def opencl_symbol_mangler(kernel, name):
def opencl_symbol_mangler(kernel: LoopKernel,
name: str) -> tuple[NumpyType, str] | None:
# FIXME: should be more picky about exact names
if name.startswith("FLT_"):
return NumpyType(np.dtype(np.float32)), name
Expand Down Expand Up @@ -540,11 +543,32 @@ def opencl_preamble_generator(preamble_info):
class ExpressionToOpenCLCExpressionMapper(ExpressionToCExpressionMapper):

def wrap_in_typecast(self, actual_type, needed_dtype, s):

if needed_dtype.dtype.kind == "b" and actual_type.dtype.kind == "f":
# CL does not perform implicit conversion from float-type to a bool.
from pymbolic.primitives import Comparison
return Comparison(s, "!=", 0)

if needed_dtype == actual_type:
return s

registry = self.codegen_state.ast_builder.target.get_dtype_registry()
if self.codegen_state.target.is_vector_dtype(needed_dtype):
# OpenCL does not let you do explicit vector type casts between vector
# types. Instead you need to call their function which is of the form
# <desttype> convert_<desttype><n>(src) where n
# is the number of elements in the vector which is the same as in src.
# https://registry.khronos.org/OpenCL/specs/3.0-unified/html/OpenCL_C.html#explicit-casts

# We infer the data type of (s) before we recurse down into (s) to convert
# to a CExpression. With vectorization, we can change the actual type of (s)
# from a scalar type to a vector type. So we are going to recompute the
# actual type.
type_of_s = self.infer_type(s)
if self.codegen_state.target.is_vector_dtype(type_of_s):
cast = var("convert_%s" % registry.dtype_to_ctype(needed_dtype))
return cast(s)

return super().wrap_in_typecast(actual_type, needed_dtype, s)

def map_group_hw_index(self, expr, type_context):
Expand All @@ -553,6 +577,74 @@ def map_group_hw_index(self, expr, type_context):
def map_local_hw_index(self, expr, type_context):
return var("lid")(expr.axis)

def map_variable(self, expr, type_context):

if self.codegen_state.vectorization_info:
if self.codegen_state.vectorization_info.iname == expr.name:
# This needs to be converted into a vector literal.
from loopy.symbolic import TypedLiteral
vector_length = self.codegen_state.vectorization_info.length
index_type = self.codegen_state.kernel.index_dtype
vector_type = self.codegen_state.target.vector_dtype(index_type,
vector_length)
typename = self.codegen_state.target.dtype_to_typename(vector_type)
vector_literal = f"(({typename})" + " (" + \
",".join([f"{i}" for i in range(vector_length)]) + "))"
return TypedLiteral(vector_literal, vector_type)

# return Literal(vector_literal)
return super().map_variable(expr, type_context)

def map_if(self, expr, type_context):
from loopy.types import to_loopy_type
result_type = self.infer_type(expr)
conditional_needed_loopy_type = to_loopy_type(np.bool_)
if self.codegen_state.vectorization_info:
from loopy.codegen import UnvectorizableError
from loopy.expression import VectorizabilityChecker
checker = VectorizabilityChecker(self.codegen_state.kernel,
self.codegen_state.vectorization_info.iname,
self.codegen_state.vectorization_info.length)

with suppress(UnvectorizableError):
# We know there is an expression in codegen which can be vectorized.
# We are checking if this is one of the them. If it is not, then we can
# just continue with scalar code generation for this expression.
is_vector = checker(expr)

if is_vector:
"""
We could have a vector literal here which may need to be
converted to an appropriate size. The OpenCL specification states
that for ( c ? a : b) a, b, and c must have the same
number of elements and bits and that c must be an integral type.
https://registry.khronos.org/OpenCL/specs/3.0-unified/html/OpenCL_C.html#table-builtin-relational
"""
index_type = to_loopy_type(self.codegen_state.kernel.index_dtype)
types = {8: to_loopy_type(np.int64), 4: to_loopy_type(np.int32),
2: to_loopy_type(np.int16), 1: to_loopy_type(np.int8)}
length = self.codegen_state.vectorization_info.length
if self.codegen_state.target.is_vector_dtype(result_type):
if (index_type.itemsize != result_type.itemsize and
(result_type.itemsize // length) in types):
index_type = types[result_type.itemsize]
else:
raise ValueError("Types incompatible")
else:
# We know result is going to be a vector.
if (index_type.itemsize != result_type.itemsize and
result_type.itemsize in types):
index_type = types[result_type.itemsize]
vector_type = self.codegen_state.target.vector_dtype(index_type,
length)
conditional_needed_loopy_type = to_loopy_type(vector_type)

return type(expr)(
self.rec(expr.condition, type_context,
conditional_needed_loopy_type),
self.rec(expr.then, type_context, result_type),
self.rec(expr.else_, type_context, result_type),
)
# }}}


Expand Down
33 changes: 29 additions & 4 deletions loopy/type_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
SubArrayRef,
SubstitutionRuleExpander,
SubstitutionRuleMappingContext,
TypedLiteral,
parse_tagged_name,
)
from loopy.translation_unit import (
Expand Down Expand Up @@ -365,6 +366,9 @@ def map_quotient(self, expr):
else:
return self.combine([n_dtype_set, d_dtype_set])

def map_typed_literal(self, expr: TypedLiteral):
return [expr.dtype]

def map_constant(self, expr):
if isinstance(expr, np.generic):
return [NumpyType(np.dtype(type(expr)))]
Expand Down Expand Up @@ -540,19 +544,40 @@ def map_lookup(self, expr):
dtype = field[0]
return [NumpyType(dtype)]

def is_vector_dtype(self, dtype):
target = self.kernel.target

return target.is_vector_dtype(dtype)

def map_comparison(self, expr):
self(expr.left, return_tuple=False, return_dtype_set=False)
self(expr.right, return_tuple=False, return_dtype_set=False)
left = self(expr.left, return_tuple=False, return_dtype_set=False)
right = self(expr.right, return_tuple=False, return_dtype_set=False)
# We need to return a vector type if we either of the sides is a vector.

vector_output = []
for dtype in (left, right):
if self.is_vector_dtype(dtype):
vector_output.append(dtype)
if vector_output:
return vector_output
return [NumpyType(np.dtype(np.bool_))]

def map_logical_not(self, expr):
self.rec(expr.child)
child = self.rec(expr.child)
if self.is_vector_dtype(child):
return child

return [NumpyType(np.dtype(np.bool_))]

def map_logical_and(self, expr):
output_type = []
for child in expr.children:
self.rec(child)
type_to_check = self.rec(child)
if self.is_vector_dtype(type_to_check):
output_type.append(type_to_check)

if output_type:
return output_type

return [NumpyType(np.dtype(np.bool_))]

Expand Down
29 changes: 29 additions & 0 deletions test/test_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -875,6 +875,35 @@ def test_float3():
assert "float3" in device_code


def test_cl_vectorize_index_variable(ctx_factory):
knl = lp.make_kernel(
"{ [i]: 0<=i<n }",
"""
b[i] = a[i]*3 if i < 32 else sin(a[i])
""")

knl = lp.split_array_axis(knl, "a,b", 0, 4)
knl = lp.split_iname(knl, "i", 4)
knl = lp.tag_inames(knl, {"i_inner": "vec", "i_outer": "for"})
knl = lp.tag_array_axes(knl, "a,b", "c,vec")
knl = lp.set_options(knl, write_code=True)
knl = lp.assume(knl, "n % 4 = 0 and n>0")

rng = np.random.default_rng(seed=12)
a = rng.normal(size=(16, 4))
ctx = ctx_factory()
queue = cl.CommandQueue(ctx)
knl = lp.add_and_infer_dtypes(knl, {"a": np.float64, "n": np.int64})
_evt, (result,) = knl(queue, a=a, n=a.size)

i = np.arange(16)
j = np.arange(4)
ind = 4*i[:, None] + j
result_ref = np.where(ind < 32, a*3, np.sin(a))

assert np.allclose(result, result_ref)


if __name__ == "__main__":
import sys
if len(sys.argv) > 1:
Expand Down
Loading