Skip to content

Commit

Permalink
feat: foundation-model sdk
Browse files Browse the repository at this point in the history
  • Loading branch information
GermanVor committed Dec 27, 2024
1 parent f982538 commit 48e0018
Show file tree
Hide file tree
Showing 7 changed files with 358 additions and 1 deletion.
30 changes: 30 additions & 0 deletions clients/ai-foundation_models-v1/sdk/embeddingSdk.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import { Client } from 'nice-grpc';
import { embeddingService } from '..';
import {
EmbeddingsServiceService,
TextEmbeddingRequest,
} from '../generated/yandex/cloud/ai/foundation_models/v1/embedding/embedding_service';
import { ClientCallArgs, SessionArg, TypeFromProtoc } from './types';

export type TextEmbeddingProps = Omit<TypeFromProtoc<TextEmbeddingRequest, 'text'>, 'modelUri'> & {
modelId: string;
folderId: string;
};

export class EmbeddingSdk {
private embeddingClient: Client<typeof EmbeddingsServiceService, ClientCallArgs>;

constructor(session: SessionArg) {
this.embeddingClient = session.client(embeddingService.EmbeddingsServiceClient);
}

textEmbedding(params: TextEmbeddingProps, args?: ClientCallArgs) {
const { modelId, folderId, ...restParams } = params;
const modelUri = `gpt://${folderId}/${modelId}`;

return this.embeddingClient.textEmbedding(
embeddingService.TextEmbeddingRequest.fromPartial({ ...restParams, modelUri }),
args,
);
}
}
38 changes: 38 additions & 0 deletions clients/ai-foundation_models-v1/sdk/imageGenerationSdk.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import { Client } from 'nice-grpc';
import { imageGenerationService } from '..';
import {
ImageGenerationAsyncServiceService,
ImageGenerationRequest,
} from '../generated/yandex/cloud/ai/foundation_models/v1/image_generation/image_generation_service';
import { ClientCallArgs, SessionArg, TypeFromProtoc } from './types';

export type GenerateImageProps = Omit<
TypeFromProtoc<ImageGenerationRequest, 'messages'>,
'modelUri'
> & {
modelId: string;
folderId: string;
};

export class ImageGenerationSdk {
private imageGenerationClient: Client<
typeof ImageGenerationAsyncServiceService,
ClientCallArgs
>;

constructor(session: SessionArg) {
this.imageGenerationClient = session.client(
imageGenerationService.ImageGenerationAsyncServiceClient,
);
}

generateImage(params: GenerateImageProps, args?: ClientCallArgs) {
const { modelId, folderId, ...restParams } = params;
const modelUri = `gpt://${folderId}/${modelId}`;

return this.imageGenerationClient.generate(
imageGenerationService.ImageGenerationRequest.fromPartial({ ...restParams, modelUri }),
args,
);
}
}
4 changes: 4 additions & 0 deletions clients/ai-foundation_models-v1/sdk/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
export * from './embeddingSdk';
export * from './imageGenerationSdk';
export * from './textClassificationSdk';
export * from './textGenerationSdk';
63 changes: 63 additions & 0 deletions clients/ai-foundation_models-v1/sdk/textClassificationSdk.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import { Client } from 'nice-grpc';
import { ClientCallArgs, SessionArg, TypeFromProtoc } from './types';
import {
FewShotTextClassificationRequest,
TextClassificationRequest,
TextClassificationServiceService,
} from '../generated/yandex/cloud/ai/foundation_models/v1/text_classification/text_classification_service';
import { textClassificationService } from '..';

export type TextClassificationProps = Omit<
TypeFromProtoc<TextClassificationRequest, 'text'>,
'modelUri'
> & {
modelId: string;
folderId: string;
};

export type FewShotTextClassificationProps = Omit<
TypeFromProtoc<FewShotTextClassificationRequest, 'text'>,
'modelUri'
> & {
modelId: string;
folderId: string;
};

export class TextClassificationSdk {
private textClassificationClient: Client<
typeof TextClassificationServiceService,
ClientCallArgs
>;

constructor(session: SessionArg) {
this.textClassificationClient = session.client(
textClassificationService.TextClassificationServiceClient,
);
}

classifyText(params: TextClassificationProps, args?: ClientCallArgs) {
const { modelId, folderId, ...restParams } = params;
const modelUri = `gpt://${folderId}/${modelId}`;

return this.textClassificationClient.classify(
textClassificationService.TextClassificationRequest.fromPartial({
...restParams,
modelUri,
}),
args,
);
}

classifyTextFewShort(params: FewShotTextClassificationProps, args?: ClientCallArgs) {
const { modelId, folderId, ...restParams } = params;
const modelUri = `gpt://${folderId}/${modelId}`;

return this.textClassificationClient.fewShotClassify(
textClassificationService.FewShotTextClassificationRequest.fromPartial({
...restParams,
modelUri,
}),
args,
);
}
}
84 changes: 84 additions & 0 deletions clients/ai-foundation_models-v1/sdk/textGenerationSdk.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import { Client } from 'nice-grpc';
import { textGenerationService } from '..';

import { ClientCallArgs, SessionArg, TypeFromProtoc } from './types';
import {
CompletionRequest,
TextGenerationAsyncServiceService,
TextGenerationServiceService,
TokenizeRequest,
TokenizerServiceService,
} from '../generated/yandex/cloud/ai/foundation_models/v1/text_generation/text_generation_service';

export type CompletionProps = Omit<TypeFromProtoc<CompletionRequest, 'messages'>, 'modelUri'> & {
modelId: string;
folderId: string;
};

export type TokenizeProps = Omit<TypeFromProtoc<TokenizeRequest, 'text'>, 'modelUri'> & {
modelId: string;
folderId: string;
};

export class TextGenerationSdk {
private textGenerationClient: Client<typeof TextGenerationServiceService, ClientCallArgs>;
private tokenizerClient: Client<typeof TokenizerServiceService, ClientCallArgs>;
private textGenerationAsyncClient: Client<
typeof TextGenerationAsyncServiceService,
ClientCallArgs
>;

constructor(session: SessionArg) {
this.textGenerationClient = session.client(
textGenerationService.TextGenerationServiceClient,
);

this.tokenizerClient = session.client(textGenerationService.TokenizerServiceClient);

this.textGenerationAsyncClient = session.client(
textGenerationService.TextGenerationAsyncServiceClient,
);
}

tokenize(params: TokenizeProps, args?: ClientCallArgs) {
const { modelId, folderId, ...restParams } = params;
const modelUri = `gpt://${folderId}/${modelId}`;

return this.tokenizerClient.tokenize(
textGenerationService.TokenizeRequest.fromPartial({ ...restParams, modelUri }),
args,
);
}

tokenizeCompletion(params: CompletionProps, args?: ClientCallArgs) {
const { modelId, folderId, ...restParams } = params;
const modelUri = `gpt://${folderId}/${modelId}`;

return this.tokenizerClient.tokenizeCompletion(
textGenerationService.CompletionRequest.fromPartial({ ...restParams, modelUri }),
args,
);
}

completion(params: CompletionProps, args?: ClientCallArgs) {
const { modelId, folderId, ...restParams } = params;
const modelUri = `gpt://${folderId}/${modelId}`;

return this.textGenerationClient.completion(
textGenerationService.CompletionRequest.fromPartial({ ...restParams, modelUri }),
args,
);
}

completionAsOperation(params: CompletionProps, args?: ClientCallArgs) {
const { modelId, folderId, ...restParams } = params;
const modelUri = `gpt://${folderId}/${modelId}`;

const operationP = this.textGenerationAsyncClient.completion(
textGenerationService.CompletionRequest.fromPartial({ ...restParams, modelUri }),
args,
);

return operationP;
}
}
138 changes: 138 additions & 0 deletions clients/ai-foundation_models-v1/sdk/types.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import { ChannelCredentials, ChannelOptions, Client, ServiceDefinition } from '@grpc/grpc-js';
import { ClientError, RawClient, Status } from 'nice-grpc';
import { DeadlineOptions } from 'nice-grpc-client-middleware-deadline';
import { NormalizedServiceDefinition } from 'nice-grpc/lib/service-definitions';

import { DeepPartial } from '../generated/typeRegistry';

type RetryOptions = {
/**
* Boolean indicating whether retries are enabled.
*
* If the method is marked as idempotent in Protobuf, i.e. has
*
* option idempotency_level = IDEMPOTENT;
*
* then the default is `true`. Otherwise, the default is `false`.
*
* Method options currently work only when compiling with `ts-proto`.
*/
retry?: boolean;
/**
* Base delay between retry attempts in milliseconds.
*
* Defaults to 1000.
*
* Example: if `retryBaseDelayMs` is 100, then retries will be attempted in
* 100ms, 200ms, 400ms etc. (not counting jitter).
*/
retryBaseDelayMs?: number;
/**
* Maximum delay between attempts in milliseconds.
*
* Defaults to 15 seconds.
*
* Example: if `retryBaseDelayMs` is 1000 and `retryMaxDelayMs` is 3000, then
* retries will be attempted in 1000ms, 2000ms, 3000ms, 3000ms etc (not
* counting jitter).
*/
retryMaxDelayMs?: number;
/**
* Maximum for the total number of attempts. `Infinity` is supported.
*
* Defaults to 1, i.e. a single retry will be attempted.
*/
retryMaxAttempts?: number;
/**
* Array of retryable status codes.
*
* Default is `[UNKNOWN, RESOURCE_EXHAUSTED, INTERNAL, UNAVAILABLE]`.
*/
retryableStatuses?: Status[];
/**
* Called after receiving error with retryable status code before setting
* backoff delay timer.
*
* If the error code is not retryable, or the maximum attempts exceeded, this
* function will not be called and the error will be thrown from the client
* method.
*/
onRetryableError?(error: ClientError, attempt: number, delayMs: number): void;
};

export interface TokenService {
getToken: () => Promise<string>;
}

export interface GeneratedServiceClientCtor<T extends ServiceDefinition> {
service: T;

new (
address: string,
credentials: ChannelCredentials,
options?: Partial<ChannelOptions>,
): Client;
}

export interface IIAmCredentials {
serviceAccountId: string;
accessKeyId: string;
privateKey: Buffer | string;
}

export interface ISslCredentials {
rootCertificates?: Buffer;
clientPrivateKey?: Buffer;
clientCertChain?: Buffer;
}

export interface ChannelSslOptions {
rootCerts?: Buffer;
privateKey?: Buffer;
certChain?: Buffer;
}

export interface GenericCredentialsConfig {
pollInterval?: number;
ssl?: ChannelSslOptions;
headers?: Record<string, string>;
}

export interface OAuthCredentialsConfig extends GenericCredentialsConfig {
oauthToken: string;
}

export interface IamTokenCredentialsConfig extends GenericCredentialsConfig {
iamToken: string;
}

export interface ServiceAccountCredentialsConfig extends GenericCredentialsConfig {
serviceAccountJson: IIAmCredentials;
}

export type SessionConfig =
| OAuthCredentialsConfig
| IamTokenCredentialsConfig
| ServiceAccountCredentialsConfig
| GenericCredentialsConfig;

export type ClientCallArgs = DeadlineOptions & RetryOptions;

export type WrappedServiceClientType<S extends ServiceDefinition> = RawClient<
NormalizedServiceDefinition<S>,
ClientCallArgs
>;

export type ClientType = <S extends ServiceDefinition>(
clientClass: GeneratedServiceClientCtor<S>,
customEndpoint?: string,
) => WrappedServiceClientType<S>;

export type SessionArg = { client: ClientType };

export type TypeFromProtoc<
T extends { $type: string },
NotPartialKey extends keyof Omit<T, '$type'> = never,
> = {
[Key in NotPartialKey]: T[Key];
} & DeepPartial<T>;
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@
},
"scripts": {
"test": "cross-env NODE_OPTIONS=\"--max-old-space-size=4096\" jest -c config/jest.ts --passWithNoTests '.*\\.test\\.ts$'",
"lint": "eslint src config",
"lint": "eslint src config clients",
"generate-code": "ts-node scripts/generate-code.ts",
"check-endpoints": "ts-node scripts/check-endpoints.ts",
"prettier:fix:clients": "prettier clients/ --write",
Expand Down

0 comments on commit 48e0018

Please sign in to comment.