Skip to content

Commit

Permalink
Implement lambda with explicit param types and single-expr body
Browse files Browse the repository at this point in the history
Change llvm::StringMap to std::unordered_map and llvm::SmallVector to
std::vector for better debuggability.
  • Loading branch information
emlai committed Oct 23, 2017
1 parent 9423998 commit a08e68e
Show file tree
Hide file tree
Showing 12 changed files with 236 additions and 71 deletions.
25 changes: 17 additions & 8 deletions src/ast/ast-printer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,17 @@ std::ostream& operator<<(std::ostream& out, const Expr& expr);
std::ostream& operator<<(std::ostream& out, const Stmt& stmt);
std::ostream& operator<<(std::ostream& out, const Decl& decl);

std::ostream& operator<<(std::ostream& out, llvm::ArrayRef<ParamDecl> params) {
out << "(";

for (const ParamDecl& param : params) {
out << param;
if (&param != &params.back()) out << " ";
}

return out << ")";
}

std::ostream& operator<<(std::ostream& out, const VarExpr& expr) {
return out << expr.getIdentifier();
}
Expand Down Expand Up @@ -102,6 +113,10 @@ std::ostream& operator<<(std::ostream& out, const UnwrapExpr& expr) {
return out << "(unwrap " << expr.getOperand() << ")";
}

std::ostream& operator<<(std::ostream& out, const LambdaExpr& expr) {
return out << "(lambda " << expr.getParams() << " " << *expr.getBody() << ")";
}

std::ostream& operator<<(std::ostream& out, const Expr& expr) {
switch (expr.getKind()) {
case ExprKind::VarExpr: return out << llvm::cast<VarExpr>(expr);
Expand All @@ -121,6 +136,7 @@ std::ostream& operator<<(std::ostream& out, const Expr& expr) {
case ExprKind::MemberExpr: return out << llvm::cast<MemberExpr>(expr);
case ExprKind::SubscriptExpr: return out << llvm::cast<SubscriptExpr>(expr);
case ExprKind::UnwrapExpr: return out << llvm::cast<UnwrapExpr>(expr);
case ExprKind::LambdaExpr: return out << llvm::cast<LambdaExpr>(expr);
}
llvm_unreachable("all cases handled");
}
Expand Down Expand Up @@ -243,14 +259,7 @@ std::ostream& operator<<(std::ostream& out, const ParamDecl& decl) {
std::ostream& operator<<(std::ostream& out, const FunctionDecl& decl) {
out << br << (decl.isExtern() ? "(extern-function-decl " : "(function-decl ");
delta::operator<<(out, decl.getName());

out << " (";
for (const ParamDecl& param : decl.getParams()) {
out << param;
if (&param != &decl.getParams().back()) out << " ";
}
out << ") " << decl.getReturnType();

out << " " << decl.getParams() << " " << decl.getReturnType();
if (!decl.isExtern()) out << decl.getBody();
return out << ")";
}
Expand Down
111 changes: 82 additions & 29 deletions src/ast/expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ bool Expr::isLvalue() const {
case ExprKind::IntLiteralExpr: case ExprKind::FloatLiteralExpr: case ExprKind::SizeofExpr:
case ExprKind::BoolLiteralExpr: case ExprKind::CastExpr: case ExprKind::UnwrapExpr:
case ExprKind::NullLiteralExpr: case ExprKind::BinaryExpr: case ExprKind::CallExpr:
case ExprKind::LambdaExpr:
return false;
case ExprKind::PrefixExpr:
return llvm::cast<PrefixExpr>(this)->getOperator() == STAR;
Expand All @@ -38,6 +39,8 @@ void Expr::setMoved(bool moved) {
}

std::unique_ptr<Expr> Expr::instantiate(const llvm::StringMap<Type>& genericArgs) const {
std::unique_ptr<Expr> instantiation;

switch (getKind()) {
case ExprKind::VarExpr: {
auto* varExpr = llvm::cast<VarExpr>(this);
Expand All @@ -50,60 +53,72 @@ std::unique_ptr<Expr> Expr::instantiate(const llvm::StringMap<Type>& genericArgs
identifier = varExpr->getIdentifier();
}

return llvm::make_unique<VarExpr>(std::move(identifier), varExpr->getLocation());
instantiation = llvm::make_unique<VarExpr>(std::move(identifier), varExpr->getLocation());
llvm::cast<VarExpr>(*instantiation).setDecl(varExpr->getDecl());
break;
}
case ExprKind::StringLiteralExpr: {
auto* stringLiteralExpr = llvm::cast<StringLiteralExpr>(this);
return llvm::make_unique<StringLiteralExpr>(stringLiteralExpr->getValue(),
stringLiteralExpr->getLocation());
instantiation = llvm::make_unique<StringLiteralExpr>(stringLiteralExpr->getValue(),
stringLiteralExpr->getLocation());
break;
}
case ExprKind::CharacterLiteralExpr: {
auto* characterLiteralExpr = llvm::cast<CharacterLiteralExpr>(this);
return llvm::make_unique<CharacterLiteralExpr>(characterLiteralExpr->getValue(),
characterLiteralExpr->getLocation());
instantiation = llvm::make_unique<CharacterLiteralExpr>(characterLiteralExpr->getValue(),
characterLiteralExpr->getLocation());
break;
}
case ExprKind::IntLiteralExpr: {
auto* intLiteralExpr = llvm::cast<IntLiteralExpr>(this);
return llvm::make_unique<IntLiteralExpr>(intLiteralExpr->getValue(),
intLiteralExpr->getLocation());
instantiation = llvm::make_unique<IntLiteralExpr>(intLiteralExpr->getValue(),
intLiteralExpr->getLocation());
break;
}
case ExprKind::FloatLiteralExpr: {
auto* floatLiteralExpr = llvm::cast<FloatLiteralExpr>(this);
return llvm::make_unique<FloatLiteralExpr>(floatLiteralExpr->getValue(),
floatLiteralExpr->getLocation());
instantiation = llvm::make_unique<FloatLiteralExpr>(floatLiteralExpr->getValue(),
floatLiteralExpr->getLocation());
break;
}
case ExprKind::BoolLiteralExpr: {
auto* boolLiteralExpr = llvm::cast<BoolLiteralExpr>(this);
return llvm::make_unique<BoolLiteralExpr>(boolLiteralExpr->getValue(),
boolLiteralExpr->getLocation());
instantiation = llvm::make_unique<BoolLiteralExpr>(boolLiteralExpr->getValue(),
boolLiteralExpr->getLocation());
break;
}
case ExprKind::NullLiteralExpr: {
auto* nullLiteralExpr = llvm::cast<NullLiteralExpr>(this);
return llvm::make_unique<NullLiteralExpr>(nullLiteralExpr->getLocation());
instantiation = llvm::make_unique<NullLiteralExpr>(nullLiteralExpr->getLocation());
break;
}
case ExprKind::ArrayLiteralExpr: {
auto* arrayLiteralExpr = llvm::cast<ArrayLiteralExpr>(this);
auto elements = ::instantiate(arrayLiteralExpr->getElements(), genericArgs);
return llvm::make_unique<ArrayLiteralExpr>(std::move(elements),
arrayLiteralExpr->getLocation());
instantiation = llvm::make_unique<ArrayLiteralExpr>(std::move(elements),
arrayLiteralExpr->getLocation());
break;
}
case ExprKind::TupleExpr: {
auto* tupleExpr = llvm::cast<TupleExpr>(this);
auto elements = ::instantiate(tupleExpr->getElements(), genericArgs);
return llvm::make_unique<TupleExpr>(std::move(elements), tupleExpr->getLocation());
instantiation = llvm::make_unique<TupleExpr>(std::move(elements), tupleExpr->getLocation());
break;
}
case ExprKind::PrefixExpr: {
auto* prefixExpr = llvm::cast<PrefixExpr>(this);
auto operand = prefixExpr->getOperand().instantiate(genericArgs);
return llvm::make_unique<PrefixExpr>(prefixExpr->getOperator(), std::move(operand),
prefixExpr->getLocation());
instantiation = llvm::make_unique<PrefixExpr>(prefixExpr->getOperator(), std::move(operand),
prefixExpr->getLocation());
break;
}
case ExprKind::BinaryExpr: {
auto* binaryExpr = llvm::cast<BinaryExpr>(this);
auto lhs = binaryExpr->getLHS().instantiate(genericArgs);
auto rhs = binaryExpr->getRHS().instantiate(genericArgs);
return llvm::make_unique<BinaryExpr>(binaryExpr->getOperator(), std::move(lhs),
std::move(rhs), binaryExpr->getLocation());
instantiation = llvm::make_unique<BinaryExpr>(binaryExpr->getOperator(), std::move(lhs),
std::move(rhs), binaryExpr->getLocation());
break;
}
case ExprKind::CallExpr: {
auto* callExpr = llvm::cast<CallExpr>(this);
Expand All @@ -114,40 +129,64 @@ std::unique_ptr<Expr> Expr::instantiate(const llvm::StringMap<Type>& genericArgs
auto callGenericArgs = map(callExpr->getGenericArgs(), [&](Type type) {
return type.resolve(genericArgs);
});
return llvm::make_unique<CallExpr>(std::move(callee), std::move(args),
std::move(callGenericArgs), callExpr->getLocation());
instantiation = llvm::make_unique<CallExpr>(std::move(callee), std::move(args),
std::move(callGenericArgs),
callExpr->getLocation());
llvm::cast<CallExpr>(*instantiation).setCalleeDecl(callExpr->getCalleeDecl());
break;
}
case ExprKind::CastExpr: {
auto* castExpr = llvm::cast<CastExpr>(this);
auto targetType = castExpr->getTargetType().resolve(genericArgs);
auto expr = castExpr->getExpr().instantiate(genericArgs);
return llvm::make_unique<CastExpr>(targetType, std::move(expr), castExpr->getLocation());
instantiation = llvm::make_unique<CastExpr>(targetType, std::move(expr),
castExpr->getLocation());
break;
}
case ExprKind::SizeofExpr: {
auto* sizeofExpr = llvm::cast<SizeofExpr>(this);
auto type = sizeofExpr->getType().resolve(genericArgs);
return llvm::make_unique<SizeofExpr>(type, sizeofExpr->getLocation());
instantiation = llvm::make_unique<SizeofExpr>(type, sizeofExpr->getLocation());
break;
}
case ExprKind::MemberExpr: {
auto* memberExpr = llvm::cast<MemberExpr>(this);
auto base = memberExpr->getBaseExpr()->instantiate(genericArgs);
return llvm::make_unique<MemberExpr>(std::move(base), memberExpr->getMemberName(),
memberExpr->getLocation());
instantiation = llvm::make_unique<MemberExpr>(std::move(base), memberExpr->getMemberName(),
memberExpr->getLocation());
break;
}
case ExprKind::SubscriptExpr: {
auto* subscriptExpr = llvm::cast<SubscriptExpr>(this);
auto base = subscriptExpr->getBaseExpr()->instantiate(genericArgs);
auto index = subscriptExpr->getIndexExpr()->instantiate(genericArgs);
return llvm::make_unique<SubscriptExpr>(std::move(base), std::move(index),
subscriptExpr->getLocation());
instantiation = llvm::make_unique<SubscriptExpr>(std::move(base), std::move(index),
subscriptExpr->getLocation());
break;
}
case ExprKind::UnwrapExpr: {
auto* unwrapExpr = llvm::cast<UnwrapExpr>(this);
auto operand = unwrapExpr->getOperand().instantiate(genericArgs);
return llvm::make_unique<UnwrapExpr>(std::move(operand), unwrapExpr->getLocation());
instantiation = llvm::make_unique<UnwrapExpr>(std::move(operand), unwrapExpr->getLocation());
break;
}
case ExprKind::LambdaExpr: {
auto* lambdaExpr = llvm::cast<LambdaExpr>(this);
auto params = map(lambdaExpr->getParams(), [&](const ParamDecl& p) {
return ParamDecl(p.getType().resolve(genericArgs), p.getName(), p.getLocation());
});
auto body = lambdaExpr->getBody()->instantiate(genericArgs);
instantiation = llvm::make_unique<LambdaExpr>(std::move(params), std::move(body),
lambdaExpr->getLocation());
break;
}
}
llvm_unreachable("all cases handled");

if (hasType()) {
instantiation->setType(getType());
}

return instantiation;
}

llvm::StringRef CallExpr::getFunctionName() const {
Expand Down Expand Up @@ -177,3 +216,17 @@ bool BinaryExpr::isBuiltinOp() const {
if (op == DOTDOT || op == DOTDOTDOT) return false;
return getLHS().getType().isBuiltinType() && getRHS().getType().isBuiltinType();
}

std::unique_ptr<FunctionDecl> LambdaExpr::lower(Module& module) const {
static uint64_t nameCounter = 0;

FunctionProto proto("__lambda" + std::to_string(nameCounter++), std::vector<ParamDecl>(params),
body->getType(), false, false);
auto functionDecl = llvm::make_unique<FunctionDecl>(std::move(proto), std::vector<Type>(),
module, getLocation());
std::vector<std::unique_ptr<Stmt>> body;
auto returnValue = getBody()->instantiate({});
body.push_back(llvm::make_unique<ReturnStmt>(std::move(returnValue), getBody()->getLocation()));
functionDecl->setBody(std::move(body));
return functionDecl;
}
26 changes: 24 additions & 2 deletions src/ast/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
namespace delta {

class Decl;
class FunctionDecl;
class Module;

enum class ExprKind {
VarExpr,
Expand All @@ -31,6 +33,7 @@ enum class ExprKind {
MemberExpr,
SubscriptExpr,
UnwrapExpr,
LambdaExpr
};

class Expr {
Expand All @@ -54,9 +57,11 @@ class Expr {
bool isMemberExpr() const { return getKind() == ExprKind::MemberExpr; }
bool isSubscriptExpr() const { return getKind() == ExprKind::SubscriptExpr; }
bool isUnwrapExpr() const { return getKind() == ExprKind::UnwrapExpr; }
bool isLambdaExpr() const { return getKind() == ExprKind::LambdaExpr; }

ExprKind getKind() const { return kind; }
Type getType() const { return type; }
bool hasType() const { return type.get() != nullptr; }
Type getType() const { ASSERT(type); return type; }
void setType(Type type) { ASSERT(type); this->type = type; }
bool isLvalue() const;
bool isRvalue() const { return !isLvalue(); }
Expand Down Expand Up @@ -326,7 +331,7 @@ class SubscriptExpr : public CallExpr {
};

/// A postfix expression that unwraps an optional (nullable) value, yielding the value wrapped by
// the optional, for example 'foo!'. If the optional is null, the operation triggers an assertion
/// the optional, for example 'foo!'. If the optional is null, the operation triggers an assertion
/// error (by default), or causes undefined behavior (in unchecked mode).
class UnwrapExpr : public Expr {
public:
Expand All @@ -339,4 +344,21 @@ class UnwrapExpr : public Expr {
std::unique_ptr<Expr> operand;
};

class LambdaExpr : public Expr {
public:
LambdaExpr(std::vector<ParamDecl>&& params, std::unique_ptr<Expr> body,
SourceLocation location)
: Expr(ExprKind::LambdaExpr, location), params(std::move(params)),
body(std::move(body)) {}
llvm::ArrayRef<ParamDecl> getParams() const { return params; }
llvm::MutableArrayRef<ParamDecl> getParams() { return params; }
Expr* getBody() const { return body.get(); }
std::unique_ptr<FunctionDecl> lower(Module& module) const;
static bool classof(const Expr* e) { return e->getKind() == ExprKind::LambdaExpr; }

private:
std::vector<ParamDecl> params;
std::unique_ptr<Expr> body;
};

}
19 changes: 19 additions & 0 deletions src/irgen/irgen-expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,23 @@ llvm::Value* IRGenerator::codegenUnwrapExpr(const UnwrapExpr& expr) {
return codegenExpr(expr.getOperand());
}

llvm::Value* IRGenerator::codegenLambdaExpr(const LambdaExpr& expr) {
auto functionDecl = expr.lower(*currentDecl->getModule());

auto insertBlockBackup = builder.GetInsertBlock();
auto insertPointBackup = builder.GetInsertPoint();
auto scopesBackup = std::move(scopes);

codegenFunctionDecl(*functionDecl);

scopes = std::move(scopesBackup);
if (insertBlockBackup) builder.SetInsertPoint(insertBlockBackup, insertPointBackup);

VarExpr varExpr(functionDecl->getName(), functionDecl->getLocation());
varExpr.setDecl(functionDecl.get());
return codegenVarExpr(varExpr);
}

llvm::Value* IRGenerator::codegenExpr(const Expr& expr) {
switch (expr.getKind()) {
case ExprKind::VarExpr: return codegenVarExpr(llvm::cast<VarExpr>(expr));
Expand All @@ -494,6 +511,7 @@ llvm::Value* IRGenerator::codegenExpr(const Expr& expr) {
case ExprKind::MemberExpr: return codegenMemberExpr(llvm::cast<MemberExpr>(expr));
case ExprKind::SubscriptExpr: return codegenSubscriptExpr(llvm::cast<SubscriptExpr>(expr));
case ExprKind::UnwrapExpr: return codegenUnwrapExpr(llvm::cast<UnwrapExpr>(expr));
case ExprKind::LambdaExpr: return codegenLambdaExpr(llvm::cast<LambdaExpr>(expr));
}
llvm_unreachable("all cases handled");
}
Expand All @@ -517,6 +535,7 @@ llvm::Value* IRGenerator::codegenLvalueExpr(const Expr& expr) {
case ExprKind::MemberExpr: return codegenLvalueMemberExpr(llvm::cast<MemberExpr>(expr));
case ExprKind::SubscriptExpr: return codegenLvalueSubscriptExpr(llvm::cast<SubscriptExpr>(expr));
case ExprKind::UnwrapExpr: return codegenUnwrapExpr(llvm::cast<UnwrapExpr>(expr));
case ExprKind::LambdaExpr: return codegenLambdaExpr(llvm::cast<LambdaExpr>(expr));
}
llvm_unreachable("all cases handled");
}
2 changes: 2 additions & 0 deletions src/irgen/irgen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ llvm::Value* IRGenerator::findValue(llvm::StringRef name, const Decl* decl) {
return value;
}

ASSERT(decl);

switch (decl->getKind()) {
case DeclKind::VarDecl:
return codegenVarDecl(*llvm::cast<VarDecl>(decl));
Expand Down
Loading

0 comments on commit a08e68e

Please sign in to comment.