Skip to content

Commit

Permalink
use fixtures
Browse files Browse the repository at this point in the history
  • Loading branch information
AleksanderWWW committed Jan 18, 2024
1 parent 203a340 commit 04eb8b5
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 28 deletions.
44 changes: 44 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from dataclasses import dataclass
from typing import Optional

import numpy as np
from pytest import fixture
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.utils import Bunch


@dataclass
class Dataset:
x: np.ndarray
y: np.ndarray
x_train: np.ndarray
x_test: np.ndarray
y_train: np.ndarray
y_test: np.ndarray


_IRIS_DATASET: Optional[Bunch] = None
_DIABETES_DATASET: Optional[Bunch] = None


@fixture(scope="session")
def iris() -> Dataset:
global _IRIS_DATASET
if _IRIS_DATASET is None:
_IRIS_DATASET = datasets.load_iris()
x = _IRIS_DATASET.data[:, :2]
y = _IRIS_DATASET.target
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.5)
return Dataset(x, y, x_train, x_test, y_train, y_test)


@fixture(scope="session")
def diabetes() -> Dataset:
global _DIABETES_DATASET
if _DIABETES_DATASET is None:
_DIABETES_DATASET = datasets.load_diabetes(return_X_y=True)
x, y = _DIABETES_DATASET

x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.5)
return Dataset(x, y, x_train, x_test, y_train, y_test)
47 changes: 19 additions & 28 deletions tests/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,51 +21,42 @@
import neptune_sklearn as npt_utils


def test_classifier_summary():
def test_classifier_summary(iris):
with init_run() as run:

iris = datasets.load_iris()
X = iris.data[:, :2]
y = iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5)

model = DummyClassifier()
model.fit(X_train, y_train)
model.fit(iris.x_train, iris.y_train)

run["summary"] = npt_utils.create_classifier_summary(model, X_train, X_test, y_train, y_test)
run["summary"] = npt_utils.create_classifier_summary(
model, iris.x_train, iris.x_test, iris.y_train, iris.y_test
)

run.wait()
validate_run(run, log_charts=True)
run.wait()
validate_run(run, log_charts=True)


def test_regressor_summary():
def test_regressor_summary(diabetes):
with init_run() as run:

X, y = datasets.load_diabetes(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5)

model = DummyRegressor()
model.fit(X_train, y_train)
model.fit(diabetes.x_train, diabetes.y_train)

run["summary"] = npt_utils.create_regressor_summary(model, X_train, X_test, y_train, y_test)
run["summary"] = npt_utils.create_regressor_summary(
model, diabetes.x_train, diabetes.x_test, diabetes.y_train, diabetes.y_test
)

run.wait()
validate_run(run, log_charts=True)
run.wait()
validate_run(run, log_charts=True)


def test_kmeans_summary():
def test_kmeans_summary(iris):
with init_run() as run:

iris = datasets.load_iris()
X = iris.data[:, :2]

model = KMeans()
model.fit(X)
model.fit(iris.x)

run["summary"] = npt_utils.create_kmeans_summary(model, X, n_clusters=3)
run["summary"] = npt_utils.create_kmeans_summary(model, iris.x, n_clusters=3)

run.wait()
validate_run(run, log_charts=True)
run.wait()
validate_run(run, log_charts=True)


@pytest.mark.filterwarnings("error::neptune.common.warnings.NeptuneUnsupportedType")
Expand Down

0 comments on commit 04eb8b5

Please sign in to comment.