Skip to content

Commit

Permalink
Merge pull request #1331 from pymc-devs/pep8
Browse files Browse the repository at this point in the history
STY Ran autopep8 on full code-base.
  • Loading branch information
springcoil authored Sep 6, 2016
2 parents d092dd8 + 7a0bdb5 commit a29f7f1
Show file tree
Hide file tree
Showing 104 changed files with 1,289 additions and 894 deletions.
11 changes: 6 additions & 5 deletions pymc3/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class BaseTrace(object):
Sampling values will be stored for these variables. If None,
`model.unobserved_RVs` is used.
"""

def __init__(self, name, model=None, vars=None):
self.name = name

Expand All @@ -35,17 +36,16 @@ def __init__(self, name, model=None, vars=None):
self.varnames = [var.name for var in vars]
self.fn = model.fastfn(vars)


## Get variable shapes. Most backends will need this
## information.
# Get variable shapes. Most backends will need this
# information.
var_values = list(zip(self.varnames, self.fn(model.test_point)))
self.var_shapes = {var: value.shape
for var, value in var_values}
self.var_dtypes = {var: value.dtype
for var, value in var_values}
self.chain = None

## Sampling methods
# Sampling methods

def setup(self, draws, chain):
"""Perform chain-specific setup.
Expand Down Expand Up @@ -76,7 +76,7 @@ def close(self):
"""
pass

## Selection methods
# Selection methods

def __getitem__(self, idx):
if isinstance(idx, slice):
Expand Down Expand Up @@ -149,6 +149,7 @@ class MultiTrace(object):
of the MultiTrace instance, which returns the number of draws), the
trace with the highest chain number is always used.
"""

def __init__(self, straces):
self._straces = {}
for strace in straces:
Expand Down
9 changes: 5 additions & 4 deletions pymc3/backends/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@ class NDArray(base.BaseTrace):
Sampling values will be stored for these variables. If None,
`model.unobserved_RVs` is used.
"""

def __init__(self, name=None, model=None, vars=None):
super(NDArray, self).__init__(name, model, vars)
self.draw_idx = 0
self.draws = None
self.samples = {}

## Sampling methods
# Sampling methods

def setup(self, draws, chain):
"""Perform chain-specific setup.
Expand Down Expand Up @@ -70,12 +71,12 @@ def record(self, point):
def close(self):
if self.draw_idx == self.draws:
return
## Remove trailing zeros if interrupted before completed all
## draws.
# Remove trailing zeros if interrupted before completed all
# draws.
self.samples = {var: vtrace[:self.draw_idx]
for var, vtrace in self.samples.items()}

## Selection methods
# Selection methods

def __len__(self):
if not self.samples: # `setup` has not been called.
Expand Down
22 changes: 12 additions & 10 deletions pymc3/backends/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@
'WHERE chain = ?'),
'draw_count': ('SELECT COUNT(*) FROM [{table}] '
'WHERE chain = ?'),
## Named placeholders are used in the selection templates because
## some values occur more than once in the same template.
# Named placeholders are used in the selection templates because
# some values occur more than once in the same template.
'select': ('SELECT * FROM [{table}] '
'WHERE (chain = :chain)'),
'select_burn': ('SELECT * FROM [{table}] '
Expand Down Expand Up @@ -71,6 +71,7 @@ class SQLite(base.BaseTrace):
Sampling values will be stored for these variables. If None,
`model.unobserved_RVs` is used.
"""

def __init__(self, name, model=None, vars=None):
super(SQLite, self).__init__(name, model, vars)
self._var_cols = {}
Expand All @@ -80,13 +81,13 @@ def __init__(self, name, model=None, vars=None):
self._len = None

self.db = _SQLiteDB(name)
## Inserting sampling information is queued to avoid locks
## caused by hitting the database with transactions each
## iteration.
# Inserting sampling information is queued to avoid locks
# caused by hitting the database with transactions each
# iteration.
self._queue = {varname: [] for varname in self.varnames}
self._queue_limit = 5000

## Sampling methods
# Sampling methods

def setup(self, draws, chain):
"""Perform chain-specific setup.
Expand Down Expand Up @@ -127,7 +128,7 @@ def _create_table(self):
def _create_insert_queries(self, chain):
template = TEMPLATES['insert']
for varname, var_cols in self._var_cols.items():
## Create insert statement for each variable.
# Create insert statement for each variable.
var_str = ', '.join(var_cols)
placeholders = ', '.join(['?'] * len(var_cols))
statement = template.format(table=varname,
Expand Down Expand Up @@ -164,7 +165,7 @@ def close(self):
self._execute_queue()
self.db.close()

## Selection methods
# Selection methods

def __len__(self):
if not self._is_setup:
Expand Down Expand Up @@ -252,6 +253,7 @@ def point(self, idx):


class _SQLiteDB(object):

def __init__(self, name):
self.name = name
self.con = None
Expand Down Expand Up @@ -306,8 +308,8 @@ def load(name, model=None):

def _get_table_list(cursor):
"""Return a list of table names in the current database."""
## Modified from Django. Skips the sqlite_sequence system table used
## for autoincrement key generation.
# Modified from Django. Skips the sqlite_sequence system table used
# for autoincrement key generation.
cursor.execute("SELECT name FROM sqlite_master "
"WHERE type='table' AND NOT name='sqlite_sequence' "
"ORDER BY name")
Expand Down
8 changes: 5 additions & 3 deletions pymc3/backends/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class Text(base.BaseTrace):
Sampling values will be stored for these variables. If None,
`model.unobserved_RVs` is used.
"""

def __init__(self, name, model=None, vars=None):
if not os.path.exists(name):
os.mkdir(name)
Expand All @@ -49,7 +50,7 @@ def __init__(self, name, model=None, vars=None):
self._fh = None
self.df = None

## Sampling methods
# Sampling methods

def setup(self, draws, chain):
"""Perform chain-specific setup.
Expand Down Expand Up @@ -96,7 +97,7 @@ def close(self):
self._fh.close()
self._fh = None # Avoid serialization issue.

## Selection methods
# Selection methods

def _load_df(self):
if self.df is None:
Expand Down Expand Up @@ -194,5 +195,6 @@ def dump(name, trace, chains=None):

for chain in chains:
filename = os.path.join(name, 'chain-{}.csv'.format(chain))
df = ttab.trace_to_dataframe(trace, chains=chain, flat_names=flat_names)
df = ttab.trace_to_dataframe(
trace, chains=chain, flat_names=flat_names)
df.to_csv(filename, index=False)
4 changes: 4 additions & 0 deletions pymc3/blocking.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class ArrayOrdering(object):
"""
An ordering for an array space
"""

def __init__(self, vars):
self.vmap = []
dim = 0
Expand All @@ -33,6 +34,7 @@ class DictToArrayBijection(object):
"""
A mapping between a dict space and an array space
"""

def __init__(self, ordering, dpoint):
self.ordering = ordering
self.dpt = dpoint
Expand Down Expand Up @@ -85,6 +87,7 @@ class DictToVarBijection(object):
"""
A mapping between a dict space and the array space for one element within the dict space
"""

def __init__(self, var, idx, dpoint):
self.var = str(var)
self.idx = idx
Expand All @@ -111,6 +114,7 @@ class Compose(object):
"""
Compose two functions in a pickleable way
"""

def __init__(self, fa, fb):
self.fa = fa
self.fb = fb
Expand Down
4 changes: 2 additions & 2 deletions pymc3/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@

__all__ = ['get_data_file']


def get_data_file(pkg, path):
"""Returns a file object for a package data file.
Parameters
----------
pkg : str
Expand All @@ -18,4 +19,3 @@ def get_data_file(pkg, path):
"""

return io.BytesIO(pkgutil.get_data(pkg, path))

22 changes: 12 additions & 10 deletions pymc3/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ def geweke(x, first=.1, last=.5, intervals=20):
last_start_idx = (1 - last) * end

# Calculate starting indices
start_indices = np.arange(0, int(last_start_idx), step=int((last_start_idx) / (intervals - 1)))
start_indices = np.arange(0, int(last_start_idx), step=int(
(last_start_idx) / (intervals - 1)))

# Loop over start indices
for start in start_indices:
Expand Down Expand Up @@ -151,9 +152,9 @@ def calc_rhat(x):
W = np.mean(np.var(x, axis=1, ddof=1))

# Estimate of marginal posterior variance
Vhat = W*(n - 1)/n + B/n
Vhat = W * (n - 1) / n + B / n

return np.sqrt(Vhat/W)
return np.sqrt(Vhat / W)

except ValueError:

Expand Down Expand Up @@ -223,7 +224,7 @@ def calc_vhat(x):
W = np.mean(np.var(x, axis=1, ddof=1))

# Estimate of marginal posterior variance
Vhat = W*(n - 1)/n + B/n
Vhat = W * (n - 1) / n + B / n

return Vhat

Expand All @@ -243,21 +244,22 @@ def calc_n_eff(x):

Vhat = calc_vhat(x)

variogram = lambda t: (sum(sum((x[j][i] - x[j][i-t])**2
for i in range(t,n)) for j in range(m)) / (m*(n - t)))
variogram = lambda t: (sum(sum((x[j][i] - x[j][i - t])**2
for i in range(t, n)) for j in range(m)) / (m * (n - t)))

rho = np.ones(n)
# Iterate until the sum of consecutive estimates of autocorrelation is negative
# Iterate until the sum of consecutive estimates of autocorrelation is
# negative
while not negative_autocorr and (t < n):

rho[t] = 1. - variogram(t)/(2.*Vhat)
rho[t] = 1. - variogram(t) / (2. * Vhat)

if not t % 2:
negative_autocorr = sum(rho[t-1:t+1]) < 0
negative_autocorr = sum(rho[t - 1:t + 1]) < 0

t += 1

return int(m*n / (1. + 2*rho[1:t].sum()))
return int(m * n / (1. + 2 * rho[1:t].sum()))

n_eff = {}
for var in mtrace.varnames:
Expand Down
98 changes: 48 additions & 50 deletions pymc3/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,53 +62,51 @@
from .transforms import sum_to_1

__all__ = ['Uniform',
'Flat',
'Normal',
'Beta',
'Exponential',
'Laplace',
'StudentT',
'Cauchy',
'HalfCauchy',
'Gamma',
'Weibull',
'Bound',
'StudentTpos',
'Lognormal',
'ChiSquared',
'HalfNormal',
'Wald',
'Pareto',
'InverseGamma',
'ExGaussian',
'VonMises',
'Binomial',
'BetaBinomial',
'Bernoulli',
'Poisson',
'NegativeBinomial',
'ConstantDist',
'ZeroInflatedPoisson',
'ZeroInflatedNegativeBinomial',
'DiscreteUniform',
'Geometric',
'Categorical',
'DensityDist',
'Distribution',
'Continuous',
'Discrete',
'NoDistribution',
'TensorType',
'MvNormal',
'MvStudentT',
'Dirichlet',
'Multinomial',
'Wishart',
'WishartBartlett',
'LKJCorr',
'AR1',
'GaussianRandomWalk',
'GARCH11'
]


'Flat',
'Normal',
'Beta',
'Exponential',
'Laplace',
'StudentT',
'Cauchy',
'HalfCauchy',
'Gamma',
'Weibull',
'Bound',
'StudentTpos',
'Lognormal',
'ChiSquared',
'HalfNormal',
'Wald',
'Pareto',
'InverseGamma',
'ExGaussian',
'VonMises',
'Binomial',
'BetaBinomial',
'Bernoulli',
'Poisson',
'NegativeBinomial',
'ConstantDist',
'ZeroInflatedPoisson',
'ZeroInflatedNegativeBinomial',
'DiscreteUniform',
'Geometric',
'Categorical',
'DensityDist',
'Distribution',
'Continuous',
'Discrete',
'NoDistribution',
'TensorType',
'MvNormal',
'MvStudentT',
'Dirichlet',
'Multinomial',
'Wishart',
'WishartBartlett',
'LKJCorr',
'AR1',
'GaussianRandomWalk',
'GARCH11'
]
Loading

0 comments on commit a29f7f1

Please sign in to comment.