Skip to content

Commit

Permalink
WIP moshi
Browse files Browse the repository at this point in the history
  • Loading branch information
akirakyle committed Nov 20, 2024
1 parent 4d86901 commit fc753e0
Show file tree
Hide file tree
Showing 15 changed files with 437 additions and 199 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ DynamicPolynomials = "7c1d4256-1411-5781-91ec-d7bc3513ac07"
IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Moshi = "2e0e35c7-a2e4-4343-998d-7ef72827ed2d"
MultivariatePolynomials = "102ac46a-7ee4-5c85-9060-abc95bfdeaa3"
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
Expand All @@ -25,7 +26,6 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c"
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
Unityper = "a7c27f48-0311-42f6-a7f8-2c11e75eb415"
WeakValueDicts = "897b6980-f191-5a31-bcb0-bf3c4585e0c1"

[weakdeps]
Expand All @@ -48,6 +48,7 @@ DocStringExtensions = "0.8, 0.9"
DynamicPolynomials = "0.5, 0.6"
IfElse = "0.1"
LabelledArrays = "1.5"
Moshi = "0.3.5"
MultivariatePolynomials = "0.5"
NaNMath = "0.3, 1"
ReverseDiff = "1"
Expand All @@ -57,7 +58,6 @@ StaticArrays = "0.12, 1.0"
SymbolicIndexingInterface = "0.3"
TermInterface = "2.0"
TimerOutputs = "0.5"
Unityper = "0.1.2"
WeakValueDicts = "0.1.0"
julia = "1.3"

Expand Down
3 changes: 2 additions & 1 deletion src/SymbolicUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ using DocStringExtensions

export @syms, term, showraw, hasmetadata, getmetadata, setmetadata

using Unityper
using Moshi.Data: @data, data_type_name, variant_name
using Moshi.Match: @match
using TermInterface
using DataStructures
using Setfield
Expand Down
10 changes: 6 additions & 4 deletions src/code.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ export toexpr, Assignment, (←), Let, Func, DestructuredArgs, LiteralExpr,

import ..SymbolicUtils
import ..SymbolicUtils.Rewriters
import SymbolicUtils: @matchable, BasicSymbolic, Sym, Term, iscall, operation, arguments, issym,
symtype, sorted_arguments, metadata, isterm, term, maketerm
import SymbolicUtils: @matchable, BasicSymbolicType, Sym, Term, iscall, operation, arguments, issym,
isconst, symtype, sorted_arguments, metadata, isterm, term, maketerm
import SymbolicIndexingInterface: symbolic_type, NotSymbolic

##== state management ==##
Expand Down Expand Up @@ -156,7 +156,7 @@ function function_to_expr(::typeof(SymbolicUtils.ifelse), O, st)
:($(toexpr(args[1], st)) ? $(toexpr(args[2], st)) : $(toexpr(args[3], st)))
end

function function_to_expr(x::BasicSymbolic, O, st)
function function_to_expr(x::BasicSymbolicType, O, st)
issym(x) ? get(st.rewrites, O, nothing) : nothing
end

Expand All @@ -182,6 +182,8 @@ function toexpr(O, st)
if issym(O)
O = substitute_name(O, st)
return issym(O) ? nameof(O) : toexpr(O, st)
elseif isconst(O)
return toexpr(O.val, st)
end
O = substitute_name(O, st)

Expand Down Expand Up @@ -766,7 +768,7 @@ end
function cse_block(state, t, name=Symbol("var-", hash(t)))
assignments = Assignment[]
counter = Ref{Int}(1)
names = Dict{Any, BasicSymbolic}()
names = Dict{Any, BasicSymbolicType}()
Let(assignments, cse_block!(assignments, counter, names, name, state, t))
end

Expand Down
2 changes: 1 addition & 1 deletion src/inspect.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ function AbstractTrees.nodevalue(x::Symbolic)
iscall(x) ? operation(x) : isexpr(x) ? head(x) : x
end

function AbstractTrees.nodevalue(x::BasicSymbolic)
function AbstractTrees.nodevalue(x::BasicSymbolicType)
str = if !iscall(x)
string(exprtype(x), "(", x, ")")
elseif isadd(x)
Expand Down
18 changes: 16 additions & 2 deletions src/matchers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,23 @@
# 3. Callback: takes arguments Dictionary × Number of elements matched
#
function matcher(val::Any)
iscall(val) && return term_matcher(val)
if isconst(val)
slot = val.val
return matcher(slot)
elseif iscall(val)
return term_matcher(val)
end
function literal_matcher(next, data, bindings)
islist(data) && isequal(car(data), val) ? next(bindings, 1) : nothing
if islist(data)
cd = car(data)
if isconst(cd)
cd = cd.val
end
if isequal(cd, val)
return next(bindings, 1)
end
end
nothing
end
end

Expand Down
6 changes: 3 additions & 3 deletions src/methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ macro number_methods(T, rhs1, rhs2, options=nothing)
number_methods(T, rhs1, rhs2, options) |> esc
end

@number_methods(BasicSymbolic{<:Number}, term(f, a), term(f, a, b), skipbasics)
@number_methods(BasicSymbolic{<:LiteralReal}, term(f, a), term(f, a, b), onlybasics)
@number_methods(BasicSymbolicType{<:Number}, term(f, a), term(f, a, b), skipbasics)
@number_methods(BasicSymbolicType{<:LiteralReal}, term(f, a), term(f, a, b), onlybasics)

for f in vcat(diadic, [+, -, *, \, /, ^])
@eval promote_symtype(::$(typeof(f)),
Expand Down Expand Up @@ -188,7 +188,7 @@ end
for f in [!, ~]
@eval begin
promote_symtype(::$(typeof(f)), ::Type{<:Bool}) = Bool
(::$(typeof(f)))(s::Symbolic{Bool}) = Term{Bool}(!, [s])
(::$(typeof(f)))(s::Symbolic{Bool}) = isconst(s) ? !s.val : Term{Bool}(!, [s])
end
end

Expand Down
9 changes: 6 additions & 3 deletions src/ordering.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ function get_degrees(expr)
elseif iscall(expr)
op = operation(expr)
args = sorted_arguments(expr)
if op == (^) && args[2] isa Number
if op == (^) && (args[2] isa Number || (isconst(args[2]) && args[2].val isa Number))
return map(get_degrees(args[1])) do (base, pow)
(base => pow * args[2])
end
Expand Down Expand Up @@ -78,13 +78,16 @@ function <ₑ(a::Tuple, b::Tuple)
return length(a) < length(b)
end

function <(a::BasicSymbolic, b::BasicSymbolic)
function <(a::BasicSymbolicType, b::BasicSymbolicType)
isconst(a) && isconst(b) && return a.val <ₑ b.val
isconst(a) && return a.val <ₑ b
isconst(b) && return a <ₑ b.val
da, db = get_degrees(a), get_degrees(b)
fw = monomial_lt(da, db)
bw = monomial_lt(db, da)
if fw === bw && !isequal(a, b)
if _arglen(a) == _arglen(b)
return (operation(a), arguments(a)...,) <ₑ (operation(b), arguments(b)...,)
return (operation(a), arguments(a)...) <ₑ (operation(b), arguments(b)...)
else
return _arglen(a) < _arglen(b)
end
Expand Down
5 changes: 3 additions & 2 deletions src/polyform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ PolyForm(sin((x+y)^2), recurse=true) #=> sin((x^2 + (2x)y + y^2))
struct PolyForm{T} <: Symbolic{T}
p::MP.AbstractPolynomialLike
pvar2sym::Bijection{Any,Any} # @polyvar x --> @sym x etc.
sym2term::Dict{BasicSymbolic,Any} # Symbol("sin-$hash(sin(x+y))") --> sin(x+y) => sin(PolyForm(...))
sym2term::Dict{BasicSymbolicType,Any} # Symbol("sin-$hash(sin(x+y))") --> sin(x+y) => sin(PolyForm(...))
metadata
function (::Type{PolyForm{T}})(p, d1, d2, m=nothing) where {T}
p isa Number && return p
Expand Down Expand Up @@ -63,7 +63,7 @@ end
function get_sym2term()
v = SYM2TERM[].value
if v === nothing
d = Dict{BasicSymbolic,Any}()
d = Dict{BasicSymbolicType,Any}()
SYM2TERM[] = WeakRef(d)
return d
else
Expand Down Expand Up @@ -95,6 +95,7 @@ end
_isone(p::PolyForm) = isone(p.p)

function polyize(x, pvar2sym, sym2term, vtype, pow, Fs, recurse)
x = isconst(x) ? x.val : x
if x isa Number
return x
elseif iscall(x)
Expand Down
2 changes: 2 additions & 0 deletions src/substitute.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ function substitute(expr, dict; fold=true)
canfold = !(op isa Symbolic)
args = map(arguments(expr)) do x
x′ = substitute(x, dict; fold=fold)
x′ = isconst(x) ? x′.val : x′
canfold = canfold && !(x′ isa Symbolic)
x′
end
Expand Down Expand Up @@ -54,6 +55,7 @@ function _occursin(needle, haystack)
if iscall(haystack)
args = arguments(haystack)
for arg in args
arg = isconst(arg) ? arg.val : arg
if needle isa Integer || needle isa AbstractFloat
isequal(needle, arg) && return true
else
Expand Down
Loading

0 comments on commit fc753e0

Please sign in to comment.