Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dpt.put returns a TypeError when vals is usm_ndarray with deferent dtype than x #1382

Closed
vlad-perevezentsev opened this issue Aug 31, 2023 · 1 comment
Assignees

Comments

@vlad-perevezentsev
Copy link
Collaborator

The documentation for dpt.put does not describe the case when vals is a usm_ndarray with a different data type than x.
Calling this case raises a TypeError from dpctl backend.

I think we should cast vals to proper data type in case of x mismatch as numpy does.

The below example demonstrates this case:

import dpctl.tensor as dpt

x = dpt.arange(10)
ind = dpt.asarray([0])
vals = dpt.asarray([10], dtype='f4')

dpt.put(a,ind,vals)

hev, _ = ti._put(x, (indices,), vals, axis, mode, sycl_queue=exec_q)
    214 hev.wait()

TypeError: Array data types are not the same.

# numpy 

import numpy

x_np = dpt.asnumpy(x)

numpy.put(x_np, dpt.asnumpy(ind), dpt.asnumpy(vals))
x_np
# array([10,  1,  2,  3,  4,  5,  6,  7,  8,  9])

@ndgrigorian
Copy link
Collaborator

Resolved by #1647

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants