Skip to content

Commit

Permalink
Merge pull request #1345 from n0rbed/tests
Browse files Browse the repository at this point in the history
Tests for solve_interms_ofvar and bug fixes
  • Loading branch information
ChrisRackauckas authored Nov 5, 2024
2 parents 705b61b + 62d8d0e commit d16bc03
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 7 deletions.
2 changes: 1 addition & 1 deletion src/solver/ia_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ function solve_interms_ofvar(eq, s; dropmultiplicity=true, warns=true)
coeffs, constant = polynomial_coeffs(eq, [s])
eqs = wrap.(collect(values(coeffs)))

solve_multivar(eqs, vars, dropmultiplicity=dropmultiplicity, warns=warns)
symbolic_solve(eqs, vars, dropmultiplicity=dropmultiplicity, warns=warns)
end

# an attempt at using ia_solve recursively.
Expand Down
8 changes: 3 additions & 5 deletions src/solver/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ function solve_univar(expression, x; dropmultiplicity=true)
factors_subbed = map(factor -> ssubs(factor, subs), factors)
arr_roots = []

if degree < 5 && length(factors) == 1
if degree < 5 && isequal(factors_subbed[1], wrap(expression))
arr_roots = get_roots(expression, x)

# multiplicities (repeated roots)
Expand All @@ -296,10 +296,8 @@ function solve_univar(expression, x; dropmultiplicity=true)
append!(arr_roots, og_arr_roots)
end
end
end

if length(factors) != 1
for i in eachindex(factors_subbed)
elseif length(factors) > 1 || (length(factors) == 1 && !isequal(factors_subbed[1], wrap(expression)))
for i in eachindex(factors_subbed)
if !any(isequal(x, var) for var in get_variables(factors[i]))
continue
end
Expand Down
19 changes: 18 additions & 1 deletion test/solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,24 @@ function check_approx(arr1, arr2)
return true
end

@variables x y z a b c d e
@variables x y z a b c d e s

@testset "Solving in terms of a constant var" begin
eq = ((s^2 + 1)/(s^2 + 2*s + 1)) - ((s^2 + a)/(b*c*s^2 + (b+c)*s + d))
calcd_roots = sort_arr(Symbolics.solve_interms_ofvar(eq, s), [a,b,c,d])
known_roots = sort_arr([Dict(a=>1, b=>1, c=>1, d=>1)], [a,b,c,d])
@test check_approx(calcd_roots, known_roots)

eq = (a+b)*s^2 - 2s^2 + 2*b*s - 3*s
calcd_roots = sort_arr(Symbolics.solve_interms_ofvar(eq, s), [a,b])
known_roots = sort_arr([Dict(a=>1/2, b=>3/2)], [a,b])
@test check_approx(calcd_roots, known_roots)

eq = (a*x^2+b)*s^2 - 2s^2 + 2*b*s - 3*s + 2(x^2)*(s^3) + 10*s^3
calcd_roots = sort_arr(Symbolics.solve_interms_ofvar(eq, s), [a,b])
known_roots = sort_arr([Dict(a=>-1/10, b=>3/2, x=>-im*sqrt(5)), Dict(a=>-1/10, b=>3/2, x=>im*sqrt(5))], [a,b,x])
@test check_approx(calcd_roots, known_roots)
end

@testset "Invalid input" begin
@test_throws AssertionError symbolic_solve(x, x^2)
Expand Down

0 comments on commit d16bc03

Please sign in to comment.