Skip to content

Commit

Permalink
Merge pull request #3233 from JuliaLang/anj/psifn
Browse files Browse the repository at this point in the history
Add Julia translation of psifn from SLATEC and some unicode constants
  • Loading branch information
johnmyleswhite committed Jun 4, 2013
2 parents 0db20cb + b19013d commit b51ad62
Show file tree
Hide file tree
Showing 4 changed files with 279 additions and 69 deletions.
5 changes: 5 additions & 0 deletions base/exports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,11 @@ export
VERSION,
WORD_SIZE,
e,
eulergamma,
γ,
im,
pi,
π,

# Operators
!,
Expand Down Expand Up @@ -413,6 +416,7 @@ export
tanh,
trailing_ones,
trailing_zeros,
trigamma,
trunc,
uint,
uint128,
Expand Down Expand Up @@ -448,6 +452,7 @@ export
beta,
lbeta,
eta,
polygamma,
zeta,

# arrays
Expand Down
3 changes: 3 additions & 0 deletions base/float.jl
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,9 @@ sizeof(::Type{Float64}) = 8

const e = 2.71828182845904523536
const pi = 3.14159265358979323846
const π = pi
const euler_mascheroni = 0.57721566490153286061
const γ = euler_mascheroni

## byte order swaps for arbitrary-endianness serialization/deserialization ##
bswap(x::Float32) = box(Float32,bswap_int(unbox(Float32,x)))
Expand Down
314 changes: 252 additions & 62 deletions base/math.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ export sin, cos, tan, sinh, cosh, tanh, asin, acos, atan,
airy, airyai, airyprime, airyaiprime, airybi, airybiprime,
besselj0, besselj1, besselj, bessely0, bessely1, bessely,
hankelh1, hankelh2, besseli, besselk, besselh,
beta, lbeta, eta, zeta, digamma,
beta, lbeta, eta, zeta, polygamma, digamma, trigamma,
erfinv, erfcinv

import Base.log, Base.exp, Base.sin, Base.cos, Base.tan, Base.sinh, Base.cosh,
Expand Down Expand Up @@ -492,79 +492,269 @@ end

gamma(z::Complex) = exp(lgamma(z))

# Translation of psi.c from cephes
const digamma_EUL = 0.57721566490153286061
const digamma_coefs = [8.33333333333333333333e-2,-2.10927960927960927961e-2, 7.57575757575757575758e-3,
-4.16666666666666666667e-3, 3.96825396825396825397e-3,-8.33333333333333333333e-3,
8.33333333333333333333e-2]

function digamma(x::Float64)
negative = false
nz = 0.0

if x <= 0.0
negative = true
q = x
p = floor(q)
if p == q
return NaN
end

nz = q - p
if nz != 0.5
if nz > 0.5
p += 1.0
nz = q - p
# Derivatives of the digamma function
i1mach(i::Int) = ccall(("i1mach_", Math.openlibm_extras), Int32, (Ptr{Int32},), &i)
d1mach(i::Int) = ccall(("d1mach_", Math.openlibm_extras), Float64, (Ptr{Int32},), &i)
function psifn(x::Float64, n::Int, kode::Int, m::Int)
# Translated from http://www.netlib.org/slatec/src/dpsifn.f
# Note: Underflow handling at 380 in original is skipped
const nmax = 100
ans = Array(Float64, m)
#-----------------------------------------------------------------------
# bernoulli numbers
#-----------------------------------------------------------------------
const b = [1.00000000000000000e+00,
-5.00000000000000000e-01,1.66666666666666667e-01,
-3.33333333333333333e-02,2.38095238095238095e-02,
-3.33333333333333333e-02,7.57575757575757576e-02,
-2.53113553113553114e-01,1.16666666666666667e+00,
-7.09215686274509804e+00,5.49711779448621554e+01,
-5.29124242424242424e+02,6.19212318840579710e+03,
-8.65802531135531136e+04,1.42551716666666667e+06,
-2.72982310678160920e+07,6.01580873900642368e+08,
-1.51163157670921569e+10,4.29614643061166667e+11,
-1.37116552050883328e+13,4.88332318973593167e+14,
-1.92965793419400681e+16]
trm = Array(Float64, 22)
trmr = Array(Float64, 100)
#***first executable statement dpsifn
if x <= 0.0 throw(DomainError()) end
if n < 0 error("n must be non-negative") end
if kode < 1 | kode > 2 error("kode must be one or two") end
if m < 1 error("m must be larger than one") end
mm = m
const nx = min(-i1mach(15),i1mach(16))
const r1m5 = d1mach(5)
const r1m4 = d1mach(4)*0.5
const wdtol = max(r1m4, 0.5e-18)
#-----------------------------------------------------------------------
# elim = approximate exponential over and underflow limit
#-----------------------------------------------------------------------
const elim = 2.302*(nx*r1m5 - 3.0)
xln = log(x)
nn = n + mm - 1
fn = nn
t = (fn + 1)*xln
#-----------------------------------------------------------------------
# overflow and underflow test for small and large x
#-----------------------------------------------------------------------
if abs(t) > elim
if t <= 0.0 error("n too large") end
error("Overflow, x too small or n+m-1 too large or both")
end
if x < wdtol
ans[1] = x^(-n - 1)
if mm != 1
k = 1
for i = 2:mm
ans[k + 1] = ans[k]/x
k += 1
end
nz = pi / tan(pi * nz)
else
nz = 0.0
end
x = 1.0 - x
if n != 0 return ans end
if kode == 2 ans[1] = ans[1] + xln end
return ans
end

if x <= 10.0 && x == floor(x)
y = 0.0
for i = 1:x-1
y += 1.0 / i
end
y -= digamma_EUL

if negative
y -= nz
#-----------------------------------------------------------------------
# compute xmin and the number of terms of the series, fln+1
#-----------------------------------------------------------------------
rln = r1m5*i1mach(14)
rln = min(rln, 18.06)
fln = max(rln, 3.0) - 3.0
yint = 3.50 + 0.40*fln
slope = 0.21 + fln*(0.0006038*fln + 0.008677)
xm = yint + slope*fn
mx = itrunc(xm) + 1
xmin = mx
if n != 0
xm = -2.302*rln - min(0.0,xln)
arg = xm/n
arg = min(0.0,arg)
eps = exp(arg)
xm = 1.0 - eps
if abs(arg) < 1.0e-3 xm = -arg end
fln = x*xm/eps
xm = xmin - x
if (xm > 7.0) & (fln < 15.0)
nn = itrunc(fln) + 1
np = n + 1
t1 = (n + 1)*xln
t = exp(-t1)
s = t
den = x
for i = 1:nn
den += 1.0
trm[i] = den^(-np)
s += trm[i]
end
ans[1] = s
if n == 0
if kode == 2 ans[1] = s + xln end
end
if mm == 1 return ans end
#-----------------------------------------------------------------------
# generate higher derivatives, j.gt.n
#-----------------------------------------------------------------------
tol = wdtol/5.0
for j = 2:mm
t = t/x
s = t
tols = t*tol
den = x
for i = 1:nn
den += 1.0
trm[i] = trm[i]/den
s += trm[i]
if trm[i] < tols break end
end
ans[j] = s
end
return ans
end
return y
end

w = 0.0
while x < 10.0
w += 1.0 / x
x += 1.0

xdmy = x
xdmln = xln
xinc = 0.0
if x < xmin
nx = itrunc(x)
xinc = xmin - nx
xdmy = x + xinc
xdmln = log(xdmy)
end

if x < 1.0e17
z = 1.0 / (x*x)
y = digamma_coefs[1]
for j = 2:7
y = y*z + digamma_coefs[j]
#-----------------------------------------------------------------------
# generate w(n+mm-1,x) by the asymptotic expansion
#-----------------------------------------------------------------------
t = fn*xdmln
t1 = xdmln + xdmln
t2 = t + xdmln
tk = max(abs(t), abs(t1), abs(t2))
if tk > elim error("Underflow") end
tss = exp(-t)
tt = 0.5/xdmy
t1 = tt
tst = wdtol*tt
if nn != 0 t1 = tt + 1.0/fn end
rxsq = 1.0/(xdmy*xdmy)
ta = 0.5*rxsq
t = (fn + 1)*ta
s = t*b[3]
if abs(s) >= tst
tk = 2.0
for k = 4:22
t = t*((tk + fn + 1)/(tk + 1.0))*((tk + fn)/(tk + 2.0))*rxsq
trm[k] = t*b[k]
if abs(trm[k]) < tst break end
s += trm[k]
tk += 2.0
end
y *= z
else
y = 0.0
end

y = log(x) - 0.5/x - y - w

if negative
y -= nz
s = (s + t1)*tss
while true
if xinc != 0.0
#-----------------------------------------------------------------------
# backward recur from xdmy to x
#-----------------------------------------------------------------------
nx = itrunc(xinc)
np = nn + 1
if nx > nmax error("n too large") end
if nn == 0 break end
xm = xinc - 1.0
fx = x + xm
#-----------------------------------------------------------------------
# this loop should not be changed. fx is accurate when x is small
#-----------------------------------------------------------------------
for i = 1:nx
trmr[i] = fx^(-np)
s += trmr[i]
xm -= 1.0
fx = x + xm
end
end
ans[mm] = s
if fn == 0
if kode != 2
ans[1] = s - xdmln
return ans
end
if xdmy == x return ans end
xq = xdmy/x
ans[1] = s - log(xq)
return ans
end
#-----------------------------------------------------------------------
# generate lower derivatives, j.lt.n+mm-1
#-----------------------------------------------------------------------
if mm == 1 return ans end
for j = 2:mm
fn -= 1
tss *= xdmy
t1 = tt
if fn != 0 t1 = tt + 1.0/fn end
t = (fn + 1)*ta
s = t*b[3]
if abs(s) >= tst
tk = 4 + fn
for k = 4:22 #110
trm[k] = trm[k]*(fn + 1)/tk
if abs(trm[k]) < tst break end
s += trm[k]
tk += 2.0
end
end
s = (s + t1)*tss
if xinc != 0.0
if fn == 0 break end
xm = xinc - 1.0
fx = x + xm
for i = 1:nx
trmr[i] = trmr[i]*fx
s += trmr[i]
xm -= 1.0
fx = x + xm
end
end
mx = mm - j + 1
ans[mx] = s
if fn == 0
if kode != 2
ans[1] = s - xdmln
return ans
end
if xdmy == x return ans end
xq = xdmy/x
ans[1] = s - log(xq)
return ans
end
end
if fn == 0 break end
return ans
end

return y
#-----------------------------------------------------------------------
# recursion for n = 0
#-----------------------------------------------------------------------
for i = 1:nx
s += 1.0/(x + nx - i)
end
if kode != 2
ans[1] = s - xdmln
return ans
end
if xdmy == x return ans end
xq = xdmy/x
ans[1] = s - log(xq)
return ans
end
digamma(x::Float32) = float32(digamma(float64(x)))
digamma(x::Real) = digamma(float64(x))
polygamma(k::Int, x::Float64) = (2rem(k,2) - 1)*psifn(x, k, 1, 1)[1]/gamma(k + 1)
polygamma(k::Int, x::Float32) = float32(polygamma(k, float64(x)))
polygamma(k::Int, x::Real) = polygamma(k, float64(x))

digamma(x::Real) = polygamma(0, x)
@vectorize_1arg Real digamma

trigamma(x::Real) = polygamma(1, x)
@vectorize_1arg Real trigamma

beta(x::Number, w::Number) = exp(lgamma(x)+lgamma(w)-lgamma(x+w))
lbeta(x::Number, w::Number) = lgamma(x)+lgamma(w)-lgamma(x+w)
@vectorize_2arg Number beta
Expand Down
Loading

0 comments on commit b51ad62

Please sign in to comment.