Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend cpp translator to support backwards transfer functions #45

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 63 additions & 4 deletions xdsl_smt/cli/cpp_translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,17 @@
from xdsl.parser import Parser

from xdsl.dialects.arith import Arith
from xdsl.dialects.builtin import Builtin, ModuleOp
from xdsl.dialects.func import Func, FuncOp
from xdsl.dialects.func import Func
from xdsl_smt.dialects.transfer import Transfer
from xdsl_smt.dialects.llvm_dialect import LLVM
from xdsl_smt.passes.transfer_lower import LowerToCpp, addDispatcher, addInductionOps
from xdsl.dialects.func import FuncOp, Return
from xdsl.dialects.builtin import (
Builtin,
ModuleOp,
IntegerAttr,
StringAttr,
)


def register_all_arguments(arg_parser: argparse.ArgumentParser):
Expand All @@ -34,6 +40,38 @@ def parse_file(ctx: MLContext, file: str | None) -> Operation:
return module


def is_transfer_function(func: FuncOp) -> bool:
return "applied_to" in func.attributes


def is_forward(func: FuncOp) -> bool:
if "is_forward" in func.attributes:
forward = func.attributes["is_forward"]
assert isinstance(forward, IntegerAttr)
return forward.value.data == 1
return False


def getCounterexampleFunc(func: FuncOp) -> str | None:
if "soundness_counterexample" not in func.attributes:
return None
attr = func.attributes["soundness_counterexample"]
assert isinstance(attr, StringAttr)
return attr.data


def checkFunctionValidity(func: FuncOp) -> bool:
if len(func.function_type.inputs) != len(func.args):
return False
for func_type_arg, arg in zip(func.function_type.inputs, func.args):
if func_type_arg != arg.type:
return False
return_op = func.body.block.last_op
if not (return_op is not None and isinstance(return_op, Return)):
return False
return return_op.operands[0].type == func.function_type.outputs.data[0]


def main() -> None:
ctx = MLContext()
arg_parser = argparse.ArgumentParser()
Expand All @@ -51,15 +89,36 @@ def main() -> None:
module = parse_file(ctx, args.transfer_functions)
assert isinstance(module, ModuleOp)

allFuncMapping = {}
allFuncMapping: dict[str, FuncOp] = {}
forward = False
counterexampleFuncs: set[str] = set()
with open("tmp.cpp", "w") as fout:
LowerToCpp.fout = fout
for func in module.ops:
if isinstance(func, FuncOp):
if is_transfer_function(func):
forward |= is_transfer_function(func) and is_forward(func)
counterexampleFunc = getCounterexampleFunc(func)
if counterexampleFunc is not None:
counterexampleFuncs.add(counterexampleFunc)
allFuncMapping[func.sym_name.data] = func

# check function validity
if not checkFunctionValidity(func):
print(func.sym_name)
# check function validity

for counterexample in counterexampleFuncs:
assert counterexample in allFuncMapping
allFuncMapping[counterexample].detach()
del allFuncMapping[counterexample]
for func in module.ops:
if isinstance(func, FuncOp):
allFuncMapping[func.sym_name.data] = func
# HACK: we know the pass won't check that the operation is a module
LowerToCpp(fout).apply(ctx, cast(ModuleOp, func))
addInductionOps(fout)
addDispatcher(fout)
addDispatcher(fout, forward)

# printer = Printer(target=Printer.Target.MLIR)
# printer.print_op(module)
Expand Down
10 changes: 4 additions & 6 deletions xdsl_smt/passes/transfer_lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,19 +80,17 @@ def addInductionOps(fout: TextIO):
fout.write(lowerInductionOps(inductionOp))


def addDispatcher(fout: TextIO):
def addDispatcher(fout: TextIO, is_forward: bool):
global needDispatch
if len(needDispatch) != 0:
# print(lowerDispatcher(needDispatch))
fout.write(lowerDispatcher(needDispatch))
fout.write(lowerDispatcher(needDispatch, is_forward))


@dataclass
@dataclass(frozen=True)
class LowerToCpp(ModulePass):
name = "trans_lower"

def __init__(self, fout):
self.fout = fout
fout: TextIO = None

def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None:
walker = PatternRewriteWalker(
Expand Down
121 changes: 93 additions & 28 deletions xdsl_smt/utils/lower_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,16 @@
ConstRangeForOp,
RepeatOp,
IntersectsOp,
FromArithOp,
# FromArithOp,
TupleType,
AddPoisonOp,
RemovePoisonOp,
)
from xdsl.dialects.func import FuncOp, Return, Call
from xdsl.pattern_rewriter import *
from functools import singledispatch
from typing import TypeVar, cast
from xdsl.dialects.builtin import Signedness, IntegerType, IndexType
from xdsl.dialects.builtin import Signedness, IntegerType, IndexType, IntegerAttr
from xdsl.ir import Operation
import xdsl.dialects.arith as arith

Expand Down Expand Up @@ -81,9 +84,9 @@
".ugt",
".uge",
],
"transfer.fromArith": "APInt",
"transfer.make": "std::make_tuple",
"transfer.get": "std::get<{0}>",
# "transfer.fromArith": "APInt",
"transfer.make": "{{{0}}}",
"transfer.get": "[{0}]",
"transfer.shl": ".shl",
"transfer.ashr": ".ashr",
"transfer.lshr": ".lshr",
Expand All @@ -96,8 +99,12 @@
"func.return": "return",
"transfer.constant": "APInt",
"arith.select": ["?", ":"],
"arith.cmpi": ["==", "!=", "<", "<=", ">", ">="],
"transfer.get_all_ones": "APInt::getAllOnes",
"transfer.select": ["?", ":"],
"transfer.reverse_bits": ".reverseBits",
"transfer.add_poison": " ",
"transfer.remove_poison": " ",
}
# transfer.constRangeLoop and NextLoop are controller operations, should be handle specially

Expand All @@ -120,25 +127,22 @@ def lowerType(typ, specialOp=None):
return "unsigned"
if isinstance(typ, TransIntegerType):
return "APInt"
elif isinstance(typ, AbstractValueType):
typeName = "std::tuple<"
elif isinstance(typ, AbstractValueType) or isinstance(typ, TupleType):
fields = typ.get_fields()
typeName += lowerType(fields[0])
typeName = lowerType(fields[0])
for i in range(1, len(fields)):
typeName += ","
typeName += lowerType(fields[i])
typeName += ">"
return typeName
assert lowerType(fields[i]) == typeName
return "std::vector<" + typeName + ">"
elif isinstance(typ, IntegerType):
return "int"
elif isinstance(typ, IndexType):
return "int"
print(typ)
assert False and "unsupported type"


CPP_CLASS_KEY = "CPPCLASS"
INDUCTION_KEY = "induction"
OPERATION_NO = "operationNo"


def lowerInductionOps(inductionOp: list[FuncOp]):
Expand All @@ -164,32 +168,38 @@ def lowerInductionOps(inductionOp: list[FuncOp]):
return result


def lowerDispatcher(needDispatch: list[FuncOp]):
def lowerDispatcher(needDispatch: list[FuncOp], is_forward: bool):
if len(needDispatch) > 0:
returnedType = needDispatch[0].function_type.outputs.data[0]
for func in needDispatch:
if func.function_type.outputs.data[0] != returnedType:
print(func)
print(func.function_type.outputs.data[0])
assert (
"we assume all transfer functions have the same returned type"
and False
)
returnedType = lowerType(returnedType)
funcName = "naiveDispatcher"
# we assume all operands have the same type as expr
expr = "(Operation* op, ArrayRef<" + returnedType + "> operands)"
# User should tell the generator all operands
if is_forward:
expr = "(Operation* op, std::vector<std::vector<llvm::APInt>> operands)"
else:
expr = "(Operation* op, std::vector<std::vector<llvm::APInt>> operands, unsigned operationNo)"
functionSignature = (
"std::optional<" + returnedType + "> " + funcName + expr + "{{\n{0}}}\n\n"
)
indent = "\t"
dyn_cast = (
indent
+ "if(auto castedOp=dyn_cast<{0}>(op);castedOp){{\n{1}"
+ "if(auto castedOp=dyn_cast<{0}>(op);castedOp&&{1}){{\n{2}"
+ indent
+ "}}\n"
)
return_inst = indent + indent + "return {0}({1});\n"

def handleOneTransferFunction(func: FuncOp) -> str:
def handleOneTransferFunction(func: FuncOp, operationNo: int) -> str:
blockStr = ""
for cppClass in func.attributes[CPP_CLASS_KEY]:
argStr = ""
Expand All @@ -201,12 +211,21 @@ def handleOneTransferFunction(func: FuncOp) -> str:
for i in range(1, len(func.args)):
argStr += ", operands[" + str(i) + "]"
ifBody = return_inst.format(func.sym_name.data, argStr)
blockStr += dyn_cast.format(cppClass.data, ifBody)
if operationNo == -1:
operationNoStr = "true"
else:
operationNoStr = "operationNo == " + str(operationNo)
blockStr += dyn_cast.format(cppClass.data, operationNoStr, ifBody)
return blockStr

funcBody = ""
for func in needDispatch:
funcBody += handleOneTransferFunction(func)
if is_forward:
funcBody += handleOneTransferFunction(func)
else:
operationNo = func.attributes[OPERATION_NO]
assert isinstance(operationNo, IntegerAttr)
funcBody += handleOneTransferFunction(func, operationNo.value.data)
funcBody += indent + "return {};\n"
return functionSignature.format(funcBody)

Expand Down Expand Up @@ -302,13 +321,10 @@ def _(op: arith.Cmpi):
returnedValue = op.results[0].name_hint
equals = "="
operandsName = [oper.name_hint for oper in op.operands]
assert len(operandsName) == 2
predicate = op.predicate.value.data
operName = operNameToCpp[op.name][predicate]
expr = operandsName[0] + operName + "("
if len(operandsName) > 1:
expr += operandsName[1]
for i in range(2, len(operandsName)):
expr += "," + operandsName[i]
expr = "(" + operandsName[0] + operName + operandsName[1]
expr += ")"
return indent + returnedType + " " + returnedValue + equals + expr + ends

Expand Down Expand Up @@ -355,17 +371,32 @@ def _(op: GetOp):
+ " "
+ returnedValue
+ equals
+ operNameToCpp[op.name].format(index)
+ "("
+ op.operands[0].name_hint
+ ")"
+ operNameToCpp[op.name].format(index)
+ ends
)


@lowerOperation.register
def _(op: MakeOp):
return lowerToNonClassMethod(op)
returnedType = lowerType(op.results[0].type, op)
returnedValue = op.results[0].name_hint
equals = "="
expr = ""
if len(op.operands) > 0:
expr += op.operands[0].name_hint
for i in range(1, len(op.operands)):
expr += "," + op.operands[i].name_hint
return (
indent
+ returnedType
+ " "
+ returnedValue
+ equals
+ returnedType
+ operNameToCpp[op.name].format(expr)
+ ends
)


@lowerOperation.register
Expand Down Expand Up @@ -402,6 +433,7 @@ def _(op: Return):
return indent + opName + operand + ends


"""
@lowerOperation.register
def _(op: FromArithOp):
opTy = op.op.type
Expand All @@ -421,6 +453,7 @@ def _(op: FromArithOp):
+ ")"
+ ends
)
"""


@lowerOperation.register
Expand Down Expand Up @@ -794,3 +827,35 @@ def _(op: RepeatOp):
)
forEnd = indent + "}\n"
return initExpr + forHead + forBody + forEnd


@lowerOperation.register
def _(op: AddPoisonOp):
returnedType = lowerType(op.results[0].type)
returnedValue = op.results[0].name_hint
opName = operNameToCpp[op.name]
return (
indent
+ returnedType
+ " "
+ returnedValue
+ " = "
+ op.operands[0].name_hint
+ ends
)


@lowerOperation.register
def _(op: RemovePoisonOp):
returnedType = lowerType(op.results[0].type)
returnedValue = op.results[0].name_hint
opName = operNameToCpp[op.name]
return (
indent
+ returnedType
+ " "
+ returnedValue
+ " = "
+ op.operands[0].name_hint
+ ends
)
Loading