Skip to content

Commit

Permalink
Always dump and load with little endianness (#3)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko authored Aug 4, 2023
1 parent 596fcce commit 521d18e
Showing 1 changed file with 29 additions and 3 deletions.
32 changes: 29 additions & 3 deletions lib/safetensors.ex
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,13 @@ defmodule Safetensors do
def dump(tensors) when is_map(tensors) do
{header_entries, {buffer, _offset}} =
Enum.map_reduce(tensors, {[], 0}, fn {tensor_name, tensor}, {buffer, offset} ->
binary = Nx.to_binary(tensor)
{_, elem_size} = Nx.type(tensor)

binary =
tensor
|> Nx.to_binary()
|> new_byte_order(elem_size, :little)

end_offset = offset + byte_size(binary)

header_entry = {
Expand Down Expand Up @@ -106,11 +112,16 @@ defmodule Safetensors do
"shape" => shape
} = tensor_info

binary = binary_slice(buffer, offset_start, offset_end - offset_start)
{_, elem_size} = type = dtype_to_type(dtype)

binary =
buffer
|> binary_slice(offset_start, offset_end - offset_start)
|> new_byte_order(elem_size, :little)

tensor =
binary
|> Nx.from_binary(dtype_to_type(dtype))
|> Nx.from_binary(type)
|> Nx.reshape(List.to_tuple(shape))

{tensor_name, tensor}
Expand All @@ -124,4 +135,19 @@ defmodule Safetensors do
defp dtype_to_type(dtype) do
@dtype_to_type[dtype] || raise "unrecognized dtype #{inspect(dtype)}"
end

defp new_byte_order(binary, size, endianness) do
if System.endianness() == endianness do
binary
else
data =
for <<data::size(size)-binary <- binary>> do
data
|> :binary.decode_unsigned()
|> :binary.encode_unsigned(endianness)
end

IO.iodata_to_binary(data)
end
end
end

0 comments on commit 521d18e

Please sign in to comment.