Skip to content

Latest commit

 

History

History
711 lines (594 loc) · 21.2 KB

openai.md

File metadata and controls

711 lines (594 loc) · 21.2 KB
title tags
Upgrade to OpenAI Python SDK v1.X
python
openai
migration

Convert OpenAI from openai version to the v1 version.

engine marzano(0.1)
language python

pattern rename_resource() {
    or {
        `Audio` => `audio`,
        `ChatCompletion` => `chat.completions`,
        `Completion` => `completions`,
        `Edit` => `edits`,
        `Embedding` => `embeddings`,
        `File` => `files`,
        `FineTune` => `fine_tunes`,
        `FineTuningJob` => `fine_tuning`,
        `Image` => `images`,
        `Model` => `models`,
        `Moderation` => `moderations`,
    }
}

pattern rename_resource_cls() {
    or {
        r"Audio" => `resources.Audio`,
        r"ChatCompletion" => `resources.chat.Completions`,
        r"Completion" => `resources.Completions`,
        r"Edit" => `resources.Edits`,
        r"Embedding" => `resources.Embeddings`,
        r"File" => `resources.Files`,
        r"FineTune" => `resources.FineTunes`,
        r"FineTuningJob" => `resources.FineTuning`,
        r"Image" => `resources.Images`,
        r"Model" => `resources.Models`,
        r"Moderation" => `resources.Moderations`,
    }
}

pattern deprecated_resource() {
    or {
        `Customer`,
        `Deployment`,
        `Engine`,
        `ErrorObject`,
    }
}

pattern deprecated_resource_cls() {
    or {
        r"Customer",
        r"Deployment",
        r"Engine",
        r"ErrorObject",
    }
}


pattern rename_func($has_sync, $has_async, $res, $stmt, $params, $client) {
    $func where {
        if ($func <: r"a([a-zA-Z0-9]+)"($func_rest)) {
            $has_async = `true`,
            $func => $func_rest,
            if ($client <: undefined) {
                $stmt => `aclient.$res.$func($params)`,
            } else {
                $stmt => `$client.$res.$func($params)`,
            }
        } else {
            $has_sync = `true`,
            if ($client <: undefined) {
                $stmt => `client.$res.$func($params)`,
            } else {
                $stmt => `$client.$res.$func($params)`,
            }
        },
        // Fix function renames
        if ($res <: `Image`) {
          $func => `generate`
        }
    }
}

pattern change_import($has_sync, $has_async, $need_openai_import, $azure, $client_params) {
    $stmt where {
        $imports_and_defs = [],

        if ($need_openai_import <:  `true`) {
            $imports_and_defs += `import openai`,
        },

        if ($azure <: true) {
          $client = `AzureOpenAI`,
          $aclient = `AsyncAzureOpenAI`,
        } else {
          $client = `OpenAI`,
          $aclient = `AsyncOpenAI`,
        },

        $formatted_params = join(list = $client_params, separator = `,\n`),

        if (and { $has_sync <: `true`, $has_async <: `true` }) {
            $imports_and_defs += `from openai import $client, $aclient`,
            $imports_and_defs += ``, // Blank line
            $imports_and_defs += `client = $client($formatted_params)`,
            $imports_and_defs += `aclient = $aclient($formatted_params)`,
        } else if ($has_sync <: `true`) {
            $imports_and_defs += `from openai import $client`,
            $imports_and_defs += ``, // Blank line
            $imports_and_defs += `client = $client($formatted_params)`,
        } else if ($has_async <: `true`) {
            $imports_and_defs += `from openai import $aclient`,
            $imports_and_defs += ``, // Blank line
            $imports_and_defs += `aclient = $aclient($formatted_params)`,
        },

        $formatted = join(list = $imports_and_defs, separator=`\n`),
        $stmt => `$formatted`,
    }
}

pattern rewrite_whole_fn_call($import, $has_sync, $has_async, $res, $func, $params, $stmt, $body, $client, $azure) {
    or {
        rename_resource() where {
            $import = `true`,
            $func <: rename_func($has_sync, $has_async, $res, $stmt, $params, $client),
            if ($azure <: true) {
              $params <: maybe contains bubble `engine` => `model`
            }
        },
        deprecated_resource() as $dep_res where {
            $stmt_whole = $stmt,
            if ($body <: contains `$_ = $stmt` as $line) {
                $stmt_whole = $line,
            },
            $stmt_whole => todo(message=`The resource '$dep_res' has been deprecated`, target=$stmt_whole),
        }
    }
}

pattern unittest_patch() {
    or {
        decorated_definition($decorators, definition=$_) where {
            $decorators <: contains bubble decorator(value=`patch($cls_path)`) as $stmt where {
                $cls_path <: contains r"openai\.([a-zA-Z0-9]+)(?:.[^,]+)?"($res),
                if ($res <: rename_resource_cls()) {} else {
                    $res <: deprecated_resource_cls(),
                    $stmt => todo(message=`The resource '$res' has been deprecated`, target=$stmt),
                }
            }
        },
        function_definition($body) where {
            $body <: contains bubble($body) or {
                `patch.object($params)`,
                `patch($params)`,
            } as $stmt where {
                $params <: contains bubble($body, $stmt) r"openai\.([a-zA-Z0-9]+)(?:.[^,]+)?"($res) where or {
                    $res <: rename_resource_cls(),
                    and {
                        $res <: deprecated_resource_cls(),
                        $line = $stmt,
                        if ($body <: contains or { `with $stmt: $_`, `with $stmt as $_: $_` } as $l) {
                            $line = $l,
                        },
                        $line => todo(message=`The resource '$res' has been deprecated`, target=$line),
                    }
                }
            },
        }
    }
}

pattern pytest_patch() {
    decorated_definition($decorators, $definition) where {
        $decorators <: contains decorator(value=`pytest.fixture`),
        $definition <: bubble function_definition($body, $parameters) where {
            $parameters <: [$monkeypatch, ...],
            $body <: contains bubble($monkeypatch) or {
                `$monkeypatch.setattr($params)` as $stmt where {
                    $params <: contains bubble($stmt) r"openai\.([a-zA-Z0-9]+)(?:.[^,]+)?"($res) where or {
                        $res <: rename_resource_cls(),
                        $stmt => todo(message=`The resource '$res' has been deprecated`, target=$stmt),
                    }
                },
                `monkeypatch.delattr($params)` as $stmt where {
                    $params <: contains bubble($stmt) r"openai\.([a-zA-Z0-9]+)(?:.[^,]+)?"($res) where or {
                        $res <: rename_resource_cls(),
                        $stmt => todo(message=`The resource '$res' has been deprecated`, target=$stmt),
                    }
                },
            }
        },
    },
}

pattern fix_object_accessing($var) {
  or {
    `$x['$y']` as $sub => `$x.$y` where {
      $sub <: contains $var
    },
    `$x.get("$y")` => `$x.$y` where {
      $x <: contains $var
    }
  }
}

// When there is a variable used by an openai call, make sure it isn't subscripted
pattern fix_downstream_openai_usage() {
    $var where {
        $program <: maybe contains fix_object_accessing($var),
        $program <: maybe contains `for $chunk in $var: $body` where {
          $body <: maybe contains fix_object_accessing($chunk)
        }
    }
}

pattern openai_main($client, $azure) {
    $body where {
        if ($client <: undefined) {
            $need_openai_import = `false`,
            $create_client = true,
        } else {
            $need_openai_import = `true`,
            $create_client = false,
        },
        if ($azure <: undefined) {
          $azure = false,
        },
        $has_openai_import = `false`,
        $has_partial_import = `false`,
        $has_sync = `false`,
        $has_async = `false`,

        $client_params = [],

        $body <: any {
          // Mark all the places where we they configure openai as something that requires manual intervention
          if ($client <: undefined) {
            contains bubble($need_openai_import, $azure, $client_params) `openai.$field = $val` as $setter where {
              $field <: or {
                `api_type` where {
                  $res = .,
                  if ($val <: or {`"azure"`, `"azure_ad"`}) {
                    $azure = true
                  },
                },
                `api_base` where {
                  $azure <: true,
                  $client_params += `azure_endpoint=$val`,
                  $res = .,
                },
                `api_key` where {
                  $res = .,
                  $client_params += `api_key=$val`,
                },
                `api_version` where {
                  $res = .,
                  // Only Azure has api_version
                  $azure = true,
                  $client_params += `api_version=$val`,
                },
                $_ where {
                  // Rename the field, if necessary
                  if ($field <: `api_base`) {
                    $new_name = `base_url`,
                  } else {
                    $new_name = $field
                  },
                  $res = todo(message=`The 'openai.$field' option isn't read in the client API. You will need to pass it when you instantiate the client, e.g. 'OpenAI($new_name=$val)'`, target=$setter),
                  $need_openai_import = `true`,
                }
              }
            } => $res
          },
          // Remap errors
          contains bubble($need_openai_import) `openai.error.$exp` => `openai.$exp` where {
            $need_openai_import = `true`,
          },
          contains `import openai` as $import_stmt where {
              $body <: contains bubble($has_sync, $has_async, $has_openai_import, $body, $client, $azure) `openai.$res.$func($params)` as $stmt where {
                  $res <: rewrite_whole_fn_call(import = $has_openai_import, $has_sync, $has_async, $res, $func, $params, $stmt, $body, $client, $azure),
                  $stmt <: maybe within bubble($stmt) `$var = $stmt` where {
                      $var <: fix_downstream_openai_usage()
                  }
              },
          },
          contains `from openai import $resources` as $partial_import_stmt where {
            $has_partial_import = `true`,
            $body <: contains bubble($has_sync, $has_async, $resources, $client, $azure) `$res.$func($params)` as $stmt where {
                $resources <: contains $res,
                $res <: rewrite_whole_fn_call($import, $has_sync, $has_async, $res, $func, $params, $stmt, $body, $client, $azure),
            }
          },
          contains unittest_patch(),
          contains pytest_patch(),
        },

        if ($create_client <: true) {
            if ($has_openai_import <: `true`) {
                $import_stmt <: change_import($has_sync, $has_async, $need_openai_import, $azure, $client_params),
                if ($has_partial_import <: `true`) {
                    $partial_import_stmt => .,
                },
            } else if ($has_partial_import <: `true`) {
                $partial_import_stmt <: change_import($has_sync, $has_async, $need_openai_import, $azure, $client_params),
            },
        },
    }
}

file($body) where {
  // No client means instantiate one per file
  $body <: openai_main()
}

Change openai import to Sync

import openai

completion = openai.Completion.create(model="davinci-002", prompt="Hello world")
chat_completion = openai.ChatCompletion.create(model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Hello world"}])
from openai import OpenAI

client = OpenAI()

completion = client.completions.create(model="davinci-002", prompt="Hello world")
chat_completion = client.chat.completions.create(model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Hello world"}])

Change openai import to Async

import openai

completion = await openai.Completion.acreate(model="davinci-002", prompt="Hello world")
chat_completion = await openai.ChatCompletion.acreate(model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Hello world"}])
from openai import AsyncOpenAI

aclient = AsyncOpenAI()

completion = await aclient.completions.create(model="davinci-002", prompt="Hello world")
chat_completion = await aclient.chat.completions.create(model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Hello world"}])

Change openai import to Both

import openai

completion = openai.Completion.create(model="davinci-002", prompt="Hello world")
chat_completion = openai.ChatCompletion.create(model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Hello world"}])

a_completion = await openai.Completion.acreate(model="davinci-002", prompt="Hello world")
a_chat_completion = await openai.ChatCompletion.acreate(model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Hello world"}])
from openai import OpenAI, AsyncOpenAI

client = OpenAI()
aclient = AsyncOpenAI()

completion = client.completions.create(model="davinci-002", prompt="Hello world")
chat_completion = client.chat.completions.create(model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Hello world"}])

a_completion = await aclient.completions.create(model="davinci-002", prompt="Hello world")
a_chat_completion = await aclient.chat.completions.create(model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Hello world"}])

Change different kinds of import

import openai
from openai import ChatCompletion

completion = openai.Completion.create(model="davinci-002", prompt="Hello world")
chat_completion = await ChatCompletion.acreate(model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Hello world"}])
from openai import OpenAI, AsyncOpenAI

client = OpenAI()
aclient = AsyncOpenAI()

completion = client.completions.create(model="davinci-002", prompt="Hello world")
chat_completion = await aclient.chat.completions.create(model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Hello world"}])

Manual config required

import openai

if openai_proxy:
    openai.proxy = openai_proxy
    openai.api_base = self.openai_api_base
import openai

if openai_proxy:
    # TODO: The 'openai.proxy' option isn't read in the client API. You will need to pass it when you instantiate the client, e.g. 'OpenAI(proxy=openai_proxy)'
    # openai.proxy = openai_proxy
    # TODO: The 'openai.api_base' option isn't read in the client API. You will need to pass it when you instantiate the client, e.g. 'OpenAI(base_url=self.openai_api_base)'
    # openai.api_base = self.openai_api_base

Remap errors

import openai

try:
    completion = openai.Completion.create(model="davinci-002", prompt="Hello world")
    chat_completion = openai.ChatCompletion.create(model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Hello world"}])
except openai.error.RateLimitError as err:
    pass
import openai
from openai import OpenAI

client = OpenAI()

try:
    completion = client.completions.create(model="davinci-002", prompt="Hello world")
    chat_completion = client.chat.completions.create(model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Hello world"}])
except openai.RateLimitError as err:
    pass

Mark deprecated api usage

import openai

completion = openai.Customer.create(model="davinci-002", prompt="Hello world")
chat_completion = openai.Deployment.create(model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Hello world"}])
import openai

# TODO: The resource 'Customer' has been deprecated
# completion = openai.Customer.create(model="davinci-002", prompt="Hello world")
# TODO: The resource 'Deployment' has been deprecated
# chat_completion = openai.Deployment.create(model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Hello world"}])

Migrate unittest

@patch('openai.Completion')
@patch('openai.Customer')
def test(MockClass1, MockClass2):
    with patch.object(openai.Completion, 'method', return_value=None):
        pass
    with patch.object(openai.Customer, 'method', return_value=None):
        pass
    with patch("openai.Engine.list"):
        pass
    pass
@patch('openai.resources.Completions')
# TODO: The resource 'Customer' has been deprecated
# @patch('openai.Customer')
def test(MockClass1, MockClass2):
    with patch.object(openai.resources.Completions, 'method', return_value=None):
        pass
    # TODO: The resource 'Customer' has been deprecated
    # with patch.object(openai.Customer, 'method', return_value=None):
    #         pass
    # TODO: The resource 'Engine' has been deprecated
    # with patch("openai.Engine.list"):
    #         pass
    pass

Migrate pytest

@pytest.fixture
def mocked_GET_pos(monkeypatch):
    monkeypatch.setattr(openai.Completion, 'GET', lambda: True)
    monkeypatch.delattr(openai.Completion, 'PUT', lambda: True)

@pytest.fixture
def mocked_GET_neg(monkeypatch):
    monkeypatch.setattr(openai.Customer, 'GET', lambda: False)

@pytest.fixture
def mocked_GET_raises(monkeypatch, other):
    def raise_():
        raise Exception()
    monkeypatch.setattr(openai.Engine.list, 'GET', raise_)
    monkeypatch.delattr(openai.Engine.list, 'PUT', lambda: True)
@pytest.fixture
def mocked_GET_pos(monkeypatch):
    monkeypatch.setattr(openai.resources.Completions, 'GET', lambda: True)
    monkeypatch.delattr(openai.resources.Completions, 'PUT', lambda: True)

@pytest.fixture
def mocked_GET_neg(monkeypatch):
    # TODO: The resource 'Customer' has been deprecated
    # monkeypatch.setattr(openai.Customer, 'GET', lambda: False)

@pytest.fixture
def mocked_GET_raises(monkeypatch, other):
    def raise_():
        raise Exception()
    # TODO: The resource 'Engine' has been deprecated
    # monkeypatch.setattr(openai.Engine.list, 'GET', raise_)
    # TODO: The resource 'Engine' has been deprecated
    # monkeypatch.delattr(openai.Engine.list, 'PUT', lambda: True)

Image creation has been renamed

The Image.create method has been renamed to image.generate.

import openai

openai.Image.create(file=file)
from openai import OpenAI

client = OpenAI()

client.images.generate(file=file)

Use Azure OpenAI

If api_type is set to Azure before, you should now use the AzureOpenAI client.

import os
import openai

openai.api_type = "azure"
openai.api_base = os.getenv("AZURE_OPENAI_ENDPOINT")
openai.api_key = os.getenv("AZURE_OPENAI_KEY")
openai.api_version = "2023-05-15"

response = openai.ChatCompletion.create(
    engine="gpt-35-turbo",
    messages=[
        {"role": "system", "content": "You are a helpful assistant."},
    ]
)
import os
from openai import AzureOpenAI

client = AzureOpenAI(
  azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
  api_key=os.getenv("AZURE_OPENAI_KEY"),
  api_version="2023-05-15"
)


response = client.chat.completions.create(
  model="gpt-35-turbo",
  messages=[
    {"role": "system", "content": "You are a helpful assistant."},
  ]
)

Fix subscripting

The new API does not support subscripting on the outputs.

import openai

model, token_limit, prompt_cost, comp_cost = 'gpt-4-32k', 32_768, 0.06, 0.12

completion = openai.ChatCompletion.create(
    model=model,
    messages=[
        {"role": "system", "content": system},
        {"role": "user", "content":
         user + text},
    ]
)
output = completion['choices'][0]['message']['content']

prom = completion['usage']['prompt_tokens']
comp = completion['usage']['completion_tokens']

# unrelated variable
foo = something['else']
from openai import OpenAI

client = OpenAI()

model, token_limit, prompt_cost, comp_cost = 'gpt-4-32k', 32_768, 0.06, 0.12

completion = client.chat.completions.create(model=model,
messages=[
    {"role": "system", "content": system},
    {"role": "user", "content":
     user + text},
])
output = completion.choices[0].message.content

prom = completion.usage.prompt_tokens
comp = completion.usage.completion_tokens

# unrelated variable
foo = something['else']

Fix completion streaming

import openai

completion = openai.ChatCompletion.create(
    model=model,
    messages=[
        {"role": "system", "content": system},
        {"role": "user", "content":
         user + text},
    ],
    stream=True
)

for chunk in completion:
    print(chunk)
    print(chunk.choices[0].delta.get("content"))
    print("****************")
from openai import OpenAI

client = OpenAI()

completion = client.chat.completions.create(model=model,
  messages=[
      {"role": "system", "content": system},
      {"role": "user", "content":
      user + text},
  ],
  stream=True
)

for chunk in completion:
    print(chunk)
    print(chunk.choices[0].delta.content)
    print("****************")

Fix multiple exceptions

Repair openai/openai-python#1165, ensure we fix all exceptions in one pass.

try:
   # Some completions handler
   pass
except openai.error.RateLimitError as e:
   print(e)
except openai.error.AuthenticationError as e:
   print(e)
except openai.error.InvalidRequestError as e:
    print(e)

Fixed:

try:
   # Some completions handler
   pass
except openai.RateLimitError as e:
   print(e)
except openai.AuthenticationError as e:
   print(e)
except openai.InvalidRequestError as e:
    print(e)