Skip to content

Commit

Permalink
Implement Astroid workaround bug
Browse files Browse the repository at this point in the history
See workaround in pylint-dev/astroid#1015
  • Loading branch information
mscuthbert committed Apr 30, 2022
1 parent 47fcf2f commit 448b35f
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 48 deletions.
2 changes: 1 addition & 1 deletion music21/meter/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ def testSlowSixEight(self):
m.timeSignature = ts
n = note.Note(quarterLength=0.5)
m.repeatAppend(n, 6)
match = [n.beatStr for n in m.notes]
match = [n.beatStr for n in m.iter().getElementsByClass(note.NotRest)]
self.assertEqual(match, ['1', '2', '3', '4', '5', '6'])
m.makeBeams(inPlace=True)
# m.show()
Expand Down
63 changes: 41 additions & 22 deletions music21/stream/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@

from music21.common.numberTools import opFrac
from music21.common.enums import GatherSpanners, OffsetSpecial
from music21.common.types import StreamType, M21ObjType, OffsetQL, ClassListType
from music21.common.types import StreamType, M21ObjType, OffsetQL

from music21 import environment

Expand All @@ -80,7 +80,7 @@

T = TypeVar('T')
# we sometimes need to return a different type.
M21ObjType2 = TypeVar('M21ObjType2', bound=base.Music21Object)
ChangedM21ObjType = TypeVar('ChangedM21ObjType', bound=base.Music21Object)

BestQuantizationMatch = namedtuple('BestQuantizationMatch',
['error', 'tick', 'match', 'signedError', 'divisor'])
Expand Down Expand Up @@ -405,24 +405,32 @@ def iter(self) -> iterator.StreamIterator[M21ObjType]:
return self.__iter__()

@overload
def __getitem__(self, k: str) -> iterator.RecursiveIterator:
...
def __getitem__(self, k: str) -> iterator.RecursiveIterator[M21ObjType]:
# Remove this code and replace with ... once Astroid #1015 is fixed.
x: iterator.RecursiveIterator[M21ObjType] = self.recurse()
return x

@overload
def __getitem__(self, k: int) -> M21ObjType:
...
return self[k] # dummy code

@overload
def __getitem__(self, k: slice) -> List[M21ObjType]:
...
return list(self.elements) # dummy code

@overload
def __getitem__(self, k: Type[M21ObjType2]) -> iterator.RecursiveIterator[M21ObjType2]:
...
def __getitem__(self,
k: Type[ChangedM21ObjType]
) -> iterator.RecursiveIterator[ChangedM21ObjType]:
x: iterator.RecursiveIterator[ChangedM21ObjType] = self.recurse()
return x # dummy code

def __getitem__(self,
k: Union[str, int, slice, Type[M21ObjType]]
) -> Union[iterator.RecursiveIterator, M21ObjType, List[M21ObjType]]:
k: Union[str, int, slice, Type[ChangedM21ObjType]]
) -> Union[iterator.RecursiveIterator[M21ObjType],
iterator.RecursiveIterator[ChangedM21ObjType],
M21ObjType,
List[M21ObjType]]:
'''
Get a Music21Object from the Stream using a variety of keys or indices.

Expand Down Expand Up @@ -3355,23 +3363,34 @@ def addGroupForElements(self, group, classFilter=None, *, recurse=False):
@overload
def getElementsByClass(self,
classFilterList: Union[str, Iterable[str]]
) -> iterator.StreamIterator:
...
) -> iterator.StreamIterator[M21ObjType]:
# Remove all dummy code once Astroid #1015 is fixed
x: iterator.StreamIterator[M21ObjType] = self.iter()
return x # dummy code

@overload
def getElementsByClass(self,
classFilterList: Type[M21ObjType]
) -> iterator.StreamIterator[M21ObjType]:
...
classFilterList: Type[ChangedM21ObjType]
) -> iterator.StreamIterator[ChangedM21ObjType]:
x: iterator.StreamIterator[ChangedM21ObjType] = self.iter()
return x # dummy code

@overload
def getElementsByClass(self,
classFilterList: Iterable[Type[M21ObjType]]
) -> iterator.StreamIterator:
...
classFilterList: Iterable[Type[ChangedM21ObjType]]
) -> iterator.StreamIterator[M21ObjType]:
x: iterator.StreamIterator[M21ObjType] = self.iter()
return x # dummy code

def getElementsByClass(self, classFilterList: ClassListType
) -> iterator.StreamIterator[ClassListType]:
def getElementsByClass(self,
classFilterList: Union[
str,
Type[ChangedM21ObjType],
Iterable[str],
Iterable[Type[ChangedM21ObjType]],
],
) -> Union[iterator.StreamIterator[M21ObjType],
iterator.StreamIterator[ChangedM21ObjType]]:
'''
Return a StreamIterator that will iterate over Elements that match one
or more classes in the `classFilterList`. A single class
Expand Down Expand Up @@ -4293,7 +4312,7 @@ def hasMeasureNumberInformation(measureIterator):
returnObj = self.cloneEmpty(derivationMethod='measures')
srcObj = self

mStreamIter = self.getElementsByClass(Measure)
mStreamIter: iterator.StreamIterator[Measure] = self.getElementsByClass(Measure)

# FIND THE CORRECT ORIGINAL MEASURE OBJECTS
# for indicesNotNumbers, this is simple...
Expand Down Expand Up @@ -7535,7 +7554,7 @@ def recurse(self,
streamsOnly=False,
restoreActiveSites=True,
classFilter=(),
includeSelf=None) -> iterator.RecursiveIterator[Any]:
includeSelf=None) -> iterator.RecursiveIterator[M21ObjType]:
'''
`.recurse()` is a fundamental method of music21 for getting into
elements contained in a Score, Part, or Measure, where elements such as
Expand Down
66 changes: 41 additions & 25 deletions music21/stream/iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
'''
from __future__ import annotations

import collections.abc
import copy
from typing import (TypeVar, List, Union, Callable, Optional, Literal, Any,
TypedDict, Generic, overload, Iterable, Type, cast)
Expand Down Expand Up @@ -63,9 +62,7 @@ class ActiveInformation(TypedDict, total=False):


# -----------------------------------------------------------------------------


class StreamIterator(prebase.ProtoM21Object, Generic[M21ObjType], collections.abc.Sequence):
class StreamIterator(prebase.ProtoM21Object, Generic[M21ObjType]):
'''
An Iterator object used to handle getting items from Streams.
The :meth:`~music21.stream.Stream.__iter__` method
Expand Down Expand Up @@ -751,11 +748,14 @@ def _newBaseStream(self) -> 'music21.stream.Stream':

@overload
def stream(self, returnStreamSubClass: Literal[False] = True) -> 'music21.stream.Stream':
...
# ignore this code -- just here until Astroid bug #1015 is fixed
x: 'music21.stream.Stream' = self.streamObj
return x

@overload
def stream(self, returnStreamSubClass: Literal[True] = True) -> StreamType:
...
x: StreamType = self.streamObj
return x

def stream(self, returnStreamSubClass=True) -> Union['music21.stream.Stream', StreamType]:
'''
Expand Down Expand Up @@ -943,33 +943,40 @@ def getElementById(self, elementId) -> Optional[M21ObjType]:
return e
return None

# Replace all code in overload statements once
# https://github.com/PyCQA/astroid/issues/1015
# is fixed and deployed
@overload
def getElementsByClass(self,
classFilterList: str,
*,
returnClone: bool = True) -> StreamIterator[base.Music21Object]:
...
returnClone: bool = True) -> StreamIterator[M21ObjType]:
x: StreamIterator[M21ObjType] = self.__class__(self.streamObj)
return x

@overload
def getElementsByClass(self,
classFilterList: Iterable[str],
*,
returnClone: bool = True) -> StreamIterator[M21ObjType]:
...
x: StreamIterator[M21ObjType] = self.__class__(self.streamObj)
return x

@overload
def getElementsByClass(self,
classFilterList: Type[ChangedM21ObjType],
*,
returnClone: bool = True) -> StreamIterator[ChangedM21ObjType]:
...
x: StreamIterator[ChangedM21ObjType] = self.__class__(self.streamObj)
return x

@overload
def getElementsByClass(self,
classFilterList: Iterable[Type[ChangedM21ObjType]],
*,
returnClone: bool = True) -> StreamIterator[M21ObjType]:
...
x: StreamIterator[M21ObjType] = self.__class__(self.streamObj)
return x


def getElementsByClass(self,
Expand Down Expand Up @@ -1569,34 +1576,39 @@ def reset(self):
# NOTE: these getElementsByClass are the same as the one in StreamIterator, but
# for now it needs to be duplicated until changing a Generic's argument type
# can be done with inheritance.
# TODO: remove code and replace with ... when Astroid bug #1015 is fixed.

@overload
def getElementsByClass(self,
classFilterList: str,
*,
returnClone: bool = True) -> OffsetIterator[M21ObjType]:
...
x: OffsetIterator[M21ObjType] = self.__class__(self.streamObj)
return x

@overload
def getElementsByClass(self,
classFilterList: Iterable[str],
*,
returnClone: bool = True) -> OffsetIterator[M21ObjType]:
...
x: OffsetIterator[M21ObjType] = self.__class__(self.streamObj)
return x

@overload
def getElementsByClass(self,
classFilterList: Type[ChangedM21ObjType],
*,
returnClone: bool = True) -> OffsetIterator[ChangedM21ObjType]:
...
x: OffsetIterator[ChangedM21ObjType] = self.__class__(self.streamObj)
return x

@overload
def getElementsByClass(self,
classFilterList: Iterable[Type[ChangedM21ObjType]],
*,
returnClone: bool = True) -> OffsetIterator[base.Music21Object]:
...
returnClone: bool = True) -> OffsetIterator[M21ObjType]:
x: OffsetIterator[M21ObjType] = self.__class__(self.streamObj)
return x


def getElementsByClass(self,
Expand All @@ -1619,7 +1631,7 @@ def getElementsByClass(self,


# -----------------------------------------------------------------------------
class RecursiveIterator(StreamIterator[M21ObjType], collections.abc.Sequence):
class RecursiveIterator(StreamIterator[M21ObjType]):
'''
One of the most powerful iterators in music21. Generally not called
directly, but created by being invoked on a stream with `Stream.recurse()`
Expand Down Expand Up @@ -1971,29 +1983,33 @@ def getElementsByOffsetInHierarchy(
def getElementsByClass(self,
classFilterList: str,
*,
returnClone: bool = True) -> RecursiveIterator[base.Music21Object]:
...
returnClone: bool = True) -> RecursiveIterator[M21ObjType]:
x: RecursiveIterator[M21ObjType] = self.__class__(self.streamObj)
return x # dummy code remove when Astroid #1015 is fixed.

@overload
def getElementsByClass(self,
classFilterList: Iterable[str],
*,
returnClone: bool = True) -> RecursiveIterator[base.Music21Object]:
...
returnClone: bool = True) -> RecursiveIterator[M21ObjType]:
x: RecursiveIterator[M21ObjType] = self.__class__(self.streamObj)
return x # dummy code

@overload
def getElementsByClass(self,
classFilterList: Type[ChangedM21ObjType],
*,
returnClone: bool = True) -> RecursiveIterator[ChangedM21ObjType]:
...
x: RecursiveIterator[ChangedM21ObjType] = self.__class__(self.streamObj)
return x # dummy code

@overload
def getElementsByClass(self,
classFilterList: Iterable[Type[ChangedM21ObjType]],
*,
returnClone: bool = True) -> RecursiveIterator[base.Music21Object]:
...
returnClone: bool = True) -> RecursiveIterator[M21ObjType]:
x: RecursiveIterator[M21ObjType] = self.__class__(self.streamObj)
return x # dummy code


def getElementsByClass(self: _SIter,
Expand All @@ -2005,7 +2021,7 @@ def getElementsByClass(self: _SIter,
],
*,
returnClone: bool = True
) -> Union[RecursiveIterator[base.Music21Object],
) -> Union[RecursiveIterator[M21ObjType],
RecursiveIterator[ChangedM21ObjType]]:
out = super().getElementsByClass(classFilterList, returnClone=returnClone)
if isinstance(classFilterList, type) and issubclass(classFilterList, base.Music21Object):
Expand Down

0 comments on commit 448b35f

Please sign in to comment.