Skip to content

Commit

Permalink
fix nested calls with complex return in condition
Browse files Browse the repository at this point in the history
  • Loading branch information
mhasel committed Jan 9, 2025
1 parent 0dc2923 commit 89b7201
Show file tree
Hide file tree
Showing 8 changed files with 316 additions and 97 deletions.
27 changes: 14 additions & 13 deletions compiler/plc_ast/src/mut_visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -208,17 +208,6 @@ pub trait AstVisitorMut: Sized {
fn visit_allocation(&mut self, _node: &mut AstNode) {}
}

/// Helper method that walks through a slice of `ConditionalBlock` and applies the visitor's `walk` method to each node.
fn walk_conditional_blocks<V>(visitor: &mut V, blocks: &mut [ConditionalBlock])
where
V: AstVisitorMut,
{
for b in blocks {
visit_nodes!(visitor, &mut b.condition);
visit_all_nodes_mut!(visitor, &mut b.body);
}
}

impl WalkerMut for AstLiteral {
fn walk<V>(&mut self, _visitor: &mut V)
where
Expand Down Expand Up @@ -321,14 +310,26 @@ impl WalkerMut for CallStatement {
}
}

impl WalkerMut for Vec<ConditionalBlock> {
fn walk<V>(&mut self, visitor: &mut V)
where
V: AstVisitorMut,
{
for b in self {
visit_nodes!(visitor, &mut b.condition);
visit_all_nodes_mut!(visitor, &mut b.body);
}
}
}

impl WalkerMut for AstControlStatement {
fn walk<V>(&mut self, visitor: &mut V)
where
V: AstVisitorMut,
{
match self {
AstControlStatement::If(stmt) => {
walk_conditional_blocks(visitor, &mut stmt.blocks);
stmt.blocks.walk(visitor);
visit_all_nodes_mut!(visitor, &mut stmt.else_block);
}
AstControlStatement::WhileLoop(stmt) | AstControlStatement::RepeatLoop(stmt) => {
Expand All @@ -342,7 +343,7 @@ impl WalkerMut for AstControlStatement {
}
AstControlStatement::Case(stmt) => {
visit_nodes!(visitor, &mut stmt.selector);
walk_conditional_blocks(visitor, &mut stmt.case_blocks);
stmt.case_blocks.walk(visitor);
visit_all_nodes_mut!(visitor, &mut stmt.else_block);
}
}
Expand Down
Binary file removed demo
Binary file not shown.
Empty file removed oom
Empty file.
118 changes: 101 additions & 17 deletions src/lowering/calls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,27 @@ use plc_ast::{
ast::{
steal_expression_list, AccessModifier, Allocation, Assignment, AstFactory, AstNode, AstStatement,
CallStatement, CompilationUnit, LinkageType, Pou, Variable, VariableBlock, VariableBlockType,
}, control_statements::{AstControlStatement, ConditionalBlock}, mut_visitor::{AstVisitorMut, WalkerMut}, provider::IdProvider, try_from, try_from_mut
},
control_statements::{AstControlStatement, ConditionalBlock},
mut_visitor::{AstVisitorMut, WalkerMut},
provider::IdProvider,
try_from_mut,
};
use plc_source::source_location::SourceLocation;

use crate::{index::Index, resolver::AnnotationMap};

#[derive(Default, Debug, Clone)]
struct VisitorContext {
is_switch_case: bool,
}

impl VisitorContext {
fn switch_case() -> Self {
Self { is_switch_case: true }
}
}

// Performs lowering for aggregate types defined in functions
#[derive(Default)]
pub struct AggregateTypeLowerer {
Expand All @@ -25,6 +40,7 @@ pub struct AggregateTypeLowerer {
// New statements to be added to the outer scope, i.e. when lowering a conditional block
outer_scope_stmts: Vec<AstNode>,
counter: AtomicI32,
ctx: VisitorContext,
}

impl AggregateTypeLowerer {
Expand All @@ -50,19 +66,32 @@ impl AggregateTypeLowerer {

fn walk_conditional_blocks(&mut self, blocks: &mut Vec<ConditionalBlock>) {
for b in blocks {
let condition = std::mem::take(b.condition.as_mut());
let mut processed_nodes = Box::new(self.map(condition));
// fixme: this breaks SWITCH statements?
if let Some(expressions) = try_from_mut!(processed_nodes, Vec<AstNode>) {
b.condition = Box::new(expressions.pop().expect("Should have at least one expression"));
let expressions = std::mem::take(expressions);
self.outer_scope_stmts.extend(expressions);
if self.ctx.is_switch_case {
b.condition.walk(self);
} else {
b.condition = processed_nodes;
let condition = std::mem::take(b.condition.as_mut());
let mut processed_nodes = Box::new(self.map(condition));
if let Some(expressions) = try_from_mut!(processed_nodes, Vec<AstNode>) {
b.condition = Box::new(expressions.pop().expect("Should have at least one expression"));
let expressions = std::mem::take(expressions);
self.outer_scope_stmts.extend(expressions);
} else {
b.condition = processed_nodes;
}
}
self.steal_and_walk_list(&mut b.body);
}
}

fn walk_with_context<T>(&mut self, t: &mut T, ctx: VisitorContext, f: impl Fn(&mut Self, &mut T))
where
T: WalkerMut,
{
let old = self.ctx.clone();
self.ctx = ctx;
f(self, t);
self.ctx = old;
}
}

impl AstVisitorMut for AggregateTypeLowerer {
Expand Down Expand Up @@ -228,11 +257,15 @@ impl AstVisitorMut for AggregateTypeLowerer {
}
AstControlStatement::Case(stmt) => {
stmt.selector.walk(self);
self.walk_conditional_blocks(&mut stmt.case_blocks);
self.walk_with_context(
&mut stmt.case_blocks,
VisitorContext::switch_case(),
Self::walk_conditional_blocks,
);
self.steal_and_walk_list(&mut stmt.else_block);
}
}

if !self.outer_scope_stmts.is_empty() {
let mut new_stmts = std::mem::take(&mut self.outer_scope_stmts);
let location = node.get_location();
Expand Down Expand Up @@ -280,11 +313,7 @@ mod tests {
"#,
);

let mut lowerer = AggregateTypeLowerer {
index: Some(index),
annotation: None,
..Default::default()
};
let mut lowerer = AggregateTypeLowerer { index: Some(index), annotation: None, ..Default::default() };
lowerer.visit_compilation_unit(&mut unit);
lowerer.index.replace(indexer::index(&unit));
assert_eq!(unit, original_unit);
Expand Down Expand Up @@ -756,7 +785,7 @@ mod tests {
id_provider.clone(),
);

let mut lowerer = AggregateTypeLowerer {
let mut lowerer = AggregateTypeLowerer {
index: Some(index),
annotation: None,
id_provider: id_provider.clone(),
Expand Down Expand Up @@ -933,4 +962,59 @@ mod tests {
assert_debug_snapshot!(index.find_pou_type("MID__STRING").unwrap());
assert_debug_snapshot!(units[0].0.implementations[1]);
}

#[test]
fn nested_complex_calls_in_if_condition() {
let id_provider = IdProvider::default();
let src = r#"
FUNCTION CLEAN : STRING
VAR_INPUT
CX : STRING;
END_VAR
VAR
pos: INT := 1;
END_VAR
IF FIND(CX, MID(CLEAN, 1, pos)) > 0 THEN
pos := pos + 1;
END_IF;
END_FUNCTION
FUNCTION FIND<T: ANY_STRING> : INT
VAR_INPUT
needle: T;
haystack: T;
END_VAR
END_FUNCTION
{external}
FUNCTION FIND__STRING : INT
VAR_INPUT
needle: STRING;
haystack: STRING;
END_VAR
END_FUNCTION
FUNCTION MID<T: ANY_STRING> : T
VAR_INPUT
str: T;
len: INT;
start: INT;
END_VAR
END_FUNCTION
{external}
FUNCTION MID__STRING : STRING
VAR_INPUT
str: STRING;
len: INT;
start: INT;
END_VAR
END_FUNCTION
"#;

let (unit, index, ..) = index_and_lower(src, id_provider.clone());
let (_, _, units) = annotate_and_lower_with_ids(unit, index, id_provider.clone());
let unit = &units[0].0;
assert_debug_snapshot!(unit.implementations[0]);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
---
source: src/lowering/calls.rs
expression: "unit.implementations[0]"
snapshot_kind: text
---
Implementation {
name: "CLEAN",
type_name: "CLEAN",
linkage: Internal,
pou_type: Function,
statements: [
ExpressionList {
expressions: [
Allocation {
name: "__MID0",
reference_type: "STRING",
},
CallStatement {
operator: ReferenceExpr {
kind: Member(
Identifier {
name: "MID",
},
),
base: None,
},
parameters: Some(
ExpressionList {
expressions: [
ReferenceExpr {
kind: Member(
Identifier {
name: "__MID0",
},
),
base: None,
},
ReferenceExpr {
kind: Member(
Identifier {
name: "CLEAN",
},
),
base: None,
},
LiteralInteger {
value: 1,
},
ReferenceExpr {
kind: Member(
Identifier {
name: "pos",
},
),
base: None,
},
],
},
),
},
IfStatement {
blocks: [
ConditionalBlock {
condition: BinaryExpression {
operator: Greater,
left: CallStatement {
operator: ReferenceExpr {
kind: Member(
Identifier {
name: "FIND",
},
),
base: None,
},
parameters: Some(
ExpressionList {
expressions: [
ReferenceExpr {
kind: Member(
Identifier {
name: "CX",
},
),
base: None,
},
ReferenceExpr {
kind: Member(
Identifier {
name: "__MID0",
},
),
base: None,
},
],
},
),
},
right: LiteralInteger {
value: 0,
},
},
body: [
Assignment {
left: ReferenceExpr {
kind: Member(
Identifier {
name: "pos",
},
),
base: None,
},
right: BinaryExpression {
operator: Plus,
left: ReferenceExpr {
kind: Member(
Identifier {
name: "pos",
},
),
base: None,
},
right: LiteralInteger {
value: 1,
},
},
},
],
},
],
else_block: [],
},
],
},
EmptyStatement,
],
location: SourceLocation {
span: Range(
TextLocation {
line: 8,
column: 16,
offset: 191,
}..TextLocation {
line: 10,
column: 23,
offset: 291,
},
),
},
name_location: SourceLocation {
span: Range(
TextLocation {
line: 1,
column: 21,
offset: 22,
}..TextLocation {
line: 1,
column: 26,
offset: 27,
},
),
},
overriding: false,
generic: false,
access: None,
}
Loading

0 comments on commit 89b7201

Please sign in to comment.