Skip to content

Commit

Permalink
chore: fix type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
jolars committed Dec 14, 2023
1 parent 3b92010 commit fa09bca
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 5 deletions.
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
intersphinx_mapping = {
"sklearn": ("https://scikit-learn.org/stable", None),
"numpy": ("https://numpy.org/doc/stable/", None),
"scipy": ("https://docs.scipy.org/doc/scipy/", None),
}

# Myst
Expand Down
6 changes: 3 additions & 3 deletions sortedl1/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(
self.max_iter = max_iter
self.tol = tol

def fit(self, X: ArrayLike, y: ArrayLike):
def fit(self, X: np.ndarray | sparse.csc_array, y: ArrayLike):
"""
Fit the model according to the given training data.
Expand Down Expand Up @@ -112,11 +112,11 @@ def fit(self, X: ArrayLike, y: ArrayLike):
self.lambda_ = result[2]
self.alpha_ = result[3]
self.n_iter_ = result[4]
self.n_features_in_ = np.shape(X)[1]
self.n_features_in_ = X.shape[1]

return self

def predict(self, X: ArrayLike) -> np.ndarray:
def predict(self, X: ArrayLike | sparse.sparray) -> np.ndarray:
"""
Generate predictions for new data.
Expand Down
4 changes: 2 additions & 2 deletions tests/test_basic.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Basic tests for the sortedl1 package."""
import numpy as np
from numpy.random import default_rng
from scipy.sparse import random
from scipy.sparse import csc_array, random

from sortedl1 import Slope

Expand Down Expand Up @@ -37,7 +37,7 @@ def test_simple_sparse_problem():
p = 3

rng = np.random.default_rng(4)
x = random(n, p, density=0.5, random_state=rng)
x = csc_array(random(n, p, density=0.5, random_state=rng))
beta = np.array([1.0, 2, -0.9])
y = x @ beta

Expand Down

0 comments on commit fa09bca

Please sign in to comment.