diff --git a/server/clip_server/model/model.py b/server/clip_server/model/model.py index e4cfe3c3a..12de55e5b 100644 --- a/server/clip_server/model/model.py +++ b/server/clip_server/model/model.py @@ -464,7 +464,9 @@ def load_openai_model( # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use model = model.to(device) - if dtype == torch.float32 or dtype.startswith('amp'): + if dtype == torch.float32 or ( + isinstance(dtype, str) and dtype.startswith('amp') + ): model.float() elif dtype == torch.bfloat16: convert_weights_to_lp(model, dtype=torch.bfloat16)