Skip to content

Commit

Permalink
Tiny kernel cleanup (#843)
Browse files Browse the repository at this point in the history
* add_n

* concatenate
  • Loading branch information
mortendahl authored Feb 8, 2022
1 parent d900545 commit 165b4ec
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 46 deletions.
23 changes: 2 additions & 21 deletions moose/src/kernels/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -275,27 +275,8 @@ pub trait PlacementAddN<S: Session, T, O> {
fn add_n(&self, sess: &S, x: &[T]) -> O;
}

modelled!(PlacementAddN::add_n, HostPlacement, vec[Tensor] -> Tensor, AddNOp);
modelled!(PlacementAddN::add_n, HostPlacement, vec[Float32Tensor] -> Float32Tensor, AddNOp);
modelled!(PlacementAddN::add_n, HostPlacement, vec[Float64Tensor] -> Float64Tensor, AddNOp);
modelled!(PlacementAddN::add_n, HostPlacement, vec[HostFloat32Tensor] -> HostFloat32Tensor, AddNOp);
modelled!(PlacementAddN::add_n, HostPlacement, vec[HostFloat64Tensor] -> HostFloat64Tensor, AddNOp);
modelled!(PlacementAddN::add_n, HostPlacement, vec[Fixed64Tensor] -> Fixed64Tensor, AddNOp);
modelled!(PlacementAddN::add_n, HostPlacement, vec[Fixed128Tensor] -> Fixed128Tensor, AddNOp);
modelled!(PlacementAddN::add_n, HostPlacement, vec[HostFixed64Tensor] -> HostFixed64Tensor, AddNOp);
modelled!(PlacementAddN::add_n, HostPlacement, vec[HostFixed128Tensor] -> HostFixed128Tensor, AddNOp);
modelled!(PlacementAddN::add_n, HostPlacement, vec[HostRing64Tensor] -> HostRing64Tensor, AddNOp);
modelled!(PlacementAddN::add_n, HostPlacement, vec[HostRing128Tensor] -> HostRing128Tensor, AddNOp);
modelled!(PlacementAddN::add_n, ReplicatedPlacement, vec[ReplicatedRing64Tensor] -> ReplicatedRing64Tensor, AddNOp);
modelled!(PlacementAddN::add_n, ReplicatedPlacement, vec[ReplicatedRing128Tensor] -> ReplicatedRing128Tensor, AddNOp);
modelled!(PlacementAddN::add_n, ReplicatedPlacement, vec[ReplicatedFixed64Tensor] -> ReplicatedFixed64Tensor, AddNOp);
modelled!(PlacementAddN::add_n, ReplicatedPlacement, vec[ReplicatedFixed128Tensor] -> ReplicatedFixed128Tensor, AddNOp);
modelled!(PlacementAddN::add_n, ReplicatedPlacement, vec[Fixed64Tensor] -> Fixed64Tensor, AddNOp);
modelled!(PlacementAddN::add_n, ReplicatedPlacement, vec[Fixed128Tensor] -> Fixed128Tensor, AddNOp);
modelled!(PlacementAddN::add_n, ReplicatedPlacement, vec[Tensor] -> Tensor, AddNOp);

kernel! {
AddNOp,
modelled_kernel! {
PlacementAddN::add_n, AddNOp,
[
(HostPlacement, vec[Tensor] -> Tensor => [concrete] Self::host_logical_kernel),
(HostPlacement, vec[Float32Tensor] -> Float32Tensor => [concrete] Self::float_kernel),
Expand Down
24 changes: 3 additions & 21 deletions moose/src/kernels/shapes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,27 +100,9 @@ pub trait PlacementConcatenate<S: Session, TS, O> {
fn concatenate(&self, sess: &S, axis: u32, xs: &[TS]) -> O;
}

modelled!(PlacementConcatenate::concatenate, HostPlacement, attributes[axis: u32] vec[Tensor] -> Tensor, ConcatOp);
modelled!(PlacementConcatenate::concatenate, HostPlacement, attributes[axis: u32] vec[Float32Tensor] -> Float32Tensor, ConcatOp);
modelled!(PlacementConcatenate::concatenate, HostPlacement, attributes[axis: u32] vec[Float64Tensor] -> Float64Tensor, ConcatOp);
modelled!(PlacementConcatenate::concatenate, HostPlacement, attributes[axis: u32] vec[HostFloat32Tensor] -> HostFloat32Tensor, ConcatOp);
modelled!(PlacementConcatenate::concatenate, HostPlacement, attributes[axis: u32] vec[HostFloat64Tensor] -> HostFloat64Tensor, ConcatOp);
modelled!(PlacementConcatenate::concatenate, HostPlacement, attributes[axis: u32] vec[HostInt8Tensor] -> HostInt8Tensor, ConcatOp);
modelled!(PlacementConcatenate::concatenate, HostPlacement, attributes[axis: u32] vec[HostInt16Tensor] -> HostInt16Tensor, ConcatOp);
modelled!(PlacementConcatenate::concatenate, HostPlacement, attributes[axis: u32] vec[HostInt32Tensor] -> HostInt32Tensor, ConcatOp);
modelled!(PlacementConcatenate::concatenate, HostPlacement, attributes[axis: u32] vec[HostInt64Tensor] -> HostInt64Tensor, ConcatOp);
modelled!(PlacementConcatenate::concatenate, HostPlacement, attributes[axis: u32] vec[HostRing64Tensor] -> HostRing64Tensor, ConcatOp);
modelled!(PlacementConcatenate::concatenate, HostPlacement, attributes[axis: u32] vec[HostRing128Tensor] -> HostRing128Tensor, ConcatOp);
modelled!(PlacementConcatenate::concatenate, ReplicatedPlacement, attributes[axis: u32] vec[Tensor] -> Tensor, ConcatOp);
modelled!(PlacementConcatenate::concatenate, ReplicatedPlacement, attributes[axis: u32] vec[Fixed64Tensor] -> Fixed64Tensor, ConcatOp);
modelled!(PlacementConcatenate::concatenate, ReplicatedPlacement, attributes[axis: u32] vec[Fixed128Tensor] -> Fixed128Tensor, ConcatOp);
modelled!(PlacementConcatenate::concatenate, ReplicatedPlacement, attributes[axis: u32] vec[ReplicatedFixed64Tensor] -> ReplicatedFixed64Tensor, ConcatOp);
modelled!(PlacementConcatenate::concatenate, ReplicatedPlacement, attributes[axis: u32] vec[ReplicatedFixed128Tensor] -> ReplicatedFixed128Tensor, ConcatOp);
modelled!(PlacementConcatenate::concatenate, ReplicatedPlacement, attributes[axis: u32] vec[ReplicatedRing64Tensor] -> ReplicatedRing64Tensor, ConcatOp);
modelled!(PlacementConcatenate::concatenate, ReplicatedPlacement, attributes[axis: u32] vec[ReplicatedRing128Tensor] -> ReplicatedRing128Tensor, ConcatOp);

kernel! {
ConcatOp, [
modelled_kernel! {
PlacementConcatenate::concatenate, ConcatOp{axis: u32},
[
(HostPlacement, vec[Tensor] -> Tensor => [concrete] attributes[axis] Self::logical_host_kernel),
(HostPlacement, vec[Float32Tensor] -> Float32Tensor => [concrete] attributes[axis] Self::float_host_kernel),
(HostPlacement, vec[Float64Tensor] -> Float64Tensor => [concrete] attributes[axis] Self::float_host_kernel),
Expand Down
7 changes: 3 additions & 4 deletions moose/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,6 @@ macro_rules! derive_runtime_kernel {
}))
}
};

(variadic, attributes[$($attr:ident$(: $prim_ty:ident)?),+] $k:expr, $self:ident) => {
{
$(
Expand Down Expand Up @@ -4548,7 +4547,7 @@ macro_rules! modelled_kernel {
}
};

(__variadic runtime, $trait:ident, $trait_fn:ident, $op:ident, $plc:ty, $([$($attr_id:ident: $attr_ty:ty),+])? ($t0:ty) -> $u:ty => $($kp:tt)+) => {
(__variadic runtime, $trait:ident, $trait_fn:ident, $op:ident, $plc:ty, $([$($attr_id:ident: $attr_ty:ty),+])? vec[$ts:ty] -> $u:ty => $($kp:tt)+) => {
impl crate::kernels::VariadicKernel<
crate::execution::SymbolicSession,
$plc,
Expand Down Expand Up @@ -4602,11 +4601,11 @@ macro_rules! modelled_kernel {
$($($attr_id:$attr_ty),*,)?
xs: &[<$ts as crate::computation::SymbolicType>::Type]
) -> <$u as crate::computation::SymbolicType>::Type {
use crate::computation::{KnownType, UnarySignature};
use crate::computation::{KnownType, VariadicSignature};
use crate::execution::{Session, SymbolicSession};
use std::convert::TryInto;

let sig = UnarySignature {
let sig = VariadicSignature {
args: <$ts as KnownType<SymbolicSession>>::TY,
ret: <$u as KnownType<SymbolicSession>>::TY,
};
Expand Down

0 comments on commit 165b4ec

Please sign in to comment.