From 6339dffef7c3168ebf705b7737e0022d0038d78f Mon Sep 17 00:00:00 2001 From: Jan Nedbal Date: Fri, 12 Jul 2024 13:02:43 +0200 Subject: [PATCH] Fix detection of aggregate functions inside custom functions --- ...eryAggregateFunctionDetectorTreeWalker.php | 295 ++---------------- .../Doctrine/Query/QueryResultTypeWalker.php | 4 + ...eryResultTypeWalkerFetchTypeMatrixTest.php | 33 ++ .../TypedExpressionIntegerWrapFunction.php | 38 +++ 4 files changed, 97 insertions(+), 273 deletions(-) create mode 100644 tests/Platform/TypedExpressionIntegerWrapFunction.php diff --git a/src/Type/Doctrine/Query/QueryAggregateFunctionDetectorTreeWalker.php b/src/Type/Doctrine/Query/QueryAggregateFunctionDetectorTreeWalker.php index 11af2086..c82527a3 100644 --- a/src/Type/Doctrine/Query/QueryAggregateFunctionDetectorTreeWalker.php +++ b/src/Type/Doctrine/Query/QueryAggregateFunctionDetectorTreeWalker.php @@ -4,7 +4,7 @@ use Doctrine\ORM\Query; use Doctrine\ORM\Query\AST; -use function is_string; +use function is_array; class QueryAggregateFunctionDetectorTreeWalker extends Query\TreeWalkerAdapter { @@ -13,294 +13,38 @@ class QueryAggregateFunctionDetectorTreeWalker extends Query\TreeWalkerAdapter public function walkSelectStatement(AST\SelectStatement $selectStatement): void { - $this->doWalkSelectClause($selectStatement->selectClause); + $this->walkNode($selectStatement->selectClause); } /** - * @param AST\SelectClause $selectClause + * @param mixed $node */ - public function doWalkSelectClause($selectClause): void + public function walkNode($node): void { - foreach ($selectClause->selectExpressions as $selectExpression) { - $this->doWalkSelectExpression($selectExpression); - } - } - - /** - * @param AST\SelectExpression $selectExpression - */ - public function doWalkSelectExpression($selectExpression): void - { - $this->doWalkNode($selectExpression->expression); - } - - /** - * @param mixed $expr - */ - private function doWalkNode($expr): void - { - if ($expr instanceof AST\AggregateExpression) { - $this->markAggregateFunctionFound(); - - } elseif ($expr instanceof AST\Functions\FunctionNode) { - if ($this->isAggregateFunction($expr)) { - $this->markAggregateFunctionFound(); - } - - } elseif ($expr instanceof AST\SimpleArithmeticExpression) { - foreach ($expr->arithmeticTerms as $term) { - $this->doWalkArithmeticTerm($term); - } - - } elseif ($expr instanceof AST\ArithmeticTerm) { - $this->doWalkArithmeticTerm($expr); - - } elseif ($expr instanceof AST\ArithmeticFactor) { - $this->doWalkArithmeticFactor($expr); - - } elseif ($expr instanceof AST\ParenthesisExpression) { - $this->doWalkArithmeticPrimary($expr->expression); - - } elseif ($expr instanceof AST\NullIfExpression) { - $this->doWalkNullIfExpression($expr); - - } elseif ($expr instanceof AST\CoalesceExpression) { - $this->doWalkCoalesceExpression($expr); - - } elseif ($expr instanceof AST\GeneralCaseExpression) { - $this->doWalkGeneralCaseExpression($expr); - - } elseif ($expr instanceof AST\SimpleCaseExpression) { - $this->doWalkSimpleCaseExpression($expr); - - } elseif ($expr instanceof AST\ArithmeticExpression) { - $this->doWalkArithmeticExpression($expr); - - } elseif ($expr instanceof AST\ComparisonExpression) { - $this->doWalkComparisonExpression($expr); - - } elseif ($expr instanceof AST\BetweenExpression) { - $this->doWalkBetweenExpression($expr); - } - } - - public function doWalkCoalesceExpression(AST\CoalesceExpression $coalesceExpression): void - { - foreach ($coalesceExpression->scalarExpressions as $scalarExpression) { - $this->doWalkSimpleArithmeticExpression($scalarExpression); - } - } - - public function doWalkNullIfExpression(AST\NullIfExpression $nullIfExpression): void - { - if (!is_string($nullIfExpression->firstExpression)) { - $this->doWalkSimpleArithmeticExpression($nullIfExpression->firstExpression); - } - - if (is_string($nullIfExpression->secondExpression)) { + if (!$node instanceof AST\Node) { return; } - $this->doWalkSimpleArithmeticExpression($nullIfExpression->secondExpression); - } - - public function doWalkGeneralCaseExpression(AST\GeneralCaseExpression $generalCaseExpression): void - { - foreach ($generalCaseExpression->whenClauses as $whenClause) { - $this->doWalkConditionalExpression($whenClause->caseConditionExpression); - $this->doWalkSimpleArithmeticExpression($whenClause->thenScalarExpression); - } - - $this->doWalkSimpleArithmeticExpression($generalCaseExpression->elseScalarExpression); - } - - public function doWalkSimpleCaseExpression(AST\SimpleCaseExpression $simpleCaseExpression): void - { - foreach ($simpleCaseExpression->simpleWhenClauses as $simpleWhenClause) { - $this->doWalkSimpleArithmeticExpression($simpleWhenClause->caseScalarExpression); - $this->doWalkSimpleArithmeticExpression($simpleWhenClause->thenScalarExpression); - } - - $this->doWalkSimpleArithmeticExpression($simpleCaseExpression->elseScalarExpression); - } - - /** - * @param AST\ConditionalExpression|AST\Phase2OptimizableConditional $condExpr - */ - public function doWalkConditionalExpression($condExpr): void - { - if (!$condExpr instanceof AST\ConditionalExpression) { - $this->doWalkConditionalTerm($condExpr); // @phpstan-ignore-line PHPStan do not read @psalm-inheritors of Phase2OptimizableConditional - return; - } - - foreach ($condExpr->conditionalTerms as $conditionalTerm) { - $this->doWalkConditionalTerm($conditionalTerm); - } - } - - /** - * @param AST\ConditionalTerm|AST\ConditionalPrimary|AST\ConditionalFactor $condTerm - */ - public function doWalkConditionalTerm($condTerm): void - { - if (!$condTerm instanceof AST\ConditionalTerm) { - $this->doWalkConditionalFactor($condTerm); + if ($this->isAggregateFunction($node)) { + $this->markAggregateFunctionFound(); return; } - foreach ($condTerm->conditionalFactors as $conditionalFactor) { - $this->doWalkConditionalFactor($conditionalFactor); - } - } + foreach ((array) $node as $property) { + if ($property instanceof AST\Node) { + $this->walkNode($property); + } - /** - * @param AST\ConditionalFactor|AST\ConditionalPrimary $factor - */ - public function doWalkConditionalFactor($factor): void - { - if (!$factor instanceof AST\ConditionalFactor) { - $this->doWalkConditionalPrimary($factor); - } else { - $this->doWalkConditionalPrimary($factor->conditionalPrimary); - } - } + if (is_array($property)) { + foreach ($property as $propertyValue) { + $this->walkNode($propertyValue); + } + } - /** - * @param AST\ConditionalPrimary $primary - */ - public function doWalkConditionalPrimary($primary): void - { - if ($primary->isSimpleConditionalExpression()) { - if ($primary->simpleConditionalExpression instanceof AST\ComparisonExpression) { - $this->doWalkComparisonExpression($primary->simpleConditionalExpression); + if ($this->wasAggregateFunctionFound()) { return; } - $this->doWalkNode($primary->simpleConditionalExpression); } - - if (!$primary->isConditionalExpression()) { - return; - } - - if ($primary->conditionalExpression === null) { - return; - } - - $this->doWalkConditionalExpression($primary->conditionalExpression); - } - - /** - * @param AST\BetweenExpression $betweenExpr - */ - public function doWalkBetweenExpression($betweenExpr): void - { - $this->doWalkArithmeticExpression($betweenExpr->expression); - $this->doWalkArithmeticExpression($betweenExpr->leftBetweenExpression); - $this->doWalkArithmeticExpression($betweenExpr->rightBetweenExpression); - } - - /** - * @param AST\ComparisonExpression $compExpr - */ - public function doWalkComparisonExpression($compExpr): void - { - $leftExpr = $compExpr->leftExpression; - $rightExpr = $compExpr->rightExpression; - - if ($leftExpr instanceof AST\Node) { - $this->doWalkNode($leftExpr); - } - - if (!($rightExpr instanceof AST\Node)) { - return; - } - - $this->doWalkNode($rightExpr); - } - - /** - * @param AST\ArithmeticExpression $arithmeticExpr - */ - public function doWalkArithmeticExpression($arithmeticExpr): void - { - if (!$arithmeticExpr->isSimpleArithmeticExpression()) { - return; - } - - if ($arithmeticExpr->simpleArithmeticExpression === null) { - return; - } - - $this->doWalkSimpleArithmeticExpression($arithmeticExpr->simpleArithmeticExpression); - } - - /** - * @param AST\Node|string $simpleArithmeticExpr - */ - public function doWalkSimpleArithmeticExpression($simpleArithmeticExpr): void - { - if (!$simpleArithmeticExpr instanceof AST\SimpleArithmeticExpression) { - $this->doWalkArithmeticTerm($simpleArithmeticExpr); - return; - } - - foreach ($simpleArithmeticExpr->arithmeticTerms as $term) { - $this->doWalkArithmeticTerm($term); - } - } - - /** - * @param AST\Node|string $term - */ - public function doWalkArithmeticTerm($term): void - { - if (is_string($term)) { - return; - } - - if (!$term instanceof AST\ArithmeticTerm) { - $this->doWalkArithmeticFactor($term); - return; - } - - foreach ($term->arithmeticFactors as $factor) { - $this->doWalkArithmeticFactor($factor); - } - } - - /** - * @param AST\Node|string $factor - */ - public function doWalkArithmeticFactor($factor): void - { - if (is_string($factor)) { - return; - } - - if (!$factor instanceof AST\ArithmeticFactor) { - $this->doWalkArithmeticPrimary($factor); - return; - } - - $this->doWalkArithmeticPrimary($factor->arithmeticPrimary); - } - - /** - * @param AST\Node|string $primary - */ - public function doWalkArithmeticPrimary($primary): void - { - if ($primary instanceof AST\SimpleArithmeticExpression) { - $this->doWalkSimpleArithmeticExpression($primary); - return; - } - - if (!($primary instanceof AST\Node)) { - return; - } - - $this->doWalkNode($primary); } private function isAggregateFunction(AST\Node $node): bool @@ -318,4 +62,9 @@ private function markAggregateFunctionFound(): void $this->_getQuery()->setHint(self::HINT_HAS_AGGREGATE_FUNCTION, true); } + private function wasAggregateFunctionFound(): bool + { + return $this->_getQuery()->hasHint(self::HINT_HAS_AGGREGATE_FUNCTION); + } + } diff --git a/src/Type/Doctrine/Query/QueryResultTypeWalker.php b/src/Type/Doctrine/Query/QueryResultTypeWalker.php index 5157a631..2ce3e1ce 100644 --- a/src/Type/Doctrine/Query/QueryResultTypeWalker.php +++ b/src/Type/Doctrine/Query/QueryResultTypeWalker.php @@ -1226,6 +1226,10 @@ public function walkSelectExpression($selectExpression): string $this->resolveDoctrineType($dbalTypeName, null, TypeCombinator::containsNull($type)) ); + if ($this->hasAggregateWithoutGroupBy() && !$expr instanceof AST\Functions\CountFunction) { + $type = TypeCombinator::addNull($type); + } + } else { // Expressions default to Doctrine's StringType, whose // convertToPHPValue() is a no-op. So the actual type depends on diff --git a/tests/Platform/QueryResultTypeWalkerFetchTypeMatrixTest.php b/tests/Platform/QueryResultTypeWalkerFetchTypeMatrixTest.php index dcf9221f..4df903c0 100644 --- a/tests/Platform/QueryResultTypeWalkerFetchTypeMatrixTest.php +++ b/tests/Platform/QueryResultTypeWalkerFetchTypeMatrixTest.php @@ -3961,6 +3961,38 @@ public static function provideCases(): iterable 'stringify' => self::STRINGIFY_DEFAULT, ]; + yield 'INT_WRAP(MIN(t.col_float)) + no data' => [ + 'data' => self::dataNone(), + 'select' => 'SELECT INT_WRAP(MIN(t.col_float)) FROM %s t', + 'mysql' => self::intOrNull(), + 'sqlite' => self::intOrNull(), + 'pdo_pgsql' => self::intOrNull(), + 'pgsql' => self::intOrNull(), + 'mssql' => self::intOrNull(), + 'mysqlResult' => null, + 'sqliteResult' => null, + 'pdoPgsqlResult' => null, + 'pgsqlResult' => null, + 'mssqlResult' => null, + 'stringify' => self::STRINGIFY_NONE, + ]; + + yield 'INT_WRAP(MIN(t.col_float))' => [ + 'data' => self::dataDefault(), + 'select' => 'SELECT INT_WRAP(MIN(t.col_float)) FROM %s t', + 'mysql' => self::intOrNull(), + 'sqlite' => self::intOrNull(), + 'pdo_pgsql' => self::intOrNull(), + 'pgsql' => self::intOrNull(), + 'mssql' => self::intOrNull(), + 'mysqlResult' => 0, + 'sqliteResult' => 0, + 'pdoPgsqlResult' => 0, + 'pgsqlResult' => 0, + 'mssqlResult' => 0, + 'stringify' => self::STRINGIFY_NONE, + ]; + yield 'COALESCE(t.col_datetime, t.col_datetime)' => [ 'data' => self::dataDefault(), 'select' => 'SELECT COALESCE(t.col_datetime, t.col_datetime) FROM %s t', @@ -5018,6 +5050,7 @@ private function createOrmConfig(): Configuration $config->addCustomStringFunction('INT_PI', TypedExpressionIntegerPiFunction::class); $config->addCustomStringFunction('BOOL_PI', TypedExpressionBooleanPiFunction::class); $config->addCustomStringFunction('STRING_PI', TypedExpressionStringPiFunction::class); + $config->addCustomStringFunction('INT_WRAP', TypedExpressionIntegerWrapFunction::class); return $config; } diff --git a/tests/Platform/TypedExpressionIntegerWrapFunction.php b/tests/Platform/TypedExpressionIntegerWrapFunction.php new file mode 100644 index 00000000..e46d39bc --- /dev/null +++ b/tests/Platform/TypedExpressionIntegerWrapFunction.php @@ -0,0 +1,38 @@ +walkArithmeticPrimary($this->expr) . ')'; + } + + public function parse(Parser $parser): void + { + $parser->match(TokenType::T_IDENTIFIER); + $parser->match(TokenType::T_OPEN_PARENTHESIS); + $this->expr = $parser->ArithmeticPrimary(); + $parser->match(TokenType::T_CLOSE_PARENTHESIS); + } + + public function getReturnType(): Type + { + return Type::getType(Types::INTEGER); + } + +}