Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ddp unused param #218

Merged
merged 4 commits into from
Jan 11, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion TinyCLIP/src/open_clip/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""

import functools
import inspect
from copy import deepcopy
import os
import random
Expand All @@ -12,6 +13,7 @@
from argparse import Namespace

from dataclasses import dataclass
import functools
import logging
import math
from typing import Tuple, Union, Callable, Optional
Expand All @@ -21,11 +23,13 @@
import torch.nn.functional as F
from torch import nn
from torch.utils.checkpoint import checkpoint
# apply the non-reentrant variant of checkpoint
if 'use_reentrant' in inspect.signature(checkpoint).parameters:
checkpoint = functools.partial(checkpoint, use_reentrant=False)

from .timm_model import TimmModel
from .utils import freeze_batch_norm_2d, to_2tuple
from .resnet import ModifiedResNet
from .weight_inherit import weight_inherit
from .l0module import L0Module


Expand Down Expand Up @@ -901,6 +905,26 @@ def image_encoder_without_ddp(self, encoder):
self._image_encoder = encoder
self._without_ddp[0] = self._image_encoder

@property
def text_encoder_without_ddp(self):
return self._without_ddp[1]

@text_encoder_without_ddp.setter
def text_encoder_without_ddp(self, encoder):
assert self.used_ddp is False
self._text_encoder = encoder
self._without_ddp[1] = self._text_encoder

@property
def logit_scale_without_ddp(self):
return self._without_ddp[2]

@logit_scale_without_ddp.setter
def logit_scale_without_ddp(self, logit_scale):
assert self.used_ddp is False
self._logit_scale = logit_scale
self._without_ddp[2] = self._logit_scale

@property
def visual(self):
return self.image_encoder_without_ddp.visual
Expand Down