diff --git a/src/api.jl b/src/api.jl index ff389aaf6..4dc261fa7 100644 --- a/src/api.jl +++ b/src/api.jl @@ -60,3 +60,24 @@ function substitute(expr, dict; fold=true) expr end end + +""" + occursin(needle::Symbolic, haystack::Symbolic) + +Determine whether the second argument contains the first argument. Note that +this function doesn't handle associativity, commutativity, or distributivity. +""" +Base.occursin(needle::Symbolic, haystack::Symbolic) = _occursin(needle, haystack) +Base.occursin(needle, haystack::Symbolic) = _occursin(needle, haystack) +Base.occursin(needle::Symbolic, haystack) = _occursin(needle, haystack) +function _occursin(needle, haystack) + isequal(needle, haystack) && return true + + if istree(haystack) + args = arguments(haystack) + for arg in args + occursin(needle, arg) && return true + end + end + return false +end diff --git a/test/basics.jl b/test/basics.jl index 4f4c5f72a..1b6b4bb6c 100644 --- a/test/basics.jl +++ b/test/basics.jl @@ -127,6 +127,13 @@ end @test substitute(exp(a), Dict(a=>2)) ≈ exp(2) end +@testset "occursin" begin + @syms a b c + @test occursin(a, a + b) + @test !occursin(sin(a), a + b + c) + @test occursin(sin(a), a * b + c + sin(a^2 * sin(a))) +end + @testset "printing" begin @syms a b c @test repr(a+b) == "a + b"