diff --git a/astroid/interpreter/_import/spec.py b/astroid/interpreter/_import/spec.py index 93096e54e6..efdb8cb923 100644 --- a/astroid/interpreter/_import/spec.py +++ b/astroid/interpreter/_import/spec.py @@ -16,6 +16,8 @@ import warnings import zipimport from collections.abc import Iterator, Sequence +from dataclasses import dataclass, field +from functools import lru_cache from pathlib import Path from typing import Any, Literal, NamedTuple, Protocol @@ -423,7 +425,19 @@ def _find_spec_with_path( raise ImportError(f"No module named {'.'.join(module_parts)}") +@dataclass(frozen=True) +class SpecArgs: + key: tuple + modpath: list[str] = field(compare=False, hash=False) + path: Sequence[str] | None = field(compare=False, hash=False) + + def find_spec(modpath: list[str], path: Sequence[str] | None = None) -> ModuleSpec: + return _find_spec(SpecArgs(tuple(modpath), modpath, path)) + + +@lru_cache(maxsize=1024) +def _find_spec(args: SpecArgs) -> ModuleSpec: """Find a spec for the given module. :type modpath: list or tuple @@ -440,10 +454,10 @@ def find_spec(modpath: list[str], path: Sequence[str] | None = None) -> ModuleSp :return: A module spec, which describes how the module was found and where. """ - _path = path or sys.path + _path = args.path or sys.path # Need a copy for not mutating the argument. - modpath = modpath[:] + modpath = args.modpath[:] submodule_path = None module_parts = modpath[:] @@ -452,7 +466,7 @@ def find_spec(modpath: list[str], path: Sequence[str] | None = None) -> ModuleSp while modpath: modname = modpath.pop(0) finder, spec = _find_spec_with_path( - _path, modname, module_parts, processed, submodule_path or path + _path, modname, module_parts, processed, submodule_path or args.path ) processed.append(modname) if modpath: @@ -468,3 +482,7 @@ def find_spec(modpath: list[str], path: Sequence[str] | None = None) -> ModuleSp spec = spec._replace(submodule_search_locations=submodule_path) return spec + + +def clear_spec_cache() -> None: + _find_spec.cache_clear() diff --git a/astroid/manager.py b/astroid/manager.py index a7a51f19c5..195ac66459 100644 --- a/astroid/manager.py +++ b/astroid/manager.py @@ -442,10 +442,12 @@ def clear_cache(self) -> None: # pylint: disable=import-outside-toplevel from astroid.brain.helpers import register_all_brains from astroid.inference_tip import clear_inference_tip_cache + from astroid.interpreter._import.spec import clear_spec_cache from astroid.interpreter.objectmodel import ObjectModel from astroid.nodes._base_nodes import LookupMixIn from astroid.nodes.scoped_nodes import ClassDef + clear_spec_cache() clear_inference_tip_cache() _invalidate_cache() # inference context cache diff --git a/tests/test_manager.py b/tests/test_manager.py index 7861927930..160fa944bb 100644 --- a/tests/test_manager.py +++ b/tests/test_manager.py @@ -23,6 +23,7 @@ AttributeInferenceError, ) from astroid.interpreter._import import util +from astroid.interpreter._import.spec import clear_spec_cache from astroid.modutils import EXT_LIB_DIRS, module_in_path from astroid.nodes import Const from astroid.nodes.scoped_nodes import ClassDef, Module @@ -41,6 +42,7 @@ class AstroidManagerTest( ): def setUp(self) -> None: super().setUp() + clear_spec_cache() self.manager = test_utils.brainless_manager() def test_ast_from_file(self) -> None: diff --git a/tests/test_modutils.py b/tests/test_modutils.py index 929c58992c..be7095e2dd 100644 --- a/tests/test_modutils.py +++ b/tests/test_modutils.py @@ -22,6 +22,7 @@ from astroid import modutils from astroid.const import PY310_PLUS from astroid.interpreter._import import spec +from astroid.interpreter._import.spec import clear_spec_cache from . import resources @@ -41,6 +42,7 @@ class ModuleFileTest(unittest.TestCase): package = "mypypa" def tearDown(self) -> None: + clear_spec_cache() for k in list(sys.path_importer_cache): if "MyPyPa" in k: del sys.path_importer_cache[k]