Skip to content

Commit

Permalink
Fix dpctl.lsplatform() to output in jupyter notebook (#800)
Browse files Browse the repository at this point in the history
* Add DPCTLPlatformMgr_GetInfo function

* Fix lsplatform function to use DPCTLPlatformMgr_GetInfo

* Added tests for DPCTLPlatformMgr_GetInfo

* Added tests for handlign of null PRef in DPCTLPlatformMgr_GetInfo

Co-authored-by: Oleksandr Pavlyk <[email protected]>
  • Loading branch information
oleksandr-pavlyk authored Mar 28, 2022
2 parents a20344b + 49e3419 commit a94e63b
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 4 deletions.
1 change: 1 addition & 0 deletions dpctl/_backend.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@ cdef extern from "syclinterface/dpctl_sycl_platform_manager.h":
DPCTLPlatformVectorRef,
size_t index)
cdef void DPCTLPlatformMgr_PrintInfo(const DPCTLSyclPlatformRef, size_t)
cdef const char *DPCTLPlatformMgr_GetInfo(const DPCTLSyclPlatformRef, size_t)


cdef extern from "syclinterface/dpctl_sycl_platform_interface.h":
Expand Down
7 changes: 6 additions & 1 deletion dpctl/_sycl_platform.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ from ._backend cimport ( # noqa: E211
DPCTLPlatform_GetPlatforms,
DPCTLPlatform_GetVendor,
DPCTLPlatform_GetVersion,
DPCTLPlatformMgr_GetInfo,
DPCTLPlatformMgr_PrintInfo,
DPCTLPlatformVector_Delete,
DPCTLPlatformVector_GetAt,
Expand Down Expand Up @@ -323,6 +324,7 @@ def lsplatform(verbosity=0):
cdef DPCTLPlatformVectorRef PVRef = NULL
cdef size_t v = 0
cdef size_t size = 0
cdef const char * info_str = NULL
cdef DPCTLSyclPlatformRef PRef = NULL

if not isinstance(verbosity, int):
Expand All @@ -347,8 +349,11 @@ def lsplatform(verbosity=0):
if v != 0:
print("Platform ", i, "::")
PRef = DPCTLPlatformVector_GetAt(PVRef, i)
DPCTLPlatformMgr_PrintInfo(PRef, v)
info_str = DPCTLPlatformMgr_GetInfo(PRef,v)
py_info = <bytes> info_str
DPCTLCString_Delete(info_str)
DPCTLPlatform_Delete(PRef)
print(py_info.decode("utf-8"),end='')
DPCTLPlatformVector_Delete(PVRef)


Expand Down
24 changes: 24 additions & 0 deletions libsyclinterface/include/dpctl_sycl_platform_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,30 @@ DPCTL_API
void DPCTLPlatformMgr_PrintInfo(__dpctl_keep const DPCTLSyclPlatformRef PRef,
size_t verbosity);

/*!
* @brief Returns a set of platform info attributes as a string.
*
* The helper function is used to get metadata about a given platform. The
* amount of information received is controlled by the verbosity level.
*
* Verbosity level 0: Returns only the name of the platform.
* Verbosity level 1: Returns the name, version, vendor, backend, number of
* devices in the platform.
* Verbosity level 2: Returns everything in level 1 and also returns the name,
* version, and filter string for each device in the
* platform.
*
* @param PRef A #DPCTLSyclPlatformRef opaque pointer.
* @param verbosity Verbosilty level to control how much information is
* printed out.
* @return A formatted C string capturing the information about the
* sycl::platform argument.
*/
DPCTL_API
__dpctl_give const char *
DPCTLPlatformMgr_GetInfo(__dpctl_keep const DPCTLSyclPlatformRef PRef,
size_t verbosity);

/*! @} */

DPCTL_C_EXTERN_C_END
24 changes: 21 additions & 3 deletions libsyclinterface/source/dpctl_sycl_platform_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "dpctl_sycl_platform_manager.h"
#include "Support/CBindingWrapping.h"
#include "dpctl_error_handlers.h"
#include "dpctl_string_utils.hpp"
#include "dpctl_sycl_platform_interface.h"
#include "dpctl_utils_helper.h"
#include <CL/sycl.hpp>
Expand All @@ -41,7 +42,7 @@ namespace
{
DEFINE_SIMPLE_CONVERSION_FUNCTIONS(platform, DPCTLSyclPlatformRef);

void platform_print_info_impl(const platform &p, size_t verbosity)
std::string platform_print_info_impl(const platform &p, size_t verbosity)
{
std::stringstream ss;

Expand Down Expand Up @@ -96,7 +97,7 @@ void platform_print_info_impl(const platform &p, size_t verbosity)
}
}

std::cout << ss.str();
return ss.str();
}

} // namespace
Expand All @@ -111,10 +112,27 @@ void DPCTLPlatformMgr_PrintInfo(__dpctl_keep const DPCTLSyclPlatformRef PRef,
{
auto p = unwrap(PRef);
if (p) {
platform_print_info_impl(*p, verbosity);
std::cout << platform_print_info_impl(*p, verbosity);
}
else {
error_handler("Platform reference is NULL.", __FILE__, __func__,
__LINE__);
}
}

__dpctl_give const char *
DPCTLPlatformMgr_GetInfo(__dpctl_keep const DPCTLSyclPlatformRef PRef,
size_t verbosity)
{
const char *cstr_info = nullptr;
auto p = unwrap(PRef);
if (p) {
auto infostr = platform_print_info_impl(*p, verbosity);
cstr_info = dpctl::helper::cstring_from_string(infostr);
}
else {
error_handler("Platform reference is NULL.", __FILE__, __func__,
__LINE__);
}
return cstr_info;
}
48 changes: 48 additions & 0 deletions libsyclinterface/tests/test_sycl_platform_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,14 @@ TEST_P(TestDPCTLSyclPlatformInterface, ChkCopyNullArg)
EXPECT_NO_FATAL_FAILURE(DPCTLPlatform_Delete(Copied_PRef));
}

TEST_P(TestDPCTLSyclPlatformInterface, ChkGetInfo)
{
const char *info_str = nullptr;
EXPECT_NO_FATAL_FAILURE(info_str = DPCTLPlatformMgr_GetInfo(PRef, 0));
ASSERT_TRUE(info_str != nullptr);
EXPECT_NO_FATAL_FAILURE(DPCTLCString_Delete(info_str));
}

TEST_P(TestDPCTLSyclPlatformInterface, ChkPrintInfo)
{
EXPECT_NO_FATAL_FAILURE(DPCTLPlatformMgr_PrintInfo(PRef, 0));
Expand Down Expand Up @@ -255,6 +263,46 @@ TEST_F(TestDPCTLSyclDefaultPlatform, ChkGetBackend)
check_platform_backend(PRef);
}

TEST_F(TestDPCTLSyclDefaultPlatform, ChkGetInfo0)
{
const char *info_str = nullptr;
EXPECT_NO_FATAL_FAILURE(info_str = DPCTLPlatformMgr_GetInfo(PRef, 0));
ASSERT_TRUE(info_str != nullptr);
EXPECT_NO_FATAL_FAILURE(DPCTLCString_Delete(info_str));
}

TEST_F(TestDPCTLSyclDefaultPlatform, ChkGetInfo1)
{
const char *info_str = nullptr;
EXPECT_NO_FATAL_FAILURE(info_str = DPCTLPlatformMgr_GetInfo(PRef, 1));
ASSERT_TRUE(info_str != nullptr);
EXPECT_NO_FATAL_FAILURE(DPCTLCString_Delete(info_str));
}

TEST_F(TestDPCTLSyclDefaultPlatform, ChkGetInfo2)
{
const char *info_str = nullptr;
EXPECT_NO_FATAL_FAILURE(info_str = DPCTLPlatformMgr_GetInfo(PRef, 2));
ASSERT_TRUE(info_str != nullptr);
EXPECT_NO_FATAL_FAILURE(DPCTLCString_Delete(info_str));
}

TEST_F(TestDPCTLSyclDefaultPlatform, ChkGetInfo3)
{
const char *info_str = nullptr;
EXPECT_NO_FATAL_FAILURE(info_str = DPCTLPlatformMgr_GetInfo(PRef, 3));
ASSERT_TRUE(info_str != nullptr);
EXPECT_NO_FATAL_FAILURE(DPCTLCString_Delete(info_str));
}

TEST_F(TestDPCTLSyclDefaultPlatform, ChkGetInfoNull)
{
const char *info_str = nullptr;
DPCTLSyclPlatformRef NullPRef = nullptr;
EXPECT_NO_FATAL_FAILURE(info_str = DPCTLPlatformMgr_GetInfo(NullPRef, 0));
ASSERT_TRUE(info_str == nullptr);
}

TEST_F(TestDPCTLSyclDefaultPlatform, ChkPrintInfo0)
{
EXPECT_NO_FATAL_FAILURE(DPCTLPlatformMgr_PrintInfo(PRef, 0));
Expand Down

0 comments on commit a94e63b

Please sign in to comment.