-
Notifications
You must be signed in to change notification settings - Fork 7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
More Multiweight support cleanups (#4948)
* Updated the link for densenet recipe. * Set default value of `num_classes` and `num_keypoints` to `None` * Provide helper methods for parameter checks to reduce duplicate code. * Throw errors on silent config overwrites from weight meta-data and legacy builders. * Changing order of arguments + fixing mypy. * Make the builders fully BC. * Add "default" weights support that returns always the best weights.
- Loading branch information
Showing
32 changed files
with
703 additions
and
420 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import warnings | ||
from typing import Any, Dict, Optional, TypeVar | ||
|
||
from ._api import Weights | ||
|
||
|
||
W = TypeVar("W", bound=Weights) | ||
V = TypeVar("V") | ||
|
||
|
||
def _deprecated_param( | ||
kwargs: Dict[str, Any], deprecated_param: str, new_param: str, default_value: Optional[W] | ||
) -> Optional[W]: | ||
warnings.warn(f"The parameter '{deprecated_param}' is deprecated, please use '{new_param}' instead.") | ||
if kwargs.pop(deprecated_param): | ||
if default_value is not None: | ||
return default_value | ||
else: | ||
raise ValueError("No checkpoint is available for model.") | ||
else: | ||
return None | ||
|
||
|
||
def _deprecated_positional(kwargs: Dict[str, Any], deprecated_param: str, new_param: str, default_value: V) -> None: | ||
warnings.warn( | ||
f"The positional parameter '{deprecated_param}' is deprecated, please use keyword parameter '{new_param}'" | ||
+ " instead." | ||
) | ||
kwargs[deprecated_param] = default_value | ||
|
||
|
||
def _ovewrite_named_param(kwargs: Dict[str, Any], param: str, new_value: V) -> None: | ||
if param in kwargs: | ||
if kwargs[param] != new_value: | ||
raise ValueError(f"The parameter '{param}' expected value {new_value} but got {kwargs[param]} instead.") | ||
else: | ||
kwargs[param] = new_value | ||
|
||
|
||
def _ovewrite_value_param(param: Optional[V], new_value: V) -> V: | ||
if param is not None: | ||
if param != new_value: | ||
raise ValueError(f"The parameter '{param}' expected value {new_value} but got {param} instead.") | ||
return new_value |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.