diff --git a/Project.toml b/Project.toml index 21b3563..8ad3aef 100644 --- a/Project.toml +++ b/Project.toml @@ -9,6 +9,7 @@ ACSets = "227ef7b5-1206-438b-ac65-934d6da304b8" Catlab = "134e5e36-593f-5add-ad60-77f754baafbe" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078" +SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" Unicode = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" [compat] @@ -16,5 +17,6 @@ ACSets = "0.2" Catlab = "0.15, 0.16" DataStructures = "0.18.13" MLStyle = "0.4.17" +SymbolicUtils = "3.1.2" Unicode = "1.6" julia = "1.6" diff --git a/src/DiagrammaticEquations.jl b/src/DiagrammaticEquations.jl index 0d0de25..d226390 100644 --- a/src/DiagrammaticEquations.jl +++ b/src/DiagrammaticEquations.jl @@ -61,6 +61,8 @@ include("colanguage.jl") include("openoperators.jl") include("deca/Deca.jl") include("learn/Learn.jl") +include("ThDEC.jl") +include("decasymbolic.jl") using .Deca diff --git a/src/ThDEC.jl b/src/ThDEC.jl new file mode 100644 index 0000000..52f5198 --- /dev/null +++ b/src/ThDEC.jl @@ -0,0 +1,270 @@ +module ThDEC +using MLStyle + +import Base: +, -, * + +struct SortError <: Exception + message::String +end + +@data Sort begin + Scalar() + Form(dim::Int, isdual::Bool) + VField(isdual::Bool) +end +export Sort, Scalar, Form, VField + +const SORT_LOOKUP = Dict( + :Form0 => Form(0, false), + :Form1 => Form(1, false), + :Form2 => Form(2, false), + :DualForm0 => Form(0, true), + :DualForm1 => Form(1, true), + :DualForm2 => Form(2, true), + :Constant => Scalar() +) + +function Base.nameof(s::Scalar) + :Constant +end + +function Base.nameof(f::Form) + dual = isdual(f) ? "Dual" : "" + Symbol("$(dual)Form$(dim(f))") +end + +const VF = VField + +dim(ω::Form) = ω.dim +isdual(ω::Form) = ω.isdual + +isdual(v::VField) = v.isdual + +# convenience functions +PrimalForm(i::Int) = Form(i, false) +export PrimalForm + +DualForm(i::Int) = Form(i, true) +export DualForm + +PrimalVF() = VF(false) +export PrimalVF + +DualVF() = VF(true) +export DualVF + +# show methods +show_duality(ω::Form) = isdual(ω) ? "dual" : "primal" + +function Base.show(io::IO, ω::Form) + print(io, isdual(ω) ? "DualForm($(dim(ω)))" : "PrimalForm($(dim(ω)))") +end + +@nospecialize +function +(s1::Sort, s2::Sort) + @match (s1, s2) begin + (Scalar(), Scalar()) => Scalar() + (Scalar(), Form(i, isdual)) || + (Form(i, isdual), Scalar()) => Form(i, isdual) + (Form(i1, isdual1), Form(i2, isdual2)) => + if (i1 == i2) && (isdual1 == isdual2) + Form(i1, isdual1) + else + throw(SortError("Cannot add two forms of different dimensions/dualities: $((i1,isdual1)) and $((i2,isdual2))")) + end + end +end + +# Type-checking inverse of addition follows addition +-(s1::Sort, s2::Sort) = +(s1, s2) + +# TODO error for Forms + +# Negation is always valid +-(s::Sort) = s + +@nospecialize +function *(s1::Sort, s2::Sort) + @match (s1, s2) begin + (Scalar(), Scalar()) => Scalar() + (Scalar(), Form(i, isdual)) || + (Form(i, isdual), Scalar()) => Form(i, isdual) + (Form(_, _), Form(_, _)) => throw(SortError("Cannot scalar multiply a form with a form. Maybe try `∧`??")) + end +end + +const SUBSCRIPT_DIGIT_0 = '₀' + +function as_sub(n::Int) + join(map(d -> SUBSCRIPT_DIGIT_0 + d, digits(n))) +end + +@nospecialize +function ∧(s1::Sort, s2::Sort) + @match (s1, s2) begin + (Form(i, isdual), Scalar()) || (Scalar(), Form(i, isdual)) => Form(i, isdual) + (Form(i1, isdual), Form(i2, isdual)) => + if i1 + i2 <= 2 + Form(i1 + i2, isdual) + else + throw(SortError("Can only take a wedge product when the dimensions of the forms add to less than 2: tried to wedge product $i1 and $i2")) + end + _ => throw(SortError("Can only take a wedge product of two forms of the same duality")) + end +end + +function Base.nameof(::typeof(∧), s1, s2) + Symbol("∧$(as_sub(dim(s1)))$(as_sub(dim(s2)))") +end + +@nospecialize +∂ₜ(s::Sort) = s + +@nospecialize +function d(s::Sort) + @match s begin + Scalar() => throw(SortError("Cannot take exterior derivative of a scalar")) + Form(i, isdual) => + if i <= 1 + Form(i + 1, isdual) + else + throw(SortError("Cannot take exterior derivative of a n-form for n >= 1")) + end + end +end + +function Base.nameof(::typeof(d), s) + Symbol("d$(as_sub(dim(s)))") +end + +@nospecialize +function ★(s::Sort) + @match s begin + Scalar() => throw(SortError("Cannot take Hodge star of a scalar")) + Form(i, isdual) => Form(2 - i, !isdual) + end +end + +function Base.nameof(::typeof(★), s) + inv = isdual(s) ? "⁻¹" : "" + Symbol("★$(as_sub(isdual(s) ? 2 - dim(s) : dim(s)))$(inv)") +end + +@nospecialize +function ι(s1::Sort, s2::Sort) + @match (s1, s2) begin + (VF(true), Form(i, true)) => PrimalForm() # wrong + (VF(true), Form(i, false)) => DualForm() + _ => throw(SortError("Can only define the discrete interior product on: + PrimalVF, DualForm(i) + DualVF(), PrimalForm(i) + .")) + end +end + +# in practice, a scalar may be treated as a constant 0-form. +function ♯(s::Sort) + @match s begin + Scalar() => PrimalVF() + Form(1, isdual) => VF(isdual) + _ => throw(SortError("Can only take ♯ to 1-forms")) + end +end +# musical isos may be defined for any combination of (primal/dual) form -> (primal/dual) vf. + +function ♭(s::Sort) + @match s begin + VF(true) => PrimalForm(1) + _ => throw(SortError("Can only apply ♭ to dual vector fields")) + end +end + +# OTHER + +function ♭♯(s::Sort) + @match s begin + Form(i, isdual) => Form(i, !isdual) + _ => throw(SortError("♭♯ is only defined on forms.")) + end +end + +# Δ = ★d⋆d, but we check signature here to throw a more helpful error +function Δ(s::Sort) + @match s begin + Form(0, isdual) => Form(0, isdual) + _ => throw(SortError("Δ is not defined for $s")) + end +end + +const OPERATOR_LOOKUP = Dict( + :⋆₀ => ★, + :⋆₁ => ★, + :⋆₂ => ★, + + # Inverse Hodge Stars + :⋆₀⁻¹ => ★, + :⋆₁⁻¹ => ★, + :⋆₂⁻¹ => ★, + + # Differentials + :d₀ => d, + :d₁ => d, + + # Dual Differentials + :dual_d₀ => d, + :d̃₀ => d, + :dual_d₁ => d, + :d̃₁ => d, + + # Wedge Products + :∧₀₁ => ∧, + :∧₁₀ => ∧, + :∧₀₂ => ∧, + :∧₂₀ => ∧, + :∧₁₁ => ∧, + + # Primal-Dual Wedge Products + :∧ᵖᵈ₁₁ => ∧, + :∧ᵖᵈ₀₁ => ∧, + :∧ᵈᵖ₁₁ => ∧, + :∧ᵈᵖ₁₀ => ∧, + + # Dual-Dual Wedge Products + :∧ᵈᵈ₁₁ => ∧, + :∧ᵈᵈ₁₀ => ∧, + :∧ᵈᵈ₀₁ => ∧, + + # Dual-Dual Interior Products + :ι₁₁ => ι, + :ι₁₂ => ι, + + # Dual-Dual Lie Derivatives + # :ℒ₁ => ℒ, + + # Dual Laplacians + # :Δᵈ₀ => Δ, + # :Δᵈ₁ => Δ, + + # Musical Isomorphisms + :♯ => ♯, + :♯ᵈ => ♯, :♭ => ♭, + + # Averaging Operator + # :avg₀₁ => avg, + + # Negatives + :neg => -, + + # Basics + + :- => -, + :+ => +, + :* => *, + :/ => /, + :.- => .-, + :.+ => .+, + :.* => .*, + :./ => ./, +) + +end diff --git a/src/decasymbolic.jl b/src/decasymbolic.jl new file mode 100644 index 0000000..3b86fa6 --- /dev/null +++ b/src/decasymbolic.jl @@ -0,0 +1,194 @@ +module SymbolicUtilInterop + +using ..ThDEC +using MLStyle +import ..ThDEC: Sort, dim, isdual +using ..decapodes +using SymbolicUtils +using SymbolicUtils: Symbolic, BasicSymbolic + +abstract type DECType <: Number end + +""" +i: dimension: 0,1,2, etc. +d: duality: true = dual, false = primal +""" +struct FormT{i,d} <: DECType +end + +struct VFieldT{d} <: DECType +end + +dim(::Type{<:FormT{d}}) where {d} = d +isdual(::Type{FormT{i,d}}) where {i,d} = d + +# convenience functions +const PrimalFormT{i} = FormT{i,false} +export PrimalFormT + +const DualFormT{i} = FormT{i,true} +export DualFormT + +const PrimalVFT = VFieldT{false} +export PrimalVFT + +const DualVFT = VFieldT{true} +export DualVFT + +function Sort(::Type{FormT{i,d}}) where {i,d} + Form(i, d) +end + +function Number(f::Form) + FormT{dim(f),isdual(f)} +end + +function Sort(::Type{VFieldT{d}}) where {d} + VField(d) +end + +function Number(v::VField) + VFieldT{isdual(v)} +end + +function Sort(::Type{<:Real}) + Scalar() +end + +function Number(s::Scalar) + Real +end + +function Sort(::BasicSymbolic{T}) where {T} + Sort(T) +end + +function Sort(::Real) + Scalar() +end + +unop_dec = [:∂ₜ, :d, :★, :♯, :♭, :-] +for unop in unop_dec + @eval begin + @nospecialize + function ThDEC.$unop( + v::BasicSymbolic{T} + ) where {T<:DECType} + s = ThDEC.$unop(Sort(T)) + SymbolicUtils.Term{Number(s)}(ThDEC.$unop, [v]) + end + end +end + +binop_dec = [:+, :-, :*, :∧] +for binop in binop_dec + @eval begin + @nospecialize + function ThDEC.$binop( + v::BasicSymbolic{T1}, + w::BasicSymbolic{T2} + ) where {T1<:DECType,T2<:DECType} + s = ThDEC.$binop(Sort(T1), Sort(T2)) + SymbolicUtils.Term{Number(s)}(ThDEC.$binop, [v, w]) + end + + @nospecialize + function ThDEC.$binop( + v::BasicSymbolic{T1}, + w::BasicSymbolic{T2} + ) where {T1<:DECType,T2<:Real} + s = ThDEC.$binop(Sort(T1), Sort(T2)) + SymbolicUtils.Term{Number(s)}(ThDEC.$binop, [v, w]) + end + + @nospecialize + function ThDEC.$binop( + v::BasicSymbolic{T1}, + w::BasicSymbolic{T2} + ) where {T1<:Real,T2<:DECType} + s = ThDEC.$binop(Sort(T1), Sort(T2)) + SymbolicUtils.Term{Number(s)}(ThDEC.$binop, [v, w]) + end + end +end + +struct Equation{E} + lhs::E + rhs::E +end + +struct DecaSymbolic + vars::Vector{Symbolic} + equations::Vector{Equation{Symbolic}} +end + +function decapodes.Term(t::SymbolicUtils.BasicSymbolic) + if SymbolicUtils.issym(t) + decapodes.Var(nameof(t)) + else + op = SymbolicUtils.head(t) + args = SymbolicUtils.arguments(t) + termargs = Term.(args) + sorts = ThDEC.Sort.(args) + if op == + + decapodes.Plus(termargs) + elseif op == * + decapodes.Mult(termargs) + elseif op == ThDEC.∂ₜ + decapodes.Tan(only(termargs)) + elseif length(args) == 1 + decapodes.App1(nameof(op, sorts...), termargs...) + elseif length(args) == 2 + decapodes.App2(nameof(op, sorts...), termargs...) + else + error("was unable to convert $t into a Term") + end + end +end + +function decapodes.Term(x::Real) + decapodes.Lit(Symbol(x)) +end + +function decapodes.DecaExpr(d::DecaSymbolic) + context = map(d.vars) do var + decapodes.Judgement(nameof(var), nameof(Sort(var)), :I) + end + equations = map(d.equations) do eq + decapodes.Eq(decapodes.Term(eq.lhs), decapodes.Term(eq.rhs)) + end + decapodes.DecaExpr(context, equations) +end + +function SymbolicUtils.BasicSymbolic(context::Dict{Symbol,Sort}, t::decapodes.Term) + @match t begin + Var(name) => SymbolicUtils.Sym{Number(context[name])}(name) + Lit(v) => Meta.parse(string(v)) # YOLO + AppCirc1(fs, arg) => foldr( + (f, x) -> ThDEC.OPERATOR_LOOKUP[f](x), + fs; + init=BasicSymbolic(context, arg) + ) + App1(f, x) => ThDEC.OPERATOR_LOOKUP[f](BasicSymbolic(context, x)) + App2(f, x, y) => ThDEC.OPERATOR_LOOKUP[f](BasicSymbolic(context, x), BasicSymbolic(context, y)) + Plus(xs) => +(BasicSymbolic.(Ref(context), xs)...) + Mult(xs) => *(BasicSymbolic.(Ref(context), xs)...) + Tan(x) => ThDEC.∂ₜ(BasicSymbolic(context, x)) + end +end + +function DecaSymbolic(d::decapodes.DecaExpr) + context = map(d.context) do j + j.var => ThDEC.SORT_LOOKUP[j.dim] + end + vars = map(context) do (v, s) + SymbolicUtils.Sym{Number(s)}(v) + end + context = Dict{Symbol,Sort}(context) + eqs = map(d.equations) do eq + Equation{Symbolic}(BasicSymbolic(context, eq.lhs), BasicSymbolic(context, eq.rhs)) + end + DecaSymbolic(vars, eqs) +end + +end