Skip to content

Commit

Permalink
Unify MeanOp and fuse RepFixedpointMeanOp with RingFixedpointMeanOp (#…
Browse files Browse the repository at this point in the history
…941)

* Unify MeanOp

* Fuse RepFixedpointMean with RingFixedpointMean

Co-authored-by: Lex Vorona <[email protected]>
  • Loading branch information
yanndupis and Lex Vorona authored Mar 15, 2022
1 parent cbb20b0 commit d2d56d2
Show file tree
Hide file tree
Showing 16 changed files with 66 additions and 157 deletions.
2 changes: 1 addition & 1 deletion moose/src/compilation/deprecated_logical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ fn lower_op(op: &Operation) -> Operation {
},
(Placement::Host(_), Operator::Mean(ref i)) => Operation {
name: op.name.clone(),
kind: HostMeanOp {
kind: MeanOp {
sig: Signature::unary(lower_ty(i.sig.arg(0).unwrap()), lower_ty(i.sig.ret())),
axis: i.axis,
}
Expand Down
14 changes: 6 additions & 8 deletions moose/src/compilation/networking.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ mod tests {
y = Constant{value=HostFloat32Tensor([[1.0, 2.0], [3.0, 4.0]])}: () -> HostFloat32Tensor @Host(alice)
mul = Mul: (HostFloat32Tensor, HostFloat32Tensor) -> HostFloat32Tensor (x, y) @Host(alice)
dot = Dot: (HostFloat32Tensor, HostFloat32Tensor) -> HostFloat32Tensor (x, y) @Host(alice)
mean = HostMean{}: (HostFloat32Tensor) -> HostFloat32Tensor (dot) @Host(alice)"#;
mean = Mean{}: (HostFloat32Tensor) -> HostFloat32Tensor (dot) @Host(alice)"#;

let comp = NetworkingPass::pass(&source.try_into()?)?
.unwrap()
Expand All @@ -141,9 +141,8 @@ mod tests {
assert!(comp.contains(
"dot = Dot: (HostFloat32Tensor, HostFloat32Tensor) -> HostFloat32Tensor (x, y) @Host(alice)"
));
assert!(comp.contains(
"mean = HostMean: (HostFloat32Tensor) -> HostFloat32Tensor (dot) @Host(alice)"
));
assert!(comp
.contains("mean = Mean: (HostFloat32Tensor) -> HostFloat32Tensor (dot) @Host(alice)"));
Ok(())
}

Expand All @@ -154,7 +153,7 @@ mod tests {
y = Constant{value=HostFloat32Tensor([[1.0, 2.0], [3.0, 4.0]])}: () -> HostFloat32Tensor @Host(bob)
mul = Mul: (HostFloat32Tensor, HostFloat32Tensor) -> HostFloat32Tensor (x, y) @Host(alice)
dot = Dot: (HostFloat32Tensor, HostFloat32Tensor) -> HostFloat32Tensor (x, y) @Host(alice)
mean = HostMean{}: (HostFloat32Tensor) -> HostFloat32Tensor (dot) @Host(alice)"#;
mean = Mean{}: (HostFloat32Tensor) -> HostFloat32Tensor (dot) @Host(alice)"#;
let comp = NetworkingPass::pass(&source.try_into()?)?
.unwrap()
.to_textual();
Expand All @@ -166,9 +165,8 @@ mod tests {
assert!(comp.contains(r#"receive_0 = Receive{rendezvous_key = 00000000000000000000000000000000, sender = "bob"}: () -> HostFloat32Tensor () @Host(alice)"#));
assert!(comp.contains("mul = Mul: (HostFloat32Tensor, HostFloat32Tensor) -> HostFloat32Tensor (x, receive_0) @Host(alice)"));
assert!(comp.contains("dot = Dot: (HostFloat32Tensor, HostFloat32Tensor) -> HostFloat32Tensor (x, receive_0) @Host(alice)"));
assert!(comp.contains(
"mean = HostMean: (HostFloat32Tensor) -> HostFloat32Tensor (dot) @Host(alice)"
));
assert!(comp
.contains("mean = Mean: (HostFloat32Tensor) -> HostFloat32Tensor (dot) @Host(alice)"));
Ok(())
}

Expand Down
2 changes: 1 addition & 1 deletion moose/src/compilation/typing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ mod tests {
y = Constant{value=HostFloat32Tensor([[1.0, 2.0], [3.0, 4.0]])}: () -> HostFloat32Tensor @Host(alice)
mul = Mul: (HostFloat32Tensor, HostFloat32Tensor) -> HostFloat32Tensor (x, y) @Host(alice)
dot = Dot: (HostFloat32Tensor, HostFloat32Tensor) -> HostFloat32Tensor (x, y) @Host(alice)
mean = HostMean{}: (HostFloat32Tensor) -> HostFloat32Tensor (dot) @Host(alice)
mean = Mean{}: (HostFloat32Tensor) -> HostFloat32Tensor (dot) @Host(alice)
constant_0 = Constant{value = HostString("regression_weights")}: () -> HostString () @Host(alice)
save = Save: (HostString, Unknown) -> HostUnit (constant_0, mean) @Host(alice)
"#;
Expand Down
31 changes: 0 additions & 31 deletions moose/src/computation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -870,7 +870,6 @@ operators![
Shl,
Shr,
Abs,
HostMean,
Diag,
Sign,
RingFixedpointMean,
Expand All @@ -887,7 +886,6 @@ operators![
// Fixed-point operators
FixedpointEncode,
FixedpointDecode,
FixedpointMean,
Pow2,
Exp,
Sigmoid,
Expand All @@ -896,16 +894,13 @@ operators![
EqualZero,
LessThan,
GreaterThan,
// Floating-point operators
FloatingpointMean,
// Additive operators
AdtToRep,
// Replicated operators
Share,
Reveal,
Fill,
Msb,
RepFixedpointMean,
AddN,
TruncPr,
RepToAdt,
Expand Down Expand Up @@ -1141,12 +1136,6 @@ pub struct SignOp {
pub sig: Signature,
}

#[derive(Serialize, Deserialize, PartialEq, Eq, Hash, Clone, Debug, ShortName, FromTextual)]
pub struct HostMeanOp {
pub sig: Signature,
pub axis: Option<u32>,
}

#[derive(
Serialize, Deserialize, PartialEq, Eq, Hash, Clone, Debug, ShortName, ToTextual, FromTextual,
)]
Expand Down Expand Up @@ -1326,25 +1315,13 @@ pub struct FixedpointDecodeOp {
pub fractional_precision: u32,
}

#[derive(Serialize, Deserialize, PartialEq, Eq, Hash, Clone, Debug, ShortName, FromTextual)]
pub struct FixedpointMeanOp {
pub sig: Signature,
pub axis: Option<u32>,
}

#[derive(
Serialize, Deserialize, PartialEq, Eq, Hash, Clone, Debug, ShortName, ToTextual, FromTextual,
)]
pub struct NegOp {
pub sig: Signature,
}

#[derive(Serialize, Deserialize, PartialEq, Eq, Hash, Clone, Debug, ShortName, FromTextual)]
pub struct FloatingpointMeanOp {
pub sig: Signature,
pub axis: Option<u32>,
}

#[derive(
Serialize, Deserialize, PartialEq, Eq, Hash, Clone, Debug, ShortName, ToTextual, FromTextual,
)]
Expand Down Expand Up @@ -1410,14 +1387,6 @@ pub struct RevealOp {
pub sig: Signature,
}

#[derive(Serialize, Deserialize, PartialEq, Eq, Hash, Clone, Debug, ShortName, FromTextual)]
pub struct RepFixedpointMeanOp {
pub sig: Signature,
pub axis: Option<u32>,
pub scaling_base: u64,
pub scaling_exp: u32,
}

#[derive(
Serialize, Deserialize, PartialEq, Eq, Hash, Clone, Debug, ShortName, ToTextual, FromTextual,
)]
Expand Down
4 changes: 0 additions & 4 deletions moose/src/execution/asynchronous.rs
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,6 @@ impl DispatchKernel<AsyncSession> for Operator {
TruncPr(op) => DispatchKernel::compile(op, plc),
Msb(op) => DispatchKernel::compile(op, plc),
RepToAdt(op) => DispatchKernel::compile(op, plc),
RepFixedpointMean(op) => DispatchKernel::compile(op, plc),
BitDecompose(op) => DispatchKernel::compile(op, plc),
BitCompose(op) => DispatchKernel::compile(op, plc),
AdtToRep(op) => DispatchKernel::compile(op, plc),
Expand All @@ -368,14 +367,11 @@ impl DispatchKernel<AsyncSession> for Operator {
Input(op) => DispatchKernel::compile(op, plc),
Output(op) => DispatchKernel::compile(op, plc),
AtLeast2D(op) => DispatchKernel::compile(op, plc),
HostMean(op) => DispatchKernel::compile(op, plc),
FixedpointEncode(op) => DispatchKernel::compile(op, plc),
FixedpointDecode(op) => DispatchKernel::compile(op, plc),
FixedpointMean(op) => DispatchKernel::compile(op, plc),
Sign(op) => DispatchKernel::compile(op, plc),
Transpose(op) => DispatchKernel::compile(op, plc),
Squeeze(op) => DispatchKernel::compile(op, plc),
FloatingpointMean(op) => DispatchKernel::compile(op, plc),
Identity(op) => DispatchKernel::compile(op, plc),
Cast(op) => DispatchKernel::compile(op, plc),
Reshape(op) => DispatchKernel::compile(op, plc),
Expand Down
12 changes: 6 additions & 6 deletions moose/src/execution/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -690,16 +690,16 @@ mod tests {
false,
true
)]
#[case("HostMean", None, "Float32(2.5) @Host(alice)", true, true)]
#[case("Mean", None, "Float32(2.5) @Host(alice)", true, true)]
#[case(
"HostMean",
"Mean",
Some(0),
"HostFloat32Tensor([2.0, 3.0]) @Host(alice)",
false,
true
)]
#[case(
"HostMean",
"Mean",
Some(1),
"HostFloat32Tensor([1.5, 3.5]) @Host(alice)",
false,
Expand All @@ -720,16 +720,16 @@ mod tests {
false,
false
)]
#[case("HostMean", None, "Float32(2.5) @Host(alice)", true, false)]
#[case("Mean", None, "Float32(2.5) @Host(alice)", true, false)]
#[case(
"HostMean",
"Mean",
Some(0),
"HostFloat32Tensor([2.0, 3.0]) @Host(alice)",
false,
false
)]
#[case(
"HostMean",
"Mean",
Some(1),
"HostFloat32Tensor([1.5, 3.5]) @Host(alice)",
false,
Expand Down
4 changes: 0 additions & 4 deletions moose/src/execution/symbolic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,6 @@ impl SymbolicStrategy for DefaultSymbolicStrategy {
Share(op) => DispatchKernel::compile(&op, plc)?(sess, operands),
Reveal(op) => DispatchKernel::compile(&op, plc)?(sess, operands),
TruncPr(op) => DispatchKernel::compile(&op, plc)?(sess, operands),
RepFixedpointMean(op) => DispatchKernel::compile(&op, plc)?(sess, operands),
AddN(op) => DispatchKernel::compile(&op, plc)?(sess, operands),
Shl(op) => DispatchKernel::compile(&op, plc)?(sess, operands),
Shr(op) => DispatchKernel::compile(&op, plc)?(sess, operands),
Expand All @@ -317,7 +316,6 @@ impl SymbolicStrategy for DefaultSymbolicStrategy {
BitCompose(op) => DispatchKernel::compile(&op, plc)?(sess, operands),
ShlDim(op) => DispatchKernel::compile(&op, plc)?(sess, operands),
AdtToRep(op) => DispatchKernel::compile(&op, plc)?(sess, operands),
HostMean(op) => DispatchKernel::compile(&op, plc)?(sess, operands),
Sqrt(op) => DispatchKernel::compile(&op, plc)?(sess, operands),
Sign(op) => DispatchKernel::compile(&op, plc)?(sess, operands),
Pow2(op) => DispatchKernel::compile(&op, plc)?(sess, operands),
Expand All @@ -331,9 +329,7 @@ impl SymbolicStrategy for DefaultSymbolicStrategy {
GreaterThan(op) => DispatchKernel::compile(&op, plc)?(sess, operands),
FixedpointEncode(op) => DispatchKernel::compile(&op, plc)?(sess, operands),
FixedpointDecode(op) => DispatchKernel::compile(&op, plc)?(sess, operands),
FixedpointMean(op) => DispatchKernel::compile(&op, plc)?(sess, operands),
Identity(op) => DispatchKernel::compile(&op, plc)?(sess, operands),
FloatingpointMean(op) => DispatchKernel::compile(&op, plc)?(sess, operands),
AtLeast2D(op) => DispatchKernel::compile(&op, plc)?(sess, operands),
IndexAxis(op) => DispatchKernel::compile(&op, plc)?(sess, operands),
Slice(op) => DispatchKernel::compile(&op, plc)?(sess, operands),
Expand Down
4 changes: 0 additions & 4 deletions moose/src/execution/synchronous.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,6 @@ impl Session for SyncSession {
Msb(op) => DispatchKernel::compile(&op, plc)?(self, operands)?,
Abs(op) => DispatchKernel::compile(&op, plc)?(self, operands)?,
RepToAdt(op) => DispatchKernel::compile(&op, plc)?(self, operands)?,
RepFixedpointMean(op) => DispatchKernel::compile(&op, plc)?(self, operands)?,
AddN(op) => DispatchKernel::compile(&op, plc)?(self, operands)?,
Index(op) => DispatchKernel::compile(&op, plc)?(self, operands)?,
BitDecompose(op) => DispatchKernel::compile(&op, plc)?(self, operands)?,
Expand All @@ -215,16 +214,13 @@ impl Session for SyncSession {
Constant(op) => DispatchKernel::compile(&op, plc)?(self, operands)?,
Input(op) => DispatchKernel::compile(&op, plc)?(self, operands)?,
Output(op) => DispatchKernel::compile(&op, plc)?(self, operands)?,
HostMean(op) => DispatchKernel::compile(&op, plc)?(self, operands)?,
Sqrt(op) => DispatchKernel::compile(&op, plc)?(self, operands)?,
FixedpointEncode(op) => DispatchKernel::compile(&op, plc)?(self, operands)?,
FixedpointDecode(op) => DispatchKernel::compile(&op, plc)?(self, operands)?,
FixedpointMean(op) => DispatchKernel::compile(&op, plc)?(self, operands)?,
Diag(op) => DispatchKernel::compile(&op, plc)?(self, operands)?,
ExpandDims(op) => DispatchKernel::compile(&op, plc)?(self, operands)?,
Squeeze(op) => DispatchKernel::compile(&op, plc)?(self, operands)?,
Sign(op) => DispatchKernel::compile(&op, plc)?(self, operands)?,
FloatingpointMean(op) => DispatchKernel::compile(&op, plc)?(self, operands)?,
Identity(op) => DispatchKernel::compile(&op, plc)?(self, operands)?,
Cast(op) => DispatchKernel::compile(&op, plc)?(self, operands)?,
AtLeast2D(op) => DispatchKernel::compile(&op, plc)?(self, operands)?,
Expand Down
2 changes: 1 addition & 1 deletion moose/src/fixedpoint/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1123,7 +1123,7 @@ impl ShapeOp {
}
}

impl FixedpointMeanOp {
impl MeanOp {
pub(crate) fn fixed_host_kernel<S: Session, HostFixedT, MirFixedT, RepFixedT>(
sess: &S,
plc: &HostPlacement,
Expand Down
2 changes: 1 addition & 1 deletion moose/src/floatingpoint/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ impl IdentityOp {
}
}

impl FloatingpointMeanOp {
impl MeanOp {
pub(crate) fn float_host_kernel<S: Session, HostFloatT, MirroredT>(
sess: &S,
plc: &HostPlacement,
Expand Down
8 changes: 4 additions & 4 deletions moose/src/host/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -719,8 +719,8 @@ impl BitDecomposeOp {
}
}

impl HostMeanOp {
pub(crate) fn kernel<S: RuntimeSession, T: LinalgScalar + FromPrimitive>(
impl MeanOp {
pub(crate) fn host_kernel<S: RuntimeSession, T: LinalgScalar + FromPrimitive>(
_sess: &S,
plc: &HostPlacement,
axis: Option<u32>,
Expand All @@ -734,7 +734,7 @@ impl HostMeanOp {
let reduced: Option<ArrayD<T>> = x.0.mean_axis(Axis(i as usize));
if reduced.is_none() {
return Err(Error::KernelError(
"HostMeanOp cannot reduce over an empty axis.".to_string(),
"MeanOp cannot reduce over an empty axis.".to_string(),
));
};
Ok(HostTensor::place(plc, reduced.unwrap().into_shared()))
Expand All @@ -743,7 +743,7 @@ impl HostMeanOp {
let mean = x.0.mean();
if mean.is_none() {
return Err(Error::KernelError(
"HostMeanOp cannot reduce over an empty tensor.".to_string(),
"MeanOp cannot reduce over an empty tensor.".to_string(),
));
};
let out = Array::from_elem([], mean.unwrap())
Expand Down
53 changes: 21 additions & 32 deletions moose/src/kernels/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -377,38 +377,33 @@ pub trait PlacementMean<S: Session, T, O> {
fn mean(&self, sess: &S, axis: Option<u32>, x: &T) -> O;
}

kernel! {
MeanOp, [
(HostPlacement, (Tensor) -> Tensor => [concrete] attributes[sig, axis] Self::host_kernel),
(ReplicatedPlacement, (Tensor) -> Tensor => [concrete] attributes[sig, axis] Self::rep_kernel),
]
}

modelled_kernel! {
PlacementMean::mean, HostMeanOp{axis: Option<u32>},
[
(HostPlacement, (HostFloat32Tensor) -> HostFloat32Tensor => [runtime] Self::kernel),
(HostPlacement, (HostFloat64Tensor) -> HostFloat64Tensor => [runtime] Self::kernel),
]
}

modelled_kernel! {
PlacementMean::mean, FloatingpointMeanOp{axis: Option<u32>},
PlacementMean::mean, MeanOp{axis: Option<u32>},
[
(HostPlacement, (Tensor) -> Tensor => [concrete] custom |op| {
let sig = op.sig;
let axis = op.axis;
Ok(Box::new(move |sess, plc, x| {
Self::logical_host_kernel(sess, plc, sig, axis, x)
}))
}),
(HostPlacement, (HostFloat32Tensor) -> HostFloat32Tensor => [runtime] Self::host_kernel),
(HostPlacement, (HostFloat64Tensor) -> HostFloat64Tensor => [runtime] Self::host_kernel),
(HostPlacement, (Float32Tensor) -> Float32Tensor => [concrete] Self::float_host_kernel),
(HostPlacement, (Float64Tensor) -> Float64Tensor => [concrete] Self::float_host_kernel),
]
}

modelled_kernel! {
PlacementMean::mean, FixedpointMeanOp{axis: Option<u32>},
[
(HostPlacement, (Fixed64Tensor) -> Fixed64Tensor => [concrete] Self::fixed_host_kernel),
(HostPlacement, (Fixed128Tensor) -> Fixed128Tensor => [concrete] Self::fixed_host_kernel),
(ReplicatedPlacement, (Fixed64Tensor) -> Fixed64Tensor => [concrete] Self::fixed_rep_kernel),
(ReplicatedPlacement, (Fixed128Tensor) -> Fixed128Tensor => [concrete] Self::fixed_rep_kernel),
(HostPlacement, (HostFixed64Tensor) -> HostFixed64Tensor => [concrete] Self::hostfixed_kernel),
(HostPlacement, (HostFixed128Tensor) -> HostFixed128Tensor => [concrete] Self::hostfixed_kernel),
(ReplicatedPlacement, (Tensor) -> Tensor => [concrete] custom |op| {
let sig = op.sig;
let axis = op.axis;
Ok(Box::new(move |sess, plc, x| {
Self::logical_rep_kernel(sess, plc, sig, axis, x)
}))
}),
(ReplicatedPlacement, (Fixed64Tensor) -> Fixed64Tensor => [concrete] Self::fixed_rep_kernel),
(ReplicatedPlacement, (Fixed128Tensor) -> Fixed128Tensor => [concrete] Self::fixed_rep_kernel),
(ReplicatedPlacement, (ReplicatedFixed64Tensor) -> ReplicatedFixed64Tensor => [concrete] Self::repfixed_kernel),
(ReplicatedPlacement, (ReplicatedFixed128Tensor) -> ReplicatedFixed128Tensor => [concrete] Self::repfixed_kernel),
]
Expand All @@ -425,19 +420,13 @@ pub trait PlacementMeanAsFixedpoint<S: Session, T, O> {
) -> O;
}

modelled_kernel! {
PlacementMeanAsFixedpoint::mean_as_fixedpoint, RepFixedpointMeanOp{axis: Option<u32>, scaling_base: u64, scaling_exp: u32},
[
(ReplicatedPlacement, (ReplicatedRing64Tensor) -> ReplicatedRing64Tensor => [concrete] Self::kernel),
(ReplicatedPlacement, (ReplicatedRing128Tensor) -> ReplicatedRing128Tensor => [concrete] Self::kernel),
]
}

modelled_kernel! {
PlacementMeanAsFixedpoint::mean_as_fixedpoint, RingFixedpointMeanOp{axis: Option<u32>, scaling_base: u64, scaling_exp: u32},
[
(HostPlacement, (HostRing64Tensor) -> HostRing64Tensor => [runtime] Self::ring64_kernel),
(HostPlacement, (HostRing128Tensor) -> HostRing128Tensor => [runtime] Self::ring128_kernel),
(ReplicatedPlacement, (ReplicatedRing64Tensor) -> ReplicatedRing64Tensor => [concrete] Self::rep_kernel),
(ReplicatedPlacement, (ReplicatedRing128Tensor) -> ReplicatedRing128Tensor => [concrete] Self::rep_kernel),
]
}

Expand Down
Loading

0 comments on commit d2d56d2

Please sign in to comment.