Skip to content

Commit

Permalink
adding @alias and @register macros to make DecaSymbolic function work…
Browse files Browse the repository at this point in the history
… in test/klausmeier
  • Loading branch information
quffaro committed Aug 27, 2024
1 parent 090ddfe commit d8be4ae
Show file tree
Hide file tree
Showing 5 changed files with 156 additions and 132 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Catlab = "134e5e36-593f-5add-ad60-77f754baafbe"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Revise = "295af30f-e4ad-537b-8983-00126c2a3abe"
StructEquality = "6ec83bb0-ed9f-11e9-3b4c-2b04cb4e219c"
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
Unicode = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
Expand All @@ -19,8 +20,8 @@ ACSets = "0.2"
Catlab = "0.15, 0.16"
DataStructures = "0.18.13"
MLStyle = "0.4.17"
Reexport = "1.2.2"
StructEquality = "2.1.0"
SymbolicUtils = "3.1.2"
Unicode = "1.6"
Reexport = "1.2.2"
julia = "1.6"
129 changes: 39 additions & 90 deletions src/ThDEC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,31 @@ using StructEquality

import Base: +, -, *

"""
Given a tuple of symbols ("aliases") and their canonical name (or "rep"), produces
for each alias typechecking and nameof methods which call those for their rep.
Example:
@alias d₀, d₁, += d
"""
macro alias(body)
(rep, aliases) = @match body begin
Expr(:tuple, rep, Expr(:tuple, aliases...)) => (rep, aliases)
_ => nothing
end
result = quote end
foreach(aliases) do alias
push!(result.args,
quote
function $(esc(alias))(s...)
$(esc(rep))(s...)
end
export $(esc(alias))
Base.nameof(::typeof($alias), s) = nameof($rep, s)
end)
end
result
end

struct SortError <: Exception
message::String
end
Expand All @@ -24,7 +49,9 @@ struct SpaceLookup
end
export SpaceLookup

SpaceLookup(default::Space) = SpaceLookup(default, Dict{Symbol, Space}(nameof(default) => default))
function SpaceLookup(default::Space)
SpaceLookup(default, Dict{Symbol, Space}(nameof(default) => default))
end

@data Sort begin
Scalar()
Expand Down Expand Up @@ -138,9 +165,7 @@ end

const SUBSCRIPT_DIGIT_0 = ''

function as_sub(n::Int)
join(map(d -> SUBSCRIPT_DIGIT_0 + d, digits(n)))
end
as_sub(n::Int) = join(map(d -> SUBSCRIPT_DIGIT_0 + d, digits(n)))

# TODO: VField
@nospecialize
Expand Down Expand Up @@ -180,20 +205,23 @@ function d(s::Sort)
end
end

function Base.nameof(::typeof(d), s)
Symbol("d$(as_sub(dim(s)))")
end
@alias d, (d₀, d₁)

Base.nameof(::typeof(d), s) = Symbol("d$(as_sub(dim(s)))")

@nospecialize
function (s::Sort)
function (s::Sort)
@match s begin
Scalar() => throw(SortError("Cannot take Hodge star of a scalar"))
VF(isdual, space) => throw(SortError("Cannot take the Hodge star of a vector field"))
Form(i, isdual, space) => Form(dim(space) - i, !isdual, space)
end
end
export

@alias , (₀, ₁, ₂, ₀⁻¹, ₁⁻¹, ₂⁻¹)

function Base.nameof(::typeof(), s)
function Base.nameof(::typeof(), s)
inv = isdual(s) ? "⁻¹" : ""
Symbol("$(as_sub(isdual(s) ? dim(space(s)) - dim(s) : dim(s)))$(inv)")
end
Expand All @@ -220,11 +248,7 @@ function ♯(s::Sort)
end
# musical isos may be defined for any combination of (primal/dual) form -> (primal/dual) vf.

# TODO
function Base.nameof(::typeof(♯), s)
Symbol("$s")
end

Base.nameof(::typeof(♯), s) = Symbol("$s")

function (s::Sort)
@match s begin
Expand All @@ -233,10 +257,7 @@ function ♭(s::Sort)
end
end

# TODO
function Base.nameof(::typeof(♭), s)
Symbol("$s")
end
Base.nameof(::typeof(♭), s) = Symbol("$s")

# OTHER

Expand All @@ -259,76 +280,4 @@ end

Base.nameof(::typeof(Δ), s) = Symbol("Δ")

const OPERATOR_LOOKUP = Dict(
:=> ★,
:=> ★,
:=> ★,

# Inverse Hodge Stars
:₀⁻¹ => ★,
:₁⁻¹ => ★,
:₂⁻¹ => ★,

# Differentials
:d₀ => d,
:d₁ => d,

# Dual Differentials
:dual_d₀ => d,
:d̃₀ => d,
:dual_d₁ => d,
:d̃₁ => d,

# Wedge Products
:₀₁ => ,
:₁₀ => ,
:₀₂ => ,
:₂₀ => ,
:₁₁ => ,

# Primal-Dual Wedge Products
:ᵖᵈ₁₁ => ,
:ᵖᵈ₀₁ => ,
:ᵈᵖ₁₁ => ,
:ᵈᵖ₁₀ => ,

# Dual-Dual Wedge Products
:ᵈᵈ₁₁ => ,
:ᵈᵈ₁₀ => ,
:ᵈᵈ₀₁ => ,

# Dual-Dual Interior Products
:ι₁₁ => ι,
:ι₁₂ => ι,

# Dual-Dual Lie Derivatives
# :ℒ₁ => ℒ,
# :L => ℒ,

# Dual Laplacians
# :Δᵈ₀ => Δ,
# :Δᵈ₁ => Δ,

# Musical Isomorphisms
:♯ => ♯,
:♯ᵈ => ♯, :♭ => ♭,

# Averaging Operator
# :avg₀₁ => avg,

# Negatives
:neg => -,

# Basics

:- => -,
:+ => +,
:* => *,
:/ => /,
:.- => .-,
:.+ => .+,
:.* => .*,
:./ => ./,
)

end
97 changes: 78 additions & 19 deletions src/decasymbolic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@ module SymbolicUtilsInterop
using ..DiagrammaticEquations: AbstractDecapode
import ..DiagrammaticEquations: eval_eq!, SummationDecapode
using ..ThDEC
using MLStyle
import ..ThDEC: Sort, dim, isdual
using ..decapodes
using SymbolicUtils

using SymbolicUtils: Symbolic, BasicSymbolic
using MLStyle
using SymbolicUtils
using SymbolicUtils: Symbolic, BasicSymbolic, FnType, Sym

# ##########################
# DECType
Expand Down Expand Up @@ -65,17 +65,6 @@ end

Sort(::BasicSymbolic{T}) where {T} = Sort(T)

# converts a sort to its Julia symbol
function to_symb(sort::Sort)
@match sort begin
Scalar() => :Constant
Form(i, isdual, X) =>
Symbol("$(isdual ? "Dual" : "")Form$i")
VField(isdual, X) =>
Symbol("$(isdual ? "Dual" : "")VF")
end
end

"""
converts ThDEC Sorts into DecaSymbolic types
"""
Expand All @@ -85,8 +74,13 @@ Number(f::Form) = FormT{dim(f),isdual(f), nameof(space(f)), dim(space(f))}

Number(v::VField) = VFieldT{isdual(v), nameof(space(v)), dim(space(v))}

# HERE WE DEFINE THE SYMBOLICUTILS

# for every unary operator in our theory, take a BasicSymbolic type, convert its type parameter to a Sort in our theory, and return a term
unop_dec = [:∂ₜ, :d, :★, :♯, :♭, :-]
unop_dec = [:∂ₜ, :d, :d₀, :d₁
, :, :₀, :₁, :₂, :₀⁻¹, :₁⁻¹, :₂⁻¹
, :♯, :♭, :-]

for unop in unop_dec
@eval begin
@nospecialize
Expand All @@ -99,6 +93,8 @@ for unop in unop_dec
end
end

# BasicSymbolic{FnType{Tuple{PrimalFormT{0}}}, PrimalFormT{0}}

binop_dec = [:+, :-, :*, :, :^]
export +,-,*,,^

Expand Down Expand Up @@ -211,6 +207,8 @@ Example:
```
"""
function SymbolicUtils.BasicSymbolic(context::Dict{Symbol,Sort}, t::decapodes.Term)
# user must import symbols into scope
! = (f -> getfield(Main, f))
@match t begin
Var(name) => SymbolicUtils.Sym{Number(context[name])}(name)
Lit(v) => Meta.parse(string(v))
Expand All @@ -219,12 +217,13 @@ function SymbolicUtils.BasicSymbolic(context::Dict{Symbol,Sort}, t::decapodes.Te
AppCirc1(fs, arg) => foldr(
# panics with constants like :k
# see test/language.jl
(f, x) -> ThDEC.OPERATOR_LOOKUP[f](x),
(f, x) -> (!(f))(x),
fs;
init=BasicSymbolic(context, arg)
)
App1(f, x) => ThDEC.OPERATOR_LOOKUP[f](BasicSymbolic(context, x))
App2(f, x, y) => ThDEC.OPERATOR_LOOKUP[f](BasicSymbolic(context, x), BasicSymbolic(context, y))
# getfield(Main,
App1(f, x) => (!(f))(BasicSymbolic(context, x))
App2(f, x, y) => (!(f))(BasicSymbolic(context, x), BasicSymbolic(context, y))
Plus(xs) => +(BasicSymbolic.(Ref(context), xs)...)
Mult(xs) => *(BasicSymbolic.(Ref(context), xs)...)
Tan(x) => ThDEC.∂ₜ(BasicSymbolic(context, x))
Expand Down Expand Up @@ -258,7 +257,7 @@ function SummationDecapode(e::DecaSymbolic)

foreach(e.vars) do var
# convert Sort(var)::PrimalForm0 --> :Form0
var_id = add_part!(d, :Var, name=var.name, type=to_symb(Sort(var)))
var_id = add_part!(d, :Var, name=var.name, type=nameof(Sort(var)))
symbol_table[var.name] = var_id
end

Expand All @@ -276,6 +275,66 @@ function SummationDecapode(e::DecaSymbolic)
return d
end

"""
Registers a new function
```
@register Δ(s::Sort) begin
@match s begin
::Scalar => error("Invalid")
::VField => error("Invalid")
::Form => ⋆(d(⋆(d(s))))
end
end
```
will create an additional method for Δ for operating on BasicSymbolic
"""
macro register(head, body)
# parse head
parsehead = begin
Expr(:call, f, types...) => (f, parsehead.(types))
Expr(:(::), var, type) => (var, type)
s => s
end
(f, args) = parsehead(head)
matchargs = [:($(x[1])::$(x[2])) for x in args]

result = quote end
push!(result.args,
esc(quote
function $f($(matchargs...))
$body
end
end))

# e.g., given [(:x, :Scalar), (:ω, :Form)]...
vs = enumerate(unique(getindex.(args, 2)))
theargs =
Dict{Symbol,Symbol}(
[v => Symbol("T$k") for (k,v) in vs]
)
# ...[(Scalar=>:T1, :Form=>:T2)]

# reassociate vars with their BasicSymbolic Generic Types
binding = map(args) do (var, type)
(var, :(BasicSymbolic{$(theargs[type])}))
end
newargs = [:($(x[1])::$(x[2])) for x in binding]
constraints = [:($T<:DECType) for T in values(theargs)]
innerargs = [:(Sort($T)) for T in values(theargs)]

push!(result.args,
quote
@nospecialize
function $(esc(f))($(newargs...)) where $(constraints...)
s = $(esc(f))($(innerargs...))
SymbolicUtils.Term{Number(s)}($(esc(f)), [$(getindex.(binding, 1)...)])
end
end)

return result
end
export @register

end
Loading

0 comments on commit d8be4ae

Please sign in to comment.