Skip to content

Commit

Permalink
New coercion behaviour as per dfinity/candid#311
Browse files Browse the repository at this point in the history
  • Loading branch information
nomeata committed Apr 8, 2022
1 parent 022fc9e commit 5619727
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 184 deletions.
3 changes: 1 addition & 2 deletions src/Codec/Candid/Class.hs
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,7 @@ decode b = do
-- Decode
(ts, vs) <- decodeVals b
-- Coerce to expected type
c <- coerceSeqDesc ts (buildSeqDesc (asTypes @(AsTuple a)))
vs' <- c vs
vs' <- coerceSeqDesc vs ts (buildSeqDesc (asTypes @(AsTuple a)))
fromCandidVals vs'

-- | Decode (dynamic) values to Haskell type
Expand Down
270 changes: 89 additions & 181 deletions src/Codec/Candid/Coerce.hs
Original file line number Diff line number Diff line change
Expand Up @@ -5,255 +5,163 @@
{-# LANGUAGE FlexibleContexts #-}
module Codec.Candid.Coerce
( coerceSeqDesc
, SeqCoercion
, coerce
, Coercion
)
where

import Prettyprinter
import qualified Data.Vector as V
import qualified Data.ByteString.Lazy as BS
import qualified Data.Map as M
import Data.Bifunctor
import Data.List
import Data.Tuple
import Control.Monad.State.Lazy
import Control.Monad.Except

import Codec.Candid.FieldName
import Codec.Candid.Types
import Codec.Candid.TypTable
import Codec.Candid.Subtype

type SeqCoercion = [Value] -> Either String [Value]
type Coercion = Value -> Either String Value

coerceSeqDesc :: SeqDesc -> SeqDesc -> Either String SeqCoercion
coerceSeqDesc sd1 sd2 =
coerceSeqDesc :: [Value] -> SeqDesc -> SeqDesc -> Either String [Value]
coerceSeqDesc vs sd1 sd2 =
unrollTypeTable sd1 $ \ts1 ->
unrollTypeTable sd2 $ \ts2 ->
coerceSeq ts1 ts2
coerceSeq vs ts1 ts2

coerceSeq ::
(Pretty k1, Pretty k2, Ord k1, Ord k2) =>
[Value] ->
[Type (Ref k1 Type)] ->
[Type (Ref k2 Type)] ->
Either String SeqCoercion
coerceSeq t1 t2 = runM $ goSeq t1 t2
Either String [Value]
coerceSeq vs t1 t2 = runSubTypeM $ goSeq vs t1 t2

-- | This function implements the `V : T ~> V' : T'` relation from the Candid spec.
--
-- `C[<t> <: <t>]` coercion function from the
-- spec. It returns `Left` if no subtyping relation holds, or `Right c` if it
-- holds, together with a coercion function.
--
-- The coercion function itself is not total because the intput value isn’t
-- typed, so we have to cater for errors there. It should not fail if the
-- passed value really is inherently of the input type.
--
-- In a dependently typed language we’d maybe have something like
-- `coerce :: foreach t1 -> foreach t2 -> Either String (t1 -> t2)`
-- instead, and thus return a total function
-- Because values in this library are untyped, we have to pass what we know about
-- their type down, so that we can do the subtype check upon a reference.
-- The given type must match the value closely (as in the type description)
coerce ::
(Pretty k1, Pretty k2, Ord k1, Ord k2) =>
Value ->
Type (Ref k1 Type) ->
Type (Ref k2 Type) ->
Either String Coercion
coerce t1 t2 = runM $ memo t1 t2

type Memo k1 k2 =
(M.Map (Type (Ref k1 Type), Type (Ref k2 Type)) Coercion,
M.Map (Type (Ref k2 Type), Type (Ref k1 Type)) Coercion)
type M k1 k2 = ExceptT String (State (Memo k1 k2))

runM :: (Ord k1, Ord k2) => M k1 k2 a -> Either String a
runM act = evalState (runExceptT act) (mempty, mempty)
Either String Value
coerce v t1 t2 = runSubTypeM $ go v t1 t2

flipM :: M k1 k2 a -> M k2 k1 a
flipM (ExceptT (StateT f)) = ExceptT (StateT f')
where
f' (m1,m2) = second swap <$> f (m2,m1) -- f (m2,m1) >>= \case (r, (m2',m1')) -> pure (r, (m1', m2'))

memo, go ::
go ::
(Pretty k1, Pretty k2, Ord k1, Ord k2) =>
Value ->
Type (Ref k1 Type) ->
Type (Ref k2 Type) ->
M k1 k2 Coercion
SubTypeM k1 k2 Value

goSeq ::
(Pretty k1, Pretty k2, Ord k1, Ord k2) =>
[Value] ->
[Type (Ref k1 Type)] ->
[Type (Ref k2 Type)] ->
M k1 k2 SeqCoercion


-- Memoization uses lazyiness: When we see a pair for the first time,
-- we optimistically put the resulting coercion into the map.
-- Either the following recursive call will fail (but then this optimistic
-- value was never used), or it will succeed, but then the guess was correct.
memo t1 t2 = do
gets (M.lookup (t1,t2) . fst) >>= \case
Just c -> pure c
Nothing -> mdo
modify (first (M.insert (t1,t2) c))
c <- go t1 t2
return c
SubTypeM k1 k2 [Value]

-- Look through refs
go (RefT (Ref _ t1)) t2 = memo t1 t2
go t1 (RefT (Ref _ t2)) = memo t1 t2
go v (RefT (Ref _ t1)) t2 = go v t1 t2
go v t1 (RefT (Ref _ t2)) = go v t1 t2

-- Identity coercion for primitive values
go NatT NatT = pure pure
go Nat8T Nat8T = pure pure
go Nat16T Nat16T = pure pure
go Nat32T Nat32T = pure pure
go Nat64T Nat64T = pure pure
go IntT IntT = pure pure
go Int8T Int8T = pure pure
go Int16T Int16T = pure pure
go Int32T Int32T = pure pure
go Int64T Int64T = pure pure
go Float32T Float32T = pure pure
go Float64T Float64T = pure pure
go BoolT BoolT = pure pure
go TextT TextT = pure pure
go NullT NullT = pure pure
go PrincipalT PrincipalT = pure pure
go v NatT NatT = pure v
go v Nat8T Nat8T = pure v
go v Nat16T Nat16T = pure v
go v Nat32T Nat32T = pure v
go v Nat64T Nat64T = pure v
go v IntT IntT = pure v
go v Int8T Int8T = pure v
go v Int16T Int16T = pure v
go v Int32T Int32T = pure v
go v Int64T Int64T = pure v
go v Float32T Float32T = pure v
go v Float64T Float64T = pure v
go v BoolT BoolT = pure v
go v TextT TextT = pure v
go v NullT NullT = pure v
go v PrincipalT PrincipalT = pure v

-- Nat <: Int
go NatT IntT = pure $ \case
NatV n -> pure $ IntV (fromIntegral n)
v -> throwError $ show $ "Unexpected value" <+> pretty v <+> "while coercing nat <: int"
go (NatV n) NatT IntT = pure $ IntV (fromIntegral n)

-- t <: reserved
go _ ReservedT = pure (const (pure ReservedV))
go _ _ ReservedT = pure ReservedV

-- empty <: t
go EmptyT _ = pure $ \v ->
throwError $ show $ "Unexpected value" <+> pretty v <+> "while coercing empty"
-- empty <: t (actually just a special case of `v :/ t`)
go v EmptyT _ = throwError $ show $ "Unexpected value" <+> pretty v <+> "while coercing empty"

-- vec t1 <: vec t2
go (VecT t1) (VecT t2) = do
c <- memo t1 t2
pure $ \case
VecV vs -> VecV <$> mapM c vs
v -> throwError $ show $ "Unexpected value" <+> pretty v <+> "while coercing vector"
go (VecV vs) (VecT t1) (VecT t2) = VecV <$> mapM (\v -> go v t1 t2) vs

-- Option: The normal rule
go (OptT t1) (OptT t2) = lift (runExceptT (memo t1 t2)) >>= \case
Right c -> pure $ \case
OptV Nothing -> pure (OptV Nothing)
OptV (Just v) -> OptV . Just <$> c v
v -> throwError $ show $ "Unexpected value" <+> pretty v <+> "while coercing option"
Left _ -> pure (const (pure (OptV Nothing)))
go (OptV Nothing) (OptT _) (OptT _) = pure NullV
go (OptV (Just v)) (OptT t1) (OptT t2) =
lift (runExceptT (go v t1 t2)) >>= \case
Right v' -> pure (OptV (Just v'))
Left _ -> pure (OptV Nothing)

-- Option: The constituent rule
go t (OptT t2) | not (isOptLike t2) = lift (runExceptT (memo t t2)) >>= \case
Right c -> pure $ \v -> OptV . Just <$> c v
Left _ -> pure (const (pure (OptV Nothing)))
go v t1 (OptT t2) | not (isOptLike t2) =
lift (runExceptT (go v t1 t2)) >>= \case
Right v' -> pure (OptV (Just v'))
Left _ -> pure (OptV Nothing)

-- Option: The fallback rule
go _ (OptT _) = pure (const (pure (OptV Nothing)))
go _ _ (OptT _) = pure (OptV Nothing)

-- Records
go (RecT fs1) (RecT fs2) = do
let m1 = M.fromList fs1
let m2 = M.fromList fs2
new_fields <- sequence
[ case unRef t of
OptT _ -> pure (fn, OptV Nothing)
ReservedT -> pure (fn, ReservedV)
t -> throwError $ show $ "Missing record field" <+> pretty fn <+> "of type" <+> pretty t
| (fn, t) <- M.toList $ m2 M.\\ m1
]
field_coercions <- sequence
[ do c <- memo t1 t2
pure $ \vm -> case M.lookup fn vm of
Nothing -> throwError $ show $ "Record value lacks field" <+> pretty fn <+> "of type" <+> pretty t1
Just v -> (fn, ) <$> c v
| (fn, (t1, t2)) <- M.toList $ M.intersectionWith (,) m1 m2
]
pure $ \case
TupV ts -> do
let vm = M.fromList $ zip [hashedField n | n <- [0..]] ts
coerced_fields <- mapM ($ vm) field_coercions
return $ RecV $ sortOn fst $ coerced_fields <> new_fields
RecV fvs -> do
let vm = M.fromList fvs
coerced_fields <- mapM ($ vm) field_coercions
return $ RecV $ sortOn fst $ coerced_fields <> new_fields
go rv (RecT fs1) (RecT fs2) = do
vm <- case rv of
TupV ts -> pure $ M.fromList $ zip [hashedField n | n <- [0..]] ts
RecV fvs -> pure $ M.fromList fvs
v -> throwError $ show $ "Unexpected value" <+> pretty v <+> "while coercing record"

let m1 = M.fromList fs1
fmap RecV $ forM fs2 $ \(fn, t2) -> (fn,) <$>
case (M.lookup fn vm, M.lookup fn m1) of
(Just v, Just t1) -> go v t1 t2
_ -> case unRef t2 of
OptT _ -> pure (OptV Nothing)
ReservedT -> pure ReservedV
t -> throwError $ show $ "Missing record field" <+> pretty fn <+> "of type" <+> pretty t

-- Variants
go (VariantT fs1) (VariantT fs2) = do
go (VariantV fn v) (VariantT fs1) (VariantT fs2) = do
let m1 = M.fromList fs1
let m2 = M.fromList fs2
cm <- M.traverseWithKey (\fn t1 ->
case M.lookup fn m2 of
Just t2 -> memo t1 t2
Nothing -> throwError $ show $ "Missing variant field" <+> pretty fn <+> "of type" <+> pretty t1
) m1
pure $ \case
VariantV fn v | Just c <- M.lookup fn cm -> VariantV fn <$> c v
| otherwise -> throwError $ show $ "Unexpected variant field" <+> pretty fn
v -> throwError $ show $ "Unexpected value" <+> pretty v <+> "while coercing variant"
case (M.lookup fn m1, M.lookup fn m2) of
(Just t1, Just t2) -> VariantV fn <$> go v t1 t2
(Nothing, _) -> throwError $ show $ "Wrongly typed variant missing field " <+> pretty fn
(_, Nothing) -> throwError $ show $ "Unexpected variant field" <+> pretty fn

-- Reference types
go (FuncT mt1) (FuncT mt2) = goMethodType mt1 mt2 >> pure pure
go (ServiceT meths1) (ServiceT meths2) = do
let m1 = M.fromList meths1
forM_ meths2 $ \(m, mt2) -> case M.lookup m m1 of
Just mt1 -> goMethodType mt1 mt2
Nothing -> throwError $ show $ "Missing service method" <+> pretty m <+> "of type" <+> pretty mt2
pure pure
go v t1@(FuncT _) t2@(FuncT _) = isSubtypeOfM t1 t2 >> pure v
go v t1@(ServiceT _) t2@(ServiceT _) = isSubtypeOfM t1 t2 >> pure v

-- BlobT
go BlobT BlobT = pure pure
go (VecT t) BlobT | isNat8 t = pure $ \case
VecV vs -> BlobV . BS.pack . V.toList <$> mapM goNat8 vs
v -> throwError $ show $ "Unexpected value" <+> pretty v <+> "while coercing vec nat8 to blob"
go v BlobT BlobT = pure v
go (VecV vs) (VecT t) BlobT | isNat8 t = BlobV . BS.pack . V.toList <$> mapM goNat8 vs
where
goNat8 (Nat8V n) = pure n
goNat8 v = throwError $ show $ "Unexpected value" <+> pretty v <+> "while coercing vec nat8 to blob"
go BlobT (VecT t) | isNat8 t = pure $ \case
BlobV b -> return $ VecV $ V.fromList $ map (Nat8V . fromIntegral) $ BS.unpack b
v -> throwError $ show $ "Unexpected value" <+> pretty v <+> "while coercing blob to vec nat8"

go t1 t2 = throwError $ show $ "Type" <+> pretty t1 <+> "is not a subtype of" <+> pretty t2

goMethodType ::
(Pretty k2, Pretty k1, Ord k2, Ord k1) =>
MethodType (Ref k1 Type) ->
MethodType (Ref k2 Type) ->
M k1 k2 ()
goMethodType (MethodType ta1 tr1 q1 o1) (MethodType ta2 tr2 q2 o2) = do
unless (q1 == q2) $ throwError "Methods differ in query annotation"
unless (o1 == o2) $ throwError "Methods differ in oneway annotation"
void $ flipM $ goSeq ta2 ta1
void $ goSeq tr1 tr2

goSeq _ [] = pure (const (return []))
goSeq ts1 (RefT (Ref _ t) : ts) = goSeq ts1 (t:ts)
goSeq ts1@[] (NullT : ts) = do
cs2 <- goSeq ts1 ts
pure $ \_vs -> (NullV :) <$> cs2 []
goSeq ts1@[] (OptT _ : ts) = do
cs2 <- goSeq ts1 ts
pure $ \_vs -> (OptV Nothing :) <$> cs2 []
goSeq ts1@[] (ReservedT : ts) = do
cs2 <- goSeq ts1 ts
pure $ \_vs -> (ReservedV :) <$> cs2 []
goSeq [] ts =
throwError $ show $ "Argument type list too short, expecting types" <+> pretty ts
goSeq (t1:ts1) (t2:ts2) = do
c1 <- memo t1 t2
cs2 <- goSeq ts1 ts2
pure $ \case
[] -> throwError $ show $ "Expecting value of type:" <+> pretty t1
(v:vs) -> do
v' <- c1 v
vs' <- cs2 vs
return (v':vs')
go (BlobV b) BlobT (VecT t) | isNat8 t = pure $ VecV $ V.fromList $ map (Nat8V . fromIntegral) $ BS.unpack b

go v t1 t2 = throwError $ show $ "Cannot coerce " <+> pretty v <+> ":" <+> pretty t1 <+> "to type " <+> pretty t2

goSeq _ _ [] = pure []
goSeq vs ts1 (RefT (Ref _ t) : ts) = goSeq vs ts1 (t:ts)
goSeq vs@[] ts1@[] (NullT : ts) = (NullV :) <$> goSeq vs ts1 ts
goSeq vs@[] ts1@[] (OptT _ : ts) = (OptV Nothing :) <$> goSeq vs ts1 ts
goSeq vs@[] ts1@[] (ReservedT : ts) = (ReservedV :) <$> goSeq vs ts1 ts
goSeq [] [] ts = throwError $ show $ "Argument type list too short, expecting types" <+> pretty ts
goSeq (v:vs) (t1:ts1) (t2:ts2) = do
v' <- go v t1 t2
vs' <- goSeq vs ts1 ts2
pure $ v' : vs'
goSeq _ _ _ = throwError $ "Illtyped input to goSeq"

unRef :: Type (Ref a Type) -> Type (Ref a Type)
unRef (RefT (Ref _ t)) = unRef t
Expand Down
1 change: 0 additions & 1 deletion src/Codec/Candid/Subtype.hs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ goSeq ::
[Type (Ref k2 Type)] ->
SubTypeM k1 k2 ()


-- Memoization: When we see a pair for the first time,
-- we optimistically put 'True' into the map.
-- Either the following recursive call will fail (but then this optimistic
Expand Down

0 comments on commit 5619727

Please sign in to comment.