Skip to content

Commit

Permalink
Add suppport for azure openai model
Browse files Browse the repository at this point in the history
  • Loading branch information
ignorejjj committed Jul 27, 2024
1 parent 96b20c8 commit cbd4c8d
Showing 1 changed file with 15 additions and 4 deletions.
19 changes: 15 additions & 4 deletions flashrag/generator/openai_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np

import asyncio
from openai import AsyncOpenAI
from openai import AsyncOpenAI, AsyncAzureOpenAI
import tiktoken

class OpenaiGenerator:
Expand All @@ -20,10 +20,12 @@ def __init__(self, config):
if self.openai_setting['api_key'] is None:
self.openai_setting['api_key'] = os.getenv('OPENAI_API_KEY')

if 'api_type' in self.openai_setting and self.openai_setting['api_type'] == "azure":
del self.openai_setting["api_type"]
self.client = AsyncAzureOpenAI(**self.openai_setting)
else:
self.client = AsyncOpenAI(**self.openai_setting)

self.client = AsyncOpenAI(
**self.openai_setting
)
self.tokenizer = tiktoken.encoding_for_model(self.model_name)

async def get_response(self, input: List, **params):
Expand Down Expand Up @@ -56,6 +58,15 @@ def generate(self, input_list: List[List], batch_size=None, return_scores=False,
# deal with generation params
generation_params = deepcopy(self.generation_params)
generation_params.update(params)
if 'do_sample' in generation_params:
generation_params.pop('do_sample')

max_tokens = params.pop('max_tokens', None) or params.pop('max_new_tokens', None)
if max_tokens is not None:
generation_params['max_tokens'] = max_tokens
else:
generation_params['max_tokens'] = generation_params.get('max_tokens', generation_params.pop('max_new_tokens', None))
generation_params.pop('max_new_tokens', None)

if return_scores:
if generation_params.get('logprobs') is not None:
Expand Down

0 comments on commit cbd4c8d

Please sign in to comment.