From cef82491553a9b4105986dd46d148efda6da22ef Mon Sep 17 00:00:00 2001 From: "Alexander M. Turek" Date: Thu, 13 Jan 2022 00:06:16 +0100 Subject: [PATCH] Support enum cases as parameters --- lib/Doctrine/ORM/AbstractQuery.php | 5 ++ .../ORM/Query/ParameterTypeInferer.php | 14 ++++- phpcs.xml.dist | 6 ++ psalm.xml | 6 ++ .../Tests/Models/Enums/AccessLevel.php | 12 ++++ tests/Doctrine/Tests/Models/Enums/City.php | 12 ++++ .../Tests/Models/Enums/UserStatus.php | 11 ++++ .../Tests/ORM/Functional/QueryTest.php | 62 +++++++++++++++++++ .../ORM/Query/ParameterTypeInfererTest.php | 40 +++++++----- tests/Doctrine/Tests/ORM/Query/QueryTest.php | 26 ++++++++ 10 files changed, 178 insertions(+), 16 deletions(-) create mode 100644 tests/Doctrine/Tests/Models/Enums/AccessLevel.php create mode 100644 tests/Doctrine/Tests/Models/Enums/City.php create mode 100644 tests/Doctrine/Tests/Models/Enums/UserStatus.php diff --git a/lib/Doctrine/ORM/AbstractQuery.php b/lib/Doctrine/ORM/AbstractQuery.php index de37fe3bd75..990794a7e0a 100644 --- a/lib/Doctrine/ORM/AbstractQuery.php +++ b/lib/Doctrine/ORM/AbstractQuery.php @@ -4,6 +4,7 @@ namespace Doctrine\ORM; +use BackedEnum; use Countable; use Doctrine\Common\Cache\Psr6\CacheAdapter; use Doctrine\Common\Cache\Psr6\DoctrineProvider; @@ -424,6 +425,10 @@ public function processParameterValue($value) return $value->name; } + if ($value instanceof BackedEnum) { + return $value->value; + } + if (! is_object($value)) { return $value; } diff --git a/lib/Doctrine/ORM/Query/ParameterTypeInferer.php b/lib/Doctrine/ORM/Query/ParameterTypeInferer.php index 4fad8e1e6cb..3f4dee00969 100644 --- a/lib/Doctrine/ORM/Query/ParameterTypeInferer.php +++ b/lib/Doctrine/ORM/Query/ParameterTypeInferer.php @@ -4,6 +4,7 @@ namespace Doctrine\ORM\Query; +use BackedEnum; use DateInterval; use DateTimeImmutable; use DateTimeInterface; @@ -54,8 +55,19 @@ public static function inferType($value) return Types::DATEINTERVAL; } + if ($value instanceof BackedEnum) { + return is_int($value->value) + ? Types::INTEGER + : Types::STRING; + } + if (is_array($value)) { - return is_int(current($value)) + $firstValue = current($value); + if ($firstValue instanceof BackedEnum) { + $firstValue = $firstValue->value; + } + + return is_int($firstValue) ? Connection::PARAM_INT_ARRAY : Connection::PARAM_STR_ARRAY; } diff --git a/phpcs.xml.dist b/phpcs.xml.dist index d598ac25c4c..2a30cf7d39d 100644 --- a/phpcs.xml.dist +++ b/phpcs.xml.dist @@ -268,10 +268,16 @@ + tests/Doctrine/Tests/Models/Enums/AccessLevel.php + tests/Doctrine/Tests/Models/Enums/City.php tests/Doctrine/Tests/Models/Enums/Suit.php + tests/Doctrine/Tests/Models/Enums/UserStatus.php + tests/Doctrine/Tests/Models/Enums/AccessLevel.php + tests/Doctrine/Tests/Models/Enums/City.php tests/Doctrine/Tests/Models/Enums/Suit.php + tests/Doctrine/Tests/Models/Enums/UserStatus.php diff --git a/psalm.xml b/psalm.xml index b83e3f2a23d..23b2feee7b0 100644 --- a/psalm.xml +++ b/psalm.xml @@ -90,6 +90,12 @@ + + + + + + diff --git a/tests/Doctrine/Tests/Models/Enums/AccessLevel.php b/tests/Doctrine/Tests/Models/Enums/AccessLevel.php new file mode 100644 index 00000000000..e6de8a0a24a --- /dev/null +++ b/tests/Doctrine/Tests/Models/Enums/AccessLevel.php @@ -0,0 +1,12 @@ +getSingleResult(); } + /** + * @requires PHP 8.1 + */ + public function testUseStringEnumCaseAsParameter(): void + { + $user = new CmsUser(); + $user->name = 'John'; + $user->username = 'john'; + $user->status = 'inactive'; + $this->_em->persist($user); + + $user = new CmsUser(); + $user->name = 'Jane'; + $user->username = 'jane'; + $user->status = 'active'; + $this->_em->persist($user); + + unset($user); + + $this->_em->flush(); + $this->_em->clear(); + + $result = $this->_em->createQuery('SELECT u FROM ' . CmsUser::class . ' u WHERE u.status = :status') + ->setParameter('status', UserStatus::Active) + ->getResult(); + + self::assertCount(1, $result); + self::assertSame('jane', $result[0]->username); + } + + /** + * @requires PHP 8.1 + */ + public function testUseIntegerEnumCaseAsParameter(): void + { + $user = new CmsUser(); + $user->name = 'John'; + $user->username = 'john'; + $user->status = '1'; + $this->_em->persist($user); + + $user = new CmsUser(); + $user->name = 'Jane'; + $user->username = 'jane'; + $user->status = '2'; + $this->_em->persist($user); + + unset($user); + + $this->_em->flush(); + $this->_em->clear(); + + $result = $this->_em->createQuery('SELECT u FROM ' . CmsUser::class . ' u WHERE u.status = :status') + ->setParameter('status', AccessLevel::User) + ->getResult(); + + self::assertCount(1, $result); + self::assertSame('jane', $result[0]->username); + } + public function testSetParameters(): void { $parameters = new ArrayCollection(); diff --git a/tests/Doctrine/Tests/ORM/Query/ParameterTypeInfererTest.php b/tests/Doctrine/Tests/ORM/Query/ParameterTypeInfererTest.php index d3acdd827f7..291e46d1520 100644 --- a/tests/Doctrine/Tests/ORM/Query/ParameterTypeInfererTest.php +++ b/tests/Doctrine/Tests/ORM/Query/ParameterTypeInfererTest.php @@ -11,26 +11,36 @@ use Doctrine\DBAL\ParameterType; use Doctrine\DBAL\Types\Types; use Doctrine\ORM\Query\ParameterTypeInferer; +use Doctrine\Tests\Models\Enums\AccessLevel; +use Doctrine\Tests\Models\Enums\UserStatus; use Doctrine\Tests\OrmTestCase; +use Generator; + +use const PHP_VERSION_ID; class ParameterTypeInfererTest extends OrmTestCase { - /** @psalm-return list */ - public function providerParameterTypeInferer(): array + /** @psalm-return Generator */ + public function providerParameterTypeInferer(): Generator { - return [ - [1, Types::INTEGER], - ['bar', ParameterType::STRING], - ['1', ParameterType::STRING], - [new DateTime(), Types::DATETIME_MUTABLE], - [new DateTimeImmutable(), Types::DATETIME_IMMUTABLE], - [new DateInterval('P1D'), Types::DATEINTERVAL], - [[2], Connection::PARAM_INT_ARRAY], - [['foo'], Connection::PARAM_STR_ARRAY], - [['1','2'], Connection::PARAM_STR_ARRAY], - [[], Connection::PARAM_STR_ARRAY], - [true, Types::BOOLEAN], - ]; + yield 'integer' => [1, Types::INTEGER]; + yield 'string' => ['bar', ParameterType::STRING]; + yield 'numeric_string' => ['1', ParameterType::STRING]; + yield 'datetime_object' => [new DateTime(), Types::DATETIME_MUTABLE]; + yield 'datetime_immutable_object' => [new DateTimeImmutable(), Types::DATETIME_IMMUTABLE]; + yield 'date_interval_object' => [new DateInterval('P1D'), Types::DATEINTERVAL]; + yield 'array_of_int' => [[2], Connection::PARAM_INT_ARRAY]; + yield 'array_of_string' => [['foo'], Connection::PARAM_STR_ARRAY]; + yield 'array_of_numeric_string' => [['1', '2'], Connection::PARAM_STR_ARRAY]; + yield 'empty_array' => [[], Connection::PARAM_STR_ARRAY]; + yield 'boolean' => [true, Types::BOOLEAN]; + + if (PHP_VERSION_ID >= 80100) { + yield 'int_backed_enum' => [AccessLevel::Admin, Types::INTEGER]; + yield 'string_backed_enum' => [UserStatus::Active, Types::STRING]; + yield 'array_of_int_backed_enum' => [[AccessLevel::Admin], Connection::PARAM_INT_ARRAY]; + yield 'array_of_string_backed_enum' => [[UserStatus::Active], Connection::PARAM_STR_ARRAY]; + } } /** diff --git a/tests/Doctrine/Tests/ORM/Query/QueryTest.php b/tests/Doctrine/Tests/ORM/Query/QueryTest.php index 31762caa486..2bbecc69555 100644 --- a/tests/Doctrine/Tests/ORM/Query/QueryTest.php +++ b/tests/Doctrine/Tests/ORM/Query/QueryTest.php @@ -24,15 +24,21 @@ use Doctrine\Tests\Models\CMS\CmsAddress; use Doctrine\Tests\Models\CMS\CmsGroup; use Doctrine\Tests\Models\CMS\CmsUser; +use Doctrine\Tests\Models\Enums\AccessLevel; +use Doctrine\Tests\Models\Enums\City; +use Doctrine\Tests\Models\Enums\UserStatus; use Doctrine\Tests\Models\Generic\DateTimeModel; use Doctrine\Tests\OrmTestCase; use Generator; use Psr\Cache\CacheItemPoolInterface; use Symfony\Component\Cache\Adapter\ArrayAdapter; +use function array_map; use function assert; use function method_exists; +use const PHP_VERSION_ID; + class QueryTest extends OrmTestCase { /** @var EntityManagerMock */ @@ -236,6 +242,9 @@ public function testCollectionParameters(): void self::assertEquals($cities, $parameter->getValue()); } + /** + * @psalm-return Generator + */ public function provideProcessParameterValueIterable(): Generator { $baseArray = [ @@ -251,6 +260,10 @@ public function provideProcessParameterValueIterable(): Generator yield 'simple_array' => [$baseArray]; yield 'doctrine_collection' => [new ArrayCollection($baseArray)]; yield 'generator' => [$gen()]; + + if (PHP_VERSION_ID >= 80100) { + yield 'array_of_enum' => [array_map([City::class, 'from'], $baseArray)]; + } } /** @@ -322,6 +335,19 @@ public function testProcessParameterValueNull(): void self::assertNull($query->processParameterValue(null)); } + /** + * @requires PHP 8.1 + */ + public function testProcessParameterValueBackedEnum(): void + { + $query = $this->entityManager->createQuery('SELECT u FROM Doctrine\Tests\Models\CMS\CmsUser u WHERE u.status = :status'); + + self::assertSame('active', $query->processParameterValue(UserStatus::Active)); + self::assertSame(2, $query->processParameterValue(AccessLevel::User)); + self::assertSame(['active'], $query->processParameterValue([UserStatus::Active])); + self::assertSame([2], $query->processParameterValue([AccessLevel::User])); + } + public function testDefaultQueryHints(): void { $config = $this->entityManager->getConfiguration();