Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ReversBits op added for transfer dialect #43

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion xdsl_smt/dialects/transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,11 @@ class NegOp(UnaryOp):
name = "transfer.neg"


@irdl_op_definition
class ReverseBitsOp(UnaryOp):
name = "transfer.reverse_bits"


class BinOp(IRDLOperation, InferResultTypeInterface, ABC):
T = Annotated[TransIntegerType | IntegerType, ConstraintVar("T")]

Expand Down Expand Up @@ -380,6 +385,23 @@ def __init__(self, shape: list[Attribute] | ArrayAttr[Attribute]) -> None:
super().__init__([shape])


@irdl_attr_definition
class TupleType(ParametrizedAttribute, TypeAttribute):
name = "transfer.tuple"
fields: ParameterDef[ArrayAttr[Attribute]]

def get_num_fields(self) -> int:
return len(self.fields.data)

def get_fields(self):
return [i for i in self.fields.data]

def __init__(self, shape: list[Attribute] | ArrayAttr[Attribute]) -> None:
if isinstance(shape, list):
shape = ArrayAttr(shape)
super().__init__([shape])
Hatsunespica marked this conversation as resolved.
Show resolved Hide resolved


@irdl_op_definition
class GetOp(IRDLOperation, InferResultTypeInterface):
name = "transfer.get"
Expand Down Expand Up @@ -583,6 +605,7 @@ def infer_result_type(
IntersectsOp,
AddPoisonOp,
RemovePoisonOp,
ReverseBitsOp,
],
[TransIntegerType, AbstractValueType],
[TransIntegerType, AbstractValueType, TupleType],
)
107 changes: 105 additions & 2 deletions xdsl_smt/semantics/transfer_semantics.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import dataclass

Hatsunespica marked this conversation as resolved.
Show resolved Hide resolved
from xdsl.pattern_rewriter import (
PatternRewriter,
)
Expand Down Expand Up @@ -30,6 +31,7 @@
count_rzeros,
count_lones,
count_rones,
reverse_bits,
)


Expand All @@ -38,7 +40,9 @@ class AbstractValueTypeSemantics(TypeSemantics):
But the last element is useless, this makes GetOp easier"""

def get_semantics(self, type: Attribute) -> Attribute:
assert isinstance(type, transfer.AbstractValueType)
assert isinstance(type, transfer.AbstractValueType) or isinstance(
type, transfer.TupleType
)
curTy = type.get_fields()[-1]
isIntegerTy = isinstance(curTy, IntegerType)
curLoweredTy = SMTLowerer.lower_type(curTy)
Expand Down Expand Up @@ -598,7 +602,9 @@ def get_semantics(
)
numBits = numBitsOp.value.value.data
bitPosition = bitPositionOp.value.value.data
extractOp = smt_bv.ExtractOp(operands[0], numBits + bitPosition, bitPosition)
extractOp = smt_bv.ExtractOp(
operands[0], numBits + bitPosition - 1, bitPosition
)
rewriter.insert_op_before_matched_op(extractOp)
return ((extractOp.res,), effect_state)

Expand Down Expand Up @@ -637,6 +643,102 @@ def get_semantics(
return ((res.res,), effect_state)


class ReverseBitsOpSemantics(OperationSemantics):
def get_semantics(
self,
operands: Sequence[SSAValue],
results: Sequence[Attribute],
attributes: Mapping[str, Attribute | SSAValue],
effect_state: SSAValue | None,
rewriter: PatternRewriter,
) -> tuple[Sequence[SSAValue], SSAValue | None]:
op_ty = operands[0].type
assert isinstance(op_ty, smt_bv.BitVectorType)
res = reverse_bits(operands[0])
rewriter.insert_op_before_matched_op(res)
return ((res[-1].results[0],), effect_state)


class ConstRangeForOpSemantics(OperationSemantics):
def get_semantics(
self,
operands: Sequence[SSAValue],
results: Sequence[Attribute],
attributes: Mapping[str, Attribute | SSAValue],
effect_state: SSAValue | None,
rewriter: PatternRewriter,
) -> tuple[Sequence[SSAValue], SSAValue | None]:
cur_op = rewriter.current_operation
lb = operands[0].owner
ub = operands[1].owner
step = operands[2].owner
assert (
isinstance(lb, smt_bv.ConstantOp)
and "loop lower bound has to be a constant"
)
assert (
isinstance(ub, smt_bv.ConstantOp)
and "loop upper bound has to be a constant"
)
Hatsunespica marked this conversation as resolved.
Show resolved Hide resolved
assert isinstance(step, smt_bv.ConstantOp) and "loop step has to be a constant"
lb_int = lb.value.value.data
ub_int = ub.value.value.data
step_int = step.value.value.data

assert step_int != 0 and "step size should not be zero"
if step_int > 0:
assert (
ub_int > lb_int
and "the upper bound should be larger than the lower bound"
)
Hatsunespica marked this conversation as resolved.
Show resolved Hide resolved
else:
assert (
ub_int < lb_int
and "the upper bound should be smaller than the lower bound"
)
Hatsunespica marked this conversation as resolved.
Show resolved Hide resolved

iter_args = operands[3:]
iter_args_num = len(iter_args)

indvar, *block_iter_args = cur_op.regions[0].block.args

value_map: dict[SSAValue, SSAValue] = {}

value_map[indvar] = operands[0]
for i in range(iter_args_num):
value_map[block_iter_args[i]] = iter_args[i]
last_result = None
for i in range(lb_int, ub_int, step_int):
Hatsunespica marked this conversation as resolved.
Show resolved Hide resolved
for cur_op in cur_op.regions[0].block.ops:
if not isinstance(cur_op, transfer.NextLoopOp):
clone_op = cur_op.clone()
for idx in range(len(clone_op.operands)):
if cur_op.operands[idx] in value_map:
clone_op.operands[idx] = value_map[cur_op.operands[idx]]
if len(cur_op.results) != 0:
value_map[cur_op.results[0]] = clone_op.results[0]
rewriter.insert_op_before_matched_op(clone_op)
continue
if isinstance(cur_op, transfer.NextLoopOp):
if i + step_int < ub_int:
new_value_map: dict[SSAValue, SSAValue] = {}
cur_ind = transfer.Constant(operands[1], i + step_int).result
new_value_map[indvar] = cur_ind
rewriter.insert_op_before_matched_op(cur_ind.owner)
for idx, arg in enumerate(block_iter_args):
new_value_map[arg] = value_map[cur_op.operands[idx]]
value_map = new_value_map
else:
make_res = [value_map[arg] for arg in cur_op.arguments]
assert (
len(make_res) == 1
and "current we only support for one returned value from for"
)
Hatsunespica marked this conversation as resolved.
Show resolved Hide resolved
last_result = make_res[0]
assert last_result is not None
return ((last_result,), effect_state)


transfer_semantics: dict[type[Operation], OperationSemantics] = {
transfer.Constant: ConstantOpSemantics(),
transfer.AddOp: TrivialOpSemantics(transfer.AddOp, smt_bv.AddOp),
Expand Down Expand Up @@ -673,4 +775,5 @@ def get_semantics(
transfer.IntersectsOp: IntersectsOpSemantics(),
transfer.AddPoisonOp: AddPoisonOpSemantics(),
transfer.RemovePoisonOp: RemovePoisonOpSemantics(),
transfer.ReverseBitsOp: ReverseBitsOpSemantics(),
}
20 changes: 20 additions & 0 deletions xdsl_smt/utils/transfer_to_smt_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,26 @@ def count_ones(b: SSAValue) -> list[Operation]:
return bits + [zero] + bvs + nb


def reverse_bits(bits: SSAValue) -> list[Operation]:
Hatsunespica marked this conversation as resolved.
Show resolved Hide resolved
assert isinstance(bits.type, smt_bv.BitVectorType)
n = bits.type.width.data
if n == 1:
# If width is only one, no need to reverse bit, but just a comparision
zero = smt_bv.ConstantOp(0, 1)
one = smt_bv.ConstantOp(1, 1)
eqOp = smt.EqOp(bits, one.res)
iteOp = smt.IteOp(eqOp.res, one.res, zero.res)
return [zero, one, eqOp, iteOp]
Hatsunespica marked this conversation as resolved.
Show resolved Hide resolved
else:
bits_ops: list[Operation] = [smt_bv.ExtractOp(bits, i, i) for i in range(n)]
cur_bits: SSAValue = bits_ops[0].results[0]
result: list[smt_bv.ConcatOp] = []
for bit in bits_ops[1:]:
result.append(smt_bv.ConcatOp(cur_bits, bit.results[0]))
cur_bits = result[-1].res
return bits_ops + result


pow2 = [2**i for i in range(0, 9)]


Expand Down
Loading