DriverTrac/venv/lib/python3.12/site-packages/jax/experimental/mosaic/gpu/equations.py

1007 lines
33 KiB
Python

# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Defines expressions and equations over layouts."""
# mypy has been causing more problems than it solves here. Disable it for these
# files. We have pytype checks anyway.
# mypy: ignore-errors
from __future__ import annotations
import abc
from collections.abc import Sequence
import dataclasses
import math
from typing import Any, Callable, assert_never, final
from . import fragmented_array as fa
from . import launch_context as lc
from . import layouts as layouts_lib
from . import inference_utils
from . import tcgen05
VariableKey = Any
@dataclasses.dataclass(frozen=True)
class Variable:
"""A variable is an abstract identifier.
`key` is supposed to be hashable.
"""
key: VariableKey
def __str__(self):
return f"V({self.key})"
class Constant(abc.ABC):
"""A constant is a known layout."""
@dataclasses.dataclass(frozen=True)
class RegisterLayout(Constant):
"""Wraps a known register layout."""
value: fa.FragmentedLayout
def __str__(self):
return f"C({self.value})"
@dataclasses.dataclass(frozen=True)
class TMEMLayout(Constant):
"""Wraps a known TMEM layout."""
value: tcgen05.TMEMLayout
def __str__(self):
return f"C({self.value})"
@dataclasses.dataclass(frozen=True)
class SMEMTiling(Constant):
"""Wraps a known SMEM Tile Transform.
If an SMEM reference may, in principle, have transforms but should not be
tiled, then `value` is `None`.
"""
value: lc.TileTransform | None
def __str__(self):
return f"C({self.value})"
@dataclasses.dataclass(frozen=True)
class LeastReplicated:
expressions: tuple[Expression, ...]
def __post_init__(self):
assert len(self.expressions) >= 1
@dataclasses.dataclass(frozen=True)
class MostReplicated:
expressions: tuple[Expression, ...]
def __post_init__(self):
assert len(self.expressions) >= 1
@dataclasses.dataclass(frozen=True)
class Reduce:
expression: Expression
axes: tuple[int, ...]
def __str__(self):
return f"Reduce([{self.axes}], {self.expression})"
@dataclasses.dataclass(frozen=True)
class BroadcastInDim:
expression: Expression
axes: tuple[int, ...]
shape: tuple[int, ...]
@dataclasses.dataclass(frozen=True)
class Reshape:
expression: Expression
source_shape: tuple[int, ...]
target_shape: tuple[int, ...]
@dataclasses.dataclass(frozen=True)
class Transpose:
expression: Expression
def __str__(self):
return f"T({self.expression})"
Expression = (
Variable
| Constant
| LeastReplicated
| MostReplicated
| Reduce
| BroadcastInDim
| Reshape
| Transpose
)
def reduce_replicated_expression(
input_expr: LeastReplicated | MostReplicated,
assignments: dict[Variable, Constant],
reducer: Callable[[fa.FragmentedLayout, fa.FragmentedLayout], fa.FragmentedLayout | None]
) -> Expression | Unsatisfiable:
assert input_expr.expressions
new_expressions: list[Expression] = []
# Use a set to eliminate duplicates, but preserve the order.
seen: set[Expression] = set()
for expr in input_expr.expressions:
reduced_expr = reduce_expression(expr, assignments)
if isinstance(reduced_expr, Unsatisfiable):
return Unsatisfiable()
if reduced_expr in seen:
continue
new_expressions.append(reduced_expr)
seen.add(reduced_expr)
if len(new_expressions) == 1:
return new_expressions[0]
consts = []
unknowns = []
for e in new_expressions:
if not isinstance(e, Constant):
unknowns.append(e)
continue
if not isinstance(e, RegisterLayout):
raise ValueError(
f"Reduction of non-register layout constant is not supported: {e}"
)
consts.append(e)
if consts:
const_red, *consts = consts
red = const_red
for cst in consts:
red_value = reducer(red.value, cst.value)
if red_value is None:
# The layouts are not compatible up to replication, this expression
# cannot be simplified.
return Unsatisfiable()
red = RegisterLayout(red_value)
else:
red = None
constructor = type(input_expr)
if red is not None:
if unknowns:
return constructor((red, *unknowns))
return red
return constructor(tuple(unknowns))
def reduce_broadcast_expression(
broadcast: BroadcastInDim, assignments: dict[Variable, Constant]
) -> Expression | Unsatisfiable:
def _check_shape_broadcast(shape: tuple[int, ...]) -> bool:
for axis, s in zip(broadcast.axes, shape, strict=True):
if broadcast.shape[axis] != s:
return False
return True
reduced_expr = reduce_expression(broadcast.expression, assignments)
match reduced_expr:
case Unsatisfiable():
return Unsatisfiable()
case RegisterLayout(value=layout):
match layout:
case fa.WGSplatFragLayout(shape=shape):
if not _check_shape_broadcast(shape):
return Unsatisfiable()
return RegisterLayout(fa.WGSplatFragLayout(shape=broadcast.shape))
case _:
return BroadcastInDim(
expression=reduced_expr,
axes=broadcast.axes,
shape=broadcast.shape,
)
case _:
return BroadcastInDim(
expression=reduced_expr, axes=broadcast.axes, shape=broadcast.shape
)
def reduce_reshape_expression(
reshape: Reshape, assignments: dict[Variable, Constant]
) -> Expression | Unsatisfiable:
reduced_expr = reduce_expression(reshape.expression, assignments)
match reduced_expr:
case Unsatisfiable():
return Unsatisfiable()
case RegisterLayout(value=layout):
match layout:
case fa.WGSplatFragLayout(shape=shape):
assert math.prod(shape) == math.prod(reshape.target_shape)
return RegisterLayout(
fa.WGSplatFragLayout(shape=reshape.target_shape)
)
case fa.WGStridedFragLayout(shape=shape, vec_size=vec_size):
assert math.prod(shape) == math.prod(reshape.target_shape)
return RegisterLayout(
fa.WGStridedFragLayout(
shape=reshape.target_shape, vec_size=vec_size
)
)
case fa.TiledLayout() as tiled_layout:
tile_shape = tiled_layout.base_tile_shape
if len(reshape.target_shape) < len(tile_shape):
return dataclasses.replace(reshape, expression=reduced_expr)
# Even if the new shape is not perfectly tilable, it is possible that
# we may be able to reshape the tiling itself in a way that is
# compatible with the new shape. We do not handle this case at the
# moment.
for ts, s in zip(tile_shape, reshape.source_shape[-len(tile_shape):], strict=True):
if s % ts != 0:
return dataclasses.replace(reshape, expression=reduced_expr)
# If minor tiled dimensions are modified, then reshaping is likely to
# not be a no-op since the strides between tiles will change,
# potentially mapping different elements to lanes and warps. We don't
# attempt to handle this case at the moment.
num_minor_tiled_dims = len(tile_shape) - 1
source_minor_tiled_dims = reshape.source_shape[-num_minor_tiled_dims:]
target_minor_tiled_dims = reshape.target_shape[-num_minor_tiled_dims:]
major_tiled_dim = tile_shape[0]
if (source_minor_tiled_dims != target_minor_tiled_dims or
reshape.target_shape[-len(tile_shape)] % major_tiled_dim != 0):
return dataclasses.replace(reshape, expression=reduced_expr)
# At this point, we now that only non-tiled dimensions and/or the
# majormost tiled dimensions may have changed. We also know that the
# majormost tiled dimension is still tilable in the new shape.
# Therefore, we can return the tiled layout as is.
return RegisterLayout(tiled_layout)
case _:
return dataclasses.replace(reshape, expression=reduced_expr) # pytype: disable=bad-return-type
def reduce_transpose_expression(
transpose: Transpose, assignments: dict[Variable, Constant]
) -> Expression | Unsatisfiable:
reduced_expr = reduce_expression(transpose.expression, assignments)
match reduced_expr:
case Unsatisfiable():
return Unsatisfiable()
case SMEMTiling(value=tile_transform):
if tile_transform is None:
return SMEMTiling(None)
tiling = tile_transform.tiling
if len(tiling) != 2:
raise NotImplementedError(
f"Only 2D tilings are supported, got {len(tiling)}"
)
return SMEMTiling(lc.TileTransform(tiling[::-1]))
case _:
return Transpose(expression=reduced_expr)
def reduce_expression(
expr: Expression, assignments: dict[Variable, Constant]
) -> Expression | Unsatisfiable:
"""Reduces an expression as much as is possible given a set of known variable assignments."""
match expr:
case Constant():
return expr
case Variable():
return assignments.get(expr, expr)
case MostReplicated():
return reduce_replicated_expression(
expr, assignments, layouts_lib.join_layouts
)
case LeastReplicated():
return reduce_replicated_expression(
expr, assignments, layouts_lib.meet_layouts
)
case Reduce(expression=expr, axes=axes):
reduced_expr = reduce_expression(expr, assignments)
match reduced_expr:
case Unsatisfiable():
return Unsatisfiable()
case RegisterLayout(value=layout) if isinstance(layout, fa.TiledLayout):
return RegisterLayout(layout.reduce(axes))
case _:
return Reduce(expression=reduced_expr, axes=axes)
case BroadcastInDim():
return reduce_broadcast_expression(expr, assignments)
case Reshape():
return reduce_reshape_expression(expr, assignments)
case Transpose():
return reduce_transpose_expression(expr, assignments)
case _:
assert_never(expr)
_SUPPORTED_TILED_RELAYOUTS = frozenset([
# Transposed layouts.
(fa.WGMMA_LAYOUT, fa.WGMMA_TRANSPOSED_LAYOUT),
(fa.WGMMA_TRANSPOSED_LAYOUT, fa.WGMMA_LAYOUT),
(fa.TCGEN05_LAYOUT, fa.TCGEN05_TRANSPOSED_LAYOUT),
(fa.TCGEN05_TRANSPOSED_LAYOUT, fa.TCGEN05_LAYOUT),
# "Conversion-optimized" layouts.
(fa.WGMMA_LAYOUT_UPCAST_2X, fa.WGMMA_LAYOUT),
(fa.WGMMA_LAYOUT_UPCAST_4X, fa.WGMMA_LAYOUT_UPCAST_2X),
(fa.WGMMA_LAYOUT_UPCAST_4X, fa.WGMMA_LAYOUT),
])
@dataclasses.dataclass(frozen=True)
class Relayout:
"""States that `source` must be relayout-able to `target`.
Relayout-ability here is not defined as a fundamental property of layouts, but
rather a reflection of our implementation. For instance, when evaluating this
constraint, we will return `False` systematically if a relayout exists but we
do not ever plan to support it.
Modeling this constraint this way is helpful, in order to allow pruning
inefficient solutions when attempting to solve an equation system.
"""
source: Expression
target: Expression
def holds(self) -> bool | None:
"""Returns whether the relayout constraint holds.
Returns `None` if the constraint can't be checked.
"""
source = self.source
target = self.target
# Fast path for syntactically identical expressions.
if source == target:
return True
if not isinstance(source, RegisterLayout) or not isinstance(
target, RegisterLayout
):
return None
source_layout, target_layout = source.value, target.value
match source_layout, target_layout:
case fa.WGSplatFragLayout() as splat, fa.WGStridedFragLayout() as strided:
return splat.shape == strided.shape
case fa.WGSplatFragLayout(), fa.TiledLayout():
return layouts_lib.splat_is_compatible_with_tiled(
source_layout, target_layout
)
case fa.TiledLayout(), fa.TiledLayout():
return (source_layout, target_layout) in _SUPPORTED_TILED_RELAYOUTS
case _:
return False
def __str__(self):
return f"Relayout({self.source}{self.target})"
@dataclasses.dataclass(frozen=True)
class IsTransferable:
"""States that `source` layout must be transferable across memory spaces to `target` layout."""
source: Expression
target: Expression
# TODO(allanrenucci): Can this be derived from the layouts?
shape: tuple[int, ...]
def supported_tmem_transfers(
self, packing: int
) -> set[tuple[tcgen05.TMEMLayout, fa.FragmentedLayout]]:
"""Returns the set of supported TMEM <-> Register transfers."""
assert len(self.shape) == 2
columns = self.shape[1]
tmem_default_layout = tcgen05.tmem_default_layout(packing)
return {
(tmem_default_layout, fa.TCGEN05_LAYOUT),
(tmem_default_layout, fa.TMEM_NATIVE_LAYOUT),
(tcgen05.tmem_half_lane_layout(columns, packing), fa.WGMMA_LAYOUT),
(
tcgen05.tmem_m64_collective_layout(columns, packing),
tcgen05.fa_m64_collective_layout(columns),
),
}
def _is_valid_tmem_transfer(
self, tmem_layout: tcgen05.TMEMLayout, reg_layout: fa.FragmentedLayout
) -> bool:
packing = tmem_layout.vector_length
return (tmem_layout, reg_layout) in self.supported_tmem_transfers(packing)
def _is_valid_smem_transfer(
self,
smem_layout: lc.TileTransform | None,
reg_layout: fa.FragmentedLayout,
) -> bool:
# TODO(b/447079781): This is way too restrictive. We need to make it more
# precise by:
# - Consider whether the op is annotated with optimized copies or not.
# - If copies do not have to be optimized, always return True.
# - If copies have to be optimized, determine if the transfer is optimal by
# calling fragmented_array.plan_tiled_transfer.
if inference_utils.is_mma_layout(reg_layout):
return smem_layout is not None and len(smem_layout.tiling) == 2
return smem_layout is None
def holds(self) -> bool | None:
"""Returns whether the constraint holds.
Returns `None` if the constraint can't be checked.
"""
source = self.source
target = self.target
if isinstance(source, TMEMLayout) and isinstance(target, RegisterLayout):
return self._is_valid_tmem_transfer(source.value, target.value)
if isinstance(target, TMEMLayout) and isinstance(source, RegisterLayout):
return self._is_valid_tmem_transfer(target.value, source.value)
if isinstance(source, TMEMLayout) and isinstance(target, TMEMLayout):
return source == target
if isinstance(source, SMEMTiling) and isinstance(target, RegisterLayout):
return self._is_valid_smem_transfer(source.value, target.value)
if isinstance(target, SMEMTiling) and isinstance(source, RegisterLayout):
return self._is_valid_smem_transfer(target.value, source.value)
if isinstance(target, Constant) and isinstance(source, Constant):
source_type = type(source).__name__
target_type = type(target).__name__
raise NotImplementedError(f"Unsupported transfer: {source_type} -> {target_type}")
return None
def __str__(self):
return f"IsTransferable({self.source}{self.target})"
@dataclasses.dataclass(frozen=True)
class NotOfType:
"""States that `expr` is not an instance of `type`."""
expr: Expression
type: type[fa.FragmentedLayout]
def holds(self) -> bool | None:
"""Whether the distinctiveness constraint holds.
Returns `None` if the constraint can't be checked.
"""
if not isinstance(self.expr, Constant):
return None
if not isinstance(self.expr, RegisterLayout):
return True
return not isinstance(self.expr.value, self.type)
def __str__(self):
return f"type({self.expr}) ≠ {self.type.__name__}"
@dataclasses.dataclass(frozen=True)
class Divides:
"""States that the `expr` tiling is a divisor of `tiling_multiple`.
That is to say that, for each tiled dimension in `expr`, the dimension must
divide its corresponding dimension in `tiling_multiple` starting from the
tail.
If `tiling_multiple` contains more dimensions than `expr`, then
the extra dimensions in `tiling_multiple` are ignored for the purposes of the
check.
`expr` is not allowed to contain more dimensions than `tiling_multiple`, and
this constraint therefore also constrains the rank of `expr`.
"""
expr: Expression
tiling_multiple: tuple[int, ...]
def holds(self) -> bool | None:
match self.expr:
case SMEMTiling(value=None):
# If there is no tiling, then this holds trivially.
return True
case SMEMTiling(value=lc.TileTransform(tiling=t)):
tiling = t
case RegisterLayout(value=fa.TiledLayout() as layout):
tiling = layout.base_tile_shape
case TMEMLayout(value):
tiling = value.base_tile_shape
case _:
return None
if len(tiling) > len(self.tiling_multiple):
# The rank of the tiling is larger than the rank of the constraint. This
# is not allowed.
return False
for size, multiple in zip(reversed(tiling), reversed(self.tiling_multiple)):
if multiple % size:
return False
return True
def __str__(self):
return f"{self.tiling_multiple} % {self.expr} == 0"
Constraint = Relayout | NotOfType | IsTransferable | Divides
def reduce_constraint(
constraint: Constraint, assignments: dict[Variable, Constant]
) -> Constraint | Tautological | Unsatisfiable:
"""Reduces a constraint."""
new_constraint: Constraint
match constraint:
case Relayout(source=source, target=target):
source_red = reduce_expression(source, assignments)
target_red = reduce_expression(target, assignments)
if isinstance(source_red, Unsatisfiable) or isinstance(
target_red, Unsatisfiable
):
return Unsatisfiable()
new_constraint = Relayout(source_red, target_red)
case NotOfType(expr=expr, type=type):
expr_red = reduce_expression(expr, assignments)
if isinstance(expr_red, Unsatisfiable):
return Unsatisfiable()
new_constraint = NotOfType(expr_red, type)
case IsTransferable(source=source, target=target, shape=shape):
source_red = reduce_expression(source, assignments)
target_red = reduce_expression(target, assignments)
if isinstance(source_red, Unsatisfiable) or isinstance(target_red, Unsatisfiable):
return Unsatisfiable()
new_constraint = IsTransferable(source_red, target_red, shape)
case Divides(expr=expr, tiling_multiple=tiling_multiple):
expr_red = reduce_expression(expr, assignments)
if isinstance(expr_red, Unsatisfiable):
return Unsatisfiable()
new_constraint = Divides(expr_red, tiling_multiple)
case _ as never:
assert_never(never)
constraint_holds = new_constraint.holds()
if constraint_holds is None:
return new_constraint
return Tautological() if constraint_holds else Unsatisfiable()
@dataclasses.dataclass(frozen=True)
class Equation:
lhs: Expression
rhs: Expression
def __str__(self):
return f"{self.lhs} == {self.rhs}"
def reduce_equation(
eq: Equation, assignments: dict[Variable, Constant]
) -> Solution:
"""Reduces an equation.
Args:
eq: the equation to reduce.
assignments: a set of known variable assignments.
Returns:
A Solution object representing the result of the evaluation. That is:
- Unsatisfiable(): if the equation is unsatisfiable.
- Tautological(): if the equation is tautological.
- Satisfiable(): if the equation is satisfiable by assigning a value to
a variable.
- Unknown(): if the equation contains remaining unknown variables.
"""
lhs = reduce_expression(eq.lhs, assignments)
rhs = reduce_expression(eq.rhs, assignments)
match (lhs, rhs):
case (Variable(), Constant()):
return SatisfiedBy((lhs, rhs))
case (Constant(), Variable()):
return SatisfiedBy((rhs, lhs))
case (Constant(), Constant()) if lhs != rhs:
return Unsatisfiable()
case _ if isinstance(lhs, Unsatisfiable) or isinstance(rhs, Unsatisfiable):
return Unsatisfiable()
case _ if lhs == rhs:
return Tautological()
case _:
# This is covered above. Add a check here to appease the type checker.
assert not isinstance(lhs, Unsatisfiable) and not isinstance(rhs, Unsatisfiable)
return Unknown(Equation(lhs, rhs))
@dataclasses.dataclass
class EquationSystem:
"""An equation system contains a set of equations and assignments.
Assignments assign constant values to variables in the system (bound
variables). Equations describe relationships between variables, and can be
used to determine assignments for unknown (free) variables.
Constraints are used to check predicates that must hold for the assignments to
be valid.
"""
assignments: dict[Variable, Constant] = dataclasses.field(
default_factory=dict
)
equations: list[Equation] = dataclasses.field(default_factory=list)
constraints: Sequence[Constraint] = dataclasses.field(default_factory=list)
def unknowns(self) -> list[Variable]:
"""Returns the list of free variables in the system."""
seen_variables: set[Variable] = set()
free_variables: list[Variable] = []
def extract_variables(expr: Expression) -> None:
match expr:
case Variable():
if expr not in seen_variables and expr not in self.assignments:
seen_variables.add(expr)
free_variables.append(expr)
case Constant():
...
case MostReplicated(expressions=expressions):
for e in expressions:
extract_variables(e)
case LeastReplicated(expressions=expressions):
for e in expressions:
extract_variables(e)
case Reduce(expression=e):
extract_variables(e)
case BroadcastInDim(expression=e):
extract_variables(e)
case Reshape(expression=e):
extract_variables(e)
case Transpose(expression=e):
extract_variables(e)
case _:
assert_never(expr)
for equation in self.equations:
extract_variables(equation.lhs)
extract_variables(equation.rhs)
for constraint in self.constraints:
match constraint:
case Relayout(source=source, target=target):
extract_variables(source)
extract_variables(target)
case NotOfType(expr=expr):
extract_variables(expr)
case IsTransferable(source=source, target=target, shape=_):
extract_variables(source)
extract_variables(target)
case Divides(expr=expr):
extract_variables(expr)
case _ as never:
assert_never(never)
return free_variables
def __and__(self, other: EquationSystem) -> EquationSystem | Unsatisfiable:
for variable, assignment in self.assignments.items():
if variable in other.assignments and assignment != other.assignments[variable]:
return Unsatisfiable()
return EquationSystem(
assignments=self.assignments | other.assignments,
equations=self.equations + other.equations,
constraints=[*self.constraints, *other.constraints],
)
def __str__(self):
r = "EquationSystem\n"
r += " assignments:\n"
for assignment, constant in self.assignments.items():
r += f" {assignment}{constant}\n"
r += " equations:\n"
for equation in self.equations:
r += f" {equation}\n"
r += " constraints:\n"
for constraint in self.constraints:
r += f" {constraint}\n"
return r
@final
class Unsatisfiable:
def __and__(self, other: EquationSystem | Unsatisfiable) -> Unsatisfiable:
return self
@dataclasses.dataclass(frozen=True)
class SatisfiedBy:
assignment: tuple[Variable, Constant]
@dataclasses.dataclass(frozen=True)
class Unknown:
equation: Equation
class Tautological:
...
def non_splat_variables(
constraints: Sequence[Constraint],
) -> set[Variable]:
"""Returns a all vars distinct from a splat."""
vars: set[Variable] = set()
for constraint in constraints:
match constraint:
case NotOfType(expr=Variable() as var, type=fa.WGSplatFragLayout):
assert isinstance(var, Variable) # make pytype happy
vars.add(var)
return vars
# The result of reducing an equation---and by extension, a system of
# equations. An equation can either be unsatisfiable (i.e. there exists no
# assignment for which it holds), satisfied by an assignment, unknown (i.e.
# still undetermined), or tautological (i.e. the equation is guaranteed to
# hold for any assignment).
Solution = Unsatisfiable | SatisfiedBy | Unknown | Tautological
def _has_relayout_of_non_splat_to_splat(constraints: Sequence[Constraint]) -> bool:
"""Returns whether the constraints imply a non-splat to splat relayout.
Such relayouts are impossible and this helps shortcut the search.
If this function returns False, this doesn't necessarily mean that there are
no non-splat to splat relayouts, just that this is not known yet.
"""
non_splat = non_splat_variables(constraints)
if not non_splat:
return False
def is_constant_splat(e) -> bool:
return isinstance(e, RegisterLayout) and isinstance(
e.value, fa.WGSplatFragLayout
)
for constraint in constraints:
match constraint:
case Relayout(source=source, target=target):
if source in non_splat and is_constant_splat(target):
return True
case _:
pass
return False
def saturate_distinct_from_splat(
equation_system: EquationSystem,
) -> EquationSystem | Unsatisfiable:
"""Adds transitive NotOfType constraints for all non-splat variables.
Given `n` variables `l0`, ... `l{n-1}`, and a set of relayouts
`{ Relayout(l{i}, l{i+1}) : 0 <= i < n }`, if we also know that
`l{0}` is not splat, then we can automatically deduce that none of
`l0`, ..., `l{n-1}` are splat either.
This helps us quickly conclude that a system is unsatisfiable in cases where
a non-splat variable is transitively relaid out into a splat layout.
"""
non_splat = non_splat_variables(equation_system.constraints)
new_constraints: list[Constraint] = []
new_non_splat_found = len(non_splat) > 0
while new_non_splat_found:
new_non_splat_found = False
for constraint in equation_system.constraints:
match constraint:
case Relayout(source=source, target=target):
if (
isinstance(target, Variable)
and source in non_splat
and target not in non_splat
):
new_non_splat_found = True
non_splat.add(target)
new_constraints.append(NotOfType(target, fa.WGSplatFragLayout))
case _:
pass
return equation_system & EquationSystem(constraints=new_constraints)
def compute_transitively_equal_vars(
system: EquationSystem,
) -> dict[Variable, list[Variable]]:
"""Computes all transitively equal variables in an equation system.
The output dictionary maps each variable that appears in equations in the
equation system to all the variables it is transitively equal to.
"""
# The equality relations between variables form a graph where variables are
# nodes and an equation `v1 == v2` forms an edge. All variables in a
# connected component are transitively equal. We use a Union-Find data
# structure with path compression to efficiently find these connected
# components (i.e., equivalence classes).
parent: dict[Variable, Variable] = {}
def find(v: Variable) -> Variable:
if v not in parent:
parent[v] = v
if parent[v] != v:
parent[v] = find(parent[v])
return parent[v]
def union(v1: Variable, v2: Variable):
root1 = find(v1)
root2 = find(v2)
if root1 != root2:
parent[root2] = root1
all_vars: set[Variable] = set()
for eq in system.equations:
if isinstance(eq.lhs, Variable) and isinstance(eq.rhs, Variable):
all_vars.add(eq.lhs)
all_vars.add(eq.rhs)
union(eq.lhs, eq.rhs)
# Group variables by their component representative.
components: dict[Variable, list[Variable]] = {}
for v in sorted(all_vars, key=str):
root = find(v)
components.setdefault(root, []).append(v)
equal_vars: dict[Variable, list[Variable]] = {}
for component_vars in components.values():
for v in component_vars:
equal_vars[v] = [other for other in component_vars if other != v]
return equal_vars
def saturate_divides_constraints_for_equal_vars(
system: EquationSystem,
) -> EquationSystem:
"""Saturates Divides constraints between all transitively equal vars.
"""
equal_vars = compute_transitively_equal_vars(system)
new_constraints: list[Constraint] = []
for constraint in system.constraints:
new_constraints.append(constraint)
match constraint:
case Divides(expr=expr, tiling_multiple=tiling_multiple):
if isinstance(expr, Variable):
for equal_var in equal_vars.get(expr, []):
new_constraints.append(Divides(equal_var, tiling_multiple))
case _:
pass
new_constraints = merge_divides_constraints(new_constraints)
return dataclasses.replace(system, constraints=new_constraints)
# TODO(bchetioui): clean up API.
def merge_divides_constraints(constraints: Sequence[Constraint]) -> list[Constraint]:
"""Merges Divides constraints that can be merged."""
result: list[Constraint] = []
var_to_tiling_multiples : dict[Variable, tuple[int, ...]] = {}
for constraint in constraints:
match constraint:
case Divides(expr=Variable() as v, tiling_multiple=tiling_multiple):
assert isinstance(v, Variable) # make pytype happy
if (previous_tiling_multiple := var_to_tiling_multiples.get(v)) is None:
var_to_tiling_multiples[v] = tiling_multiple
continue
# If the two tuples are of different lengths, the larger tuple will
# be truncated (removing initial multiples) to the length of the
# smaller tuple. This preserves the semantics of the Divides constraints
# where a tiling's rank cannot exceed the size of tiling_multiple.
min_len = min(len(tiling_multiple), len(previous_tiling_multiple))
new_tiling_multiple = []
if min_len > 0:
for x, y in zip(tiling_multiple[-min_len:], previous_tiling_multiple[-min_len:], strict=True):
new_tiling_multiple.append(math.gcd(x, y))
var_to_tiling_multiples[v] = tuple(new_tiling_multiple)
case _:
result.append(constraint)
for expr, tiling_multiple in var_to_tiling_multiples.items():
result.append(Divides(expr, tiling_multiple))
return result
def _reduce_system_once(
equation_system: EquationSystem,
) -> EquationSystem | Unsatisfiable | None:
"""Performs one reduction step over each equation in an equation system.
Returns:
- Unsatisfiable(): if the equation system is unsatisfiable.
- A new equation system if any equation was reduced.
- None: if the equation system is not known unsatisfiable, but hasn't been
reduced.
"""
changed = False
assignments: dict[Variable, Constant] = {}
equations: list[Equation] = []
for equation in equation_system.equations:
match reduce_equation(equation, equation_system.assignments):
case Unsatisfiable():
return Unsatisfiable()
case Tautological():
changed = True
case SatisfiedBy() as result:
variable, expression = result.assignment
if variable in assignments and assignments[variable] != expression:
return Unsatisfiable()
assignments[variable] = expression
changed = True
case Unknown(equation=reduced_equation):
equations.append(reduced_equation)
changed |= reduced_equation != equation
case _ as never:
assert_never(never)
assignments |= equation_system.assignments
constraints: list[Constraint] = []
for constraint in equation_system.constraints:
match reduce_constraint(constraint, assignments):
case Unsatisfiable():
return Unsatisfiable()
case Tautological():
changed = True
case _ as new_constraint:
changed |= new_constraint != constraint
constraints.append(new_constraint)
new_constraints = merge_divides_constraints(constraints)
changed |= len(new_constraints) != len(constraints)
constraints = new_constraints
# Shortcut for a specific case of unsatisfiability. This shortcut
# drastically reduces the size of the search space.
if _has_relayout_of_non_splat_to_splat(constraints):
return Unsatisfiable()
if changed:
return EquationSystem(
assignments=assignments | equation_system.assignments,
equations=equations,
constraints=constraints,
)
return None
def reduce(equation_system: EquationSystem) -> EquationSystem | Unsatisfiable:
"""Reduces an equation system until it can no longer be reduced.
Returns:
- Unsatisfiable(): if the equation system is unsatisfiable.
- The maximally reduced equation system otherwise.
"""
while True:
match _reduce_system_once(equation_system):
case None:
break
case Unsatisfiable():
return Unsatisfiable()
case EquationSystem() as new_system:
equation_system = new_system
case _ as never:
assert_never(never)
return equation_system