Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Arrow support for image_file and image-related bug-fixes #760

Merged
merged 4 commits into from
Jan 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ wgpu = { git = "https://github.com/gfx-rs/wgpu.git", ref = "a377ae2b7fe6c1c94127
wgpu-core = { git = "https://github.com/gfx-rs/wgpu.git", ref = "a377ae2b7fe6c1c9412751166f0917e617164e49" }
#wgpu = { path = "../wgpu/wgpu" }

# Upstream PR https://github.com/jorgecarleitao/arrow2/pull/1351
arrow2 = { git = "https://github.com/rerun-io/arrow2", rev = "f134e58bb554b069392f6bd495aa43aa06b58944" }
# Upstream PRs https://github.com/DataEngineeringLabs/arrow2-convert/pull/90
arrow2_convert = { git = "https://github.com/rerun-io/arrow2-convert", rev = "7e0a3a3881eb4577f95d1b7af76e7f943c4fa53d" }
# Upstream PR https://github.com/jorgecarleitao/arrow2/pull/1360
arrow2 = { git = "https://github.com/rerun-io/arrow2", rev = "c6ef5e3dde4a18c35c8a1d91d9a741835be0305f" }
# Upstream PR https://github.com/DataEngineeringLabs/arrow2-convert/pull/91
arrow2_convert = { git = "https://github.com/rerun-io/arrow2-convert", rev = "aa48be082d0039e8c1a67b581ce5f697b89f3765" }
#arrow2 = { path = "../arrow2" }
#arrow2_convert = { path = "../arrow2-convert/arrow2_convert" }
5 changes: 5 additions & 0 deletions crates/re_log_types/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ ruzstd = { version = "0.3.0", optional = true } # works on wasm
criterion = "0.4"
mimalloc = "0.1"
serde_test = { version = "1" }
arrow2 = { workspace = true, features = [
"io_ipc",
"io_print",
"compute_concatenate",
] }
Comment on lines +96 to +100
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably not needed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's nice having the concatenation test in there given what a problem it caused for images.


[lib]
bench = false
Expand Down
61 changes: 61 additions & 0 deletions crates/re_log_types/src/field_types/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ impl ArrowDeserialize for TensorId {
/// DataType::List(Box::new(Field::new("item", DataType::Float64, false))),
/// false
/// ),
/// Field::new("JPEG", DataType::Binary, false),
/// ],
/// None,
/// UnionMode::Dense
Expand All @@ -159,6 +160,7 @@ pub enum TensorData {
//F16(Vec<arrow2::types::f16>),
F32(Vec<f32>),
F64(Vec<f64>),
JPEG(Vec<u8>),
}

/// Flattened `Tensor` data payload
Expand Down Expand Up @@ -367,6 +369,10 @@ impl From<&Tensor> for ClassicTensor {
crate::TensorDataType::F64,
TensorDataStore::Dense(Arc::from(bytemuck::cast_slice(data.as_slice()))),
),
TensorData::JPEG(data) => (
crate::TensorDataType::U8,
TensorDataStore::Jpeg(Arc::from(data.as_slice())),
),
};

ClassicTensor::new(
Expand Down Expand Up @@ -511,3 +517,58 @@ fn test_arrow() {
let tensors_out: Vec<Tensor> = TryIntoCollection::try_into_collection(array).unwrap();
assert_eq!(tensors_in, tensors_out);
}

#[test]
fn test_concat_and_slice() {
use crate::msg_bundle::wrap_in_listarray;
use arrow2::array::ListArray;
use arrow2::compute::concatenate::concatenate;
use arrow2_convert::{deserialize::TryIntoCollection, serialize::TryIntoArrow};

let tensor1 = vec![Tensor {
tensor_id: TensorId::random(),
shape: vec![TensorDimension {
size: 4,
name: None,
}],
data: TensorData::JPEG(vec![1, 2, 3, 4]),
meaning: TensorDataMeaning::Unknown,
}];

let tensor2 = vec![Tensor {
tensor_id: TensorId::random(),
shape: vec![TensorDimension {
size: 4,
name: None,
}],
data: TensorData::JPEG(vec![5, 6, 7, 8]),
meaning: TensorDataMeaning::Unknown,
}];

let array1: Box<dyn arrow2::array::Array> = tensor1.iter().try_into_arrow().unwrap();
let list1 = wrap_in_listarray(array1).boxed();
let array2: Box<dyn arrow2::array::Array> = tensor2.iter().try_into_arrow().unwrap();
let list2 = wrap_in_listarray(array2).boxed();

let pre_concat = list1
.as_any()
.downcast_ref::<ListArray<i32>>()
.unwrap()
.value(0);

let tensor_out: Vec<Tensor> = TryIntoCollection::try_into_collection(pre_concat).unwrap();

assert_eq!(tensor1, tensor_out);

let concat = concatenate(&[list1.as_ref(), list2.as_ref()]).unwrap();

let slice = concat
.as_any()
.downcast_ref::<ListArray<i32>>()
.unwrap()
.value(1);

let tensor_out: Vec<Tensor> = TryIntoCollection::try_into_collection(slice).unwrap();

assert_eq!(tensor2[0], tensor_out[0]);
}
9 changes: 2 additions & 7 deletions rerun_py/rerun/log/file.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import logging
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import Optional

import numpy as np
import numpy.typing as npt
from rerun.log import EXP_ARROW

from rerun import bindings

Expand Down Expand Up @@ -84,8 +82,5 @@ def log_image_file(
"""
img_format = getattr(img_format, "value", None)

if EXP_ARROW.classic_log_gate():
bindings.log_image_file(obj_path, img_path, img_format, timeless)

if EXP_ARROW.arrow_log_gate():
logging.warning("log_image_file() not yet implemented for Arrow.")
# Image file arrow handling happens inside the python bridge
bindings.log_image_file(obj_path, img_path, img_format, timeless)
4 changes: 2 additions & 2 deletions rerun_py/rerun/log/lines.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,8 @@ def log_line_segments(
radii = _normalize_radii([stroke_width / 2])
comps[1]["rerun.radius"] = RadiusArray.from_numpy(radii)

bindings.log_arrow_msg(f"arrow/{obj_path}", components=comps[0], timeless=timeless)
bindings.log_arrow_msg(obj_path, components=comps[0], timeless=timeless)

if comps[1]:
comps[1]["rerun.instance"] = InstanceArray.splat()
bindings.log_arrow_msg(f"arrow/{obj_path}", components=comps[1], timeless=timeless)
bindings.log_arrow_msg(obj_path, components=comps[1], timeless=timeless)
74 changes: 49 additions & 25 deletions rerun_py/src/python_bridge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1746,8 +1746,6 @@ fn log_mesh_file(
let mut session = global_session();
let obj_path = session.classic_prefix_obj_path(obj_path);

session.register_type(obj_path.obj_type_path(), ObjectType::Mesh3D);

let time_point = time(timeless);

let mesh3d = Mesh3D::Encoded(EncodedMesh3D {
Expand All @@ -1764,9 +1762,7 @@ fn log_mesh_file(
//
// TODO(jleibs) replace with python-native implementation
if session.arrow_log_gate() {
let mut arrow_path = "arrow/".to_owned();
arrow_path.push_str(obj_path_str);
let arrow_path = parse_obj_path(arrow_path.as_str())?;
let arrow_path = session.arrow_prefix_obj_path(obj_path.clone());

let bundle = MsgBundle::new(
MsgId::random(),
Expand All @@ -1781,6 +1777,8 @@ fn log_mesh_file(
}

if session.classic_log_gate() {
let obj_path = session.arrow_prefix_obj_path(obj_path);
session.register_type(obj_path.obj_type_path(), ObjectType::Mesh3D);
session.send_data(
&time_point,
(&obj_path, "mesh"),
Expand Down Expand Up @@ -1813,7 +1811,7 @@ fn log_image_file(
};

use image::ImageDecoder as _;
let ((w, h), data) = match img_format {
let (w, h) = match img_format {
image::ImageFormat::Jpeg => {
use image::codecs::jpeg::JpegDecoder;
let jpeg = JpegDecoder::new(Cursor::new(&img_bytes))
Expand All @@ -1828,7 +1826,7 @@ fn log_image_file(
)));
}

(jpeg.dimensions(), TensorDataStore::Jpeg(img_bytes.into()))
jpeg.dimensions()
}
_ => {
return Err(PyTypeError::new_err(format!(
Expand All @@ -1839,27 +1837,53 @@ fn log_image_file(
};

let mut session = global_session();
let obj_path = session.classic_prefix_obj_path(obj_path);

session.register_type(obj_path.obj_type_path(), ObjectType::Image);

let time_point = time(timeless);

session.send_data(
&time_point,
(&obj_path, "tensor"),
LoggedData::Single(Data::Tensor(re_log_types::ClassicTensor::new(
TensorId::random(),
vec![
TensorDimension::height(h as _),
TensorDimension::width(w as _),
TensorDimension::depth(3),
],
TensorDataType::U8,
re_log_types::field_types::TensorDataMeaning::Unknown,
data,
))),
);
if session.arrow_log_gate() {
let arrow_path = session.arrow_prefix_obj_path(obj_path.clone());
let bundle = MsgBundle::new(
MsgId::random(),
arrow_path,
time_point.clone(),
vec![vec![re_log_types::field_types::Tensor {
tensor_id: TensorId::random(),
shape: vec![
TensorDimension::height(h as _),
TensorDimension::width(w as _),
TensorDimension::depth(3),
],
data: re_log_types::field_types::TensorData::JPEG(img_bytes.clone()),
meaning: re_log_types::field_types::TensorDataMeaning::Unknown,
}]
.try_into()
.unwrap()],
);

let msg = bundle.try_into().unwrap();

session.send(LogMsg::ArrowMsg(msg));
}

if session.classic_log_gate() {
let obj_path = session.classic_prefix_obj_path(obj_path);
session.register_type(obj_path.obj_type_path(), ObjectType::Image);
session.send_data(
&time_point,
(&obj_path, "tensor"),
LoggedData::Single(Data::Tensor(re_log_types::ClassicTensor::new(
TensorId::random(),
vec![
TensorDimension::height(h as _),
TensorDimension::width(w as _),
TensorDimension::depth(3),
],
TensorDataType::U8,
re_log_types::field_types::TensorDataMeaning::Unknown,
TensorDataStore::Jpeg(img_bytes.into()),
))),
);
}

Ok(())
}
Expand Down