Skip to content

Commit

Permalink
add annotate
Browse files Browse the repository at this point in the history
  • Loading branch information
jsxlei committed Nov 13, 2024
1 parent 01c8c41 commit 97854db
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 27 deletions.
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@


## News
#### [2022-10-17] SCALEX is online at [Nature Communications](https://www.nature.com/articles/s41467-022-33758-z)

## [Documentation](https://scalex.readthedocs.io/en/latest/index.html)
## [Tutorial](https://scalex.readthedocs.io/en/latest/tutorial/index.html)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ requires = ["hatchling"]
[project]
name = "scalex"
authors = [{name = "Lei Xiong"}]
version = "1.0.4"
version = "1.0.5"
readme = "README.md"
requires-python = ">=3.7"
description = "Online single-cell data integration through projecting heterogeneous datasets into a common cell-embedding space"
Expand Down
45 changes: 25 additions & 20 deletions scalex/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,26 +97,31 @@ def annotate(
print('-'*20+'\n', n_top, '\n'+'-'*20)

if go:
go_results = enrich_analysis(marker.head(n_top))
go_results['cell_type'] = 'leiden_' + go_results['cell_type']
n = go_results['cell_type'].nunique()
ax = dotplot(go_results,
column="Adjusted P-value",
x='cell_type', # set x axis, so you could do a multi-sample/library comparsion
# size=10,
top_term=10,
figsize=(0.7*n, 2*n),
title = "GO_BP",
xticklabels_rot=45, # rotate xtick labels
show_ring=False, # set to False to revmove outer ring
marker='o',
cutoff=0.05,
cmap='viridis'
)
if out_dir is not None:
os.makedirs(out_dir, exist_ok=True)
go_results[['Gene_set','Term','Overlap', 'Adjusted P-value', 'Genes', 'cell_type']].to_csv(out_dir + f'/go_results_{n_top}.csv')
plt.show()
for option in ['pos', 'neg']:
if option == 'pos':
go_results = enrich_analysis(marker.head(n_top))
else:
go_results = enrich_analysis(marker.tail(n_top))

go_results['cell_type'] = 'leiden_' + go_results['cell_type']
n = go_results['cell_type'].nunique()
ax = dotplot(go_results,
column="Adjusted P-value",
x='cell_type', # set x axis, so you could do a multi-sample/library comparsion
# size=10,
top_term=10,
figsize=(0.7*n, 2*n),
title = f"{option}_GO_BP_{n_top}",
xticklabels_rot=45, # rotate xtick labels
show_ring=False, # set to False to revmove outer ring
marker='o',
cutoff=0.05,
cmap='viridis'
)
if out_dir is not None:
os.makedirs(out_dir, exist_ok=True)
go_results[['Gene_set','Term','Overlap', 'Adjusted P-value', 'Genes', 'cell_type']].to_csv(out_dir + f'/{option}_go_results_{n_top}.csv')
plt.show()

for pathway_name, pathways in additional.items():
try:
Expand Down
20 changes: 18 additions & 2 deletions scalex/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ def preprocessing_rna(
target_sum: int = 10000,
n_top_features = 2000, # or gene list
chunk_size: int = CHUNK_SIZE,
min_cell_per_batch: int = 10,
keep_mt: bool = False,
backed: bool = False,
log=None
Expand Down Expand Up @@ -234,8 +235,11 @@ def preprocessing_rna(
batch_counts = adata.obs['batch'].value_counts()

# Filter out batches with only one sample
valid_batches = batch_counts[batch_counts > 10].index
print('min_cell_per_batch', min_cell_per_batch)
valid_batches = batch_counts[batch_counts >= min_cell_per_batch].index
adata = adata[adata.obs['batch'].isin(valid_batches)].copy()
# if log: log.info('There are {} batches under batch_name: {}'.format(len(adata.obs['batch'].cat.categories), batch_name))
if log: log.info(adata.obs['batch'].value_counts())

if log: log.info('Preprocessing')
# if not issparse(adata.X):
Expand Down Expand Up @@ -359,6 +363,7 @@ def preprocessing(
min_cells: int = 3,
target_sum: int = None,
n_top_features = None, # or gene list
min_cell_per_batch: int = 10,
keep_mt: bool = False,
backed: bool = False,
chunk_size: int = CHUNK_SIZE,
Expand Down Expand Up @@ -398,6 +403,7 @@ def preprocessing(
min_cells=min_cells,
target_sum=target_sum,
n_top_features=n_top_features,
min_cell_per_batch=min_cell_per_batch,
keep_mt=keep_mt,
backed=backed,
chunk_size=chunk_size,
Expand Down Expand Up @@ -562,10 +568,13 @@ def load_data(
join='inner',
batch_key='batch',
batch_name='batch',
groupby=None,
subsets=None,
min_features=600,
min_cells=3,
target_sum=None,
n_top_features=None,
min_cell_per_batch=10,
keep_mt=False,
backed=False,
batch_size=64,
Expand All @@ -592,6 +601,8 @@ def load_data(
Add the batch annotation to obs using this key. Default: 'batch'.
batch_name
Use this annotation in obs as batches for training model. Default: 'batch'.
subsets
Subsets of data to load. Default: None.
min_features
Filtered out cells that are detected in less than min_features. Default: 600.
min_cells
Expand All @@ -615,6 +626,10 @@ def load_data(
An iterable over the given dataset for testing
"""
adata = concat_data(data_list, batch_categories, join=join, batch_key=batch_key)
if subsets is not None and groupby is not None:
adata = adata[adata.obs[groupby].isin(subsets)].copy()
if log: log.info('Subsets dataset shape: {}'.format(adata.shape))

if log: log.info('Raw dataset shape: {}'.format(adata.shape))
if batch_name!='batch':
if ',' in batch_name:
Expand All @@ -625,7 +640,6 @@ def load_data(
if 'batch' not in adata.obs:
adata.obs['batch'] = 'batch'
adata.obs['batch'] = adata.obs['batch'].astype('category')
if log: log.info('There are {} batches under batch_name: {}'.format(len(adata.obs['batch'].cat.categories), batch_name))

if isinstance(n_top_features, str):
if os.path.isfile(n_top_features):
Expand All @@ -644,6 +658,7 @@ def load_data(
min_cells=min_cells,
target_sum=target_sum,
n_top_features=n_top_features,
min_cell_per_batch=min_cell_per_batch,
keep_mt=keep_mt,
chunk_size=chunk_size,
backed=backed,
Expand All @@ -658,6 +673,7 @@ def load_data(
adata.obsm[use_layer] = MaxAbsScaler().fit_transform(adata.obsm[use_layer])
else:
raise ValueError("Not support use_layer: `{}` yet".format(use_layer))

scdata = SingleCellDataset(adata, use_layer=use_layer) # Wrap AnnData into Pytorch Dataset
trainloader = DataLoader(
scdata,
Expand Down
17 changes: 16 additions & 1 deletion scalex/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def SCALEX(
min_cells:int=3,
target_sum:int=None,
n_top_features:int=None,
min_cell_per_batch:int=10,
join:str='inner',
batch_key:str='batch',
processed:bool=False,
Expand All @@ -30,6 +31,8 @@ def SCALEX(
keep_mt:bool=False,
backed:bool=False,
batch_size:int=64,
groupby:str=None,
subsets:list=None,
lr:float=2e-4,
max_iteration:int=30000,
seed:int=124,
Expand Down Expand Up @@ -146,7 +149,10 @@ def SCALEX(
profile=profile,
target_sum=target_sum,
n_top_features=n_top_features,
min_cell_per_batch=min_cell_per_batch,
batch_size=batch_size,
groupby=groupby,
subsets=subsets,
chunk_size=chunk_size,
min_features=min_features,
min_cells=min_cells,
Expand Down Expand Up @@ -199,6 +205,7 @@ def SCALEX(
n_top_features=n_top_features,
min_cells=0,
min_features=min_features,
min_cell_per_batch=min_cell_per_batch,
processed=processed,
batch_name=batch_name,
batch_key=batch_key,
Expand Down Expand Up @@ -226,6 +233,9 @@ def SCALEX(
# adata.raw = concat([ref.raw.to_adata(), adata.raw.to_adata()], join='outer', label='projection', keys=['reference', 'query'])
if 'leiden' in adata.obs:
del adata.obs['leiden']
for col in adata.obs.columns:
if not pd.api.types.is_string_dtype(adata.obs[col]):
adata.obs[col] = adata.obs[col].astype(str)

# if outdir is not None:
# adata.write(os.path.join(outdir, 'adata.h5ad'), compression='gzip')
Expand Down Expand Up @@ -321,12 +331,14 @@ def main():
parser.add_argument('--batch_key', type=str, default='batch')
parser.add_argument('--batch_name', type=str, default='batch')
parser.add_argument('--profile', type=str, default='RNA')
parser.add_argument('--test_list', '-t', type=str, nargs='+', default=[])
parser.add_argument('--subsets', type=str, nargs='+', default=None)
parser.add_argument('--groupby', type=str, default=None)

parser.add_argument('--min_features', type=int, default=None)
parser.add_argument('--min_cells', type=int, default=3)
parser.add_argument('--n_top_features', default=None)
parser.add_argument('--target_sum', type=int, default=None)
parser.add_argument('--min_cell_per_batch', type=int, default=10)
parser.add_argument('--processed', action='store_true', default=False)
parser.add_argument('--fraction', type=float, default=None)
parser.add_argument('--n_obs', type=int, default=None)
Expand Down Expand Up @@ -365,10 +377,13 @@ def main():
profile=args.profile,
join=args.join,
batch_key=args.batch_key,
groupby=args.groupby,
subsets=args.subsets,
min_features=args.min_features,
min_cells=args.min_cells,
target_sum=args.target_sum,
n_top_features=args.n_top_features,
min_cell_per_batch=args.min_cell_per_batch,
fraction=args.fraction,
n_obs=args.n_obs,
processed=args.processed,
Expand Down
8 changes: 6 additions & 2 deletions scalex/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def plot_meta(
adata
AnnData
use_rep
the cell representations or embeddings used to calculate the correlations, default is `latent` generated by `SCALE v2`
the cell representations or embeddings used to calculate the correlations, default is `latent` generated by `SCALEX`
batch
the meta information based-on, default is batch
colors
Expand Down Expand Up @@ -235,7 +235,7 @@ def plot_meta2(
adata
AnnData
use_rep
the cell representations or embeddings used to calculate the correlations, default is `latent` generated by `SCALE v2`
the cell representations or embeddings used to calculate the correlations, default is `latent` generated by `SCALEX`
batch
the meta information based-on, default is batch
colors
Expand All @@ -255,6 +255,10 @@ def plot_meta2(
fontsize
font size
"""
import matplotlib as mpl
mpl.rcParams['axes.grid'] = False
# mpl.rcParams.update(mpl.rcParamsDefault)

meta = []
name = []

Expand Down

0 comments on commit 97854db

Please sign in to comment.