Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add API key to gRPC server and client #394

Merged
merged 3 commits into from
May 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion packages/grpc/package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@walmartlabs/cookie-cutter-grpc",
"version": "1.6.0-beta.2",
"version": "1.6.0-beta.3",
"license": "Apache-2.0",
"main": "dist/index.js",
"types": "dist/index.d.ts",
Expand Down
86 changes: 73 additions & 13 deletions packages/grpc/src/__test__/grpc.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,14 @@ import {
GrpcMetadata,
grpcSource,
IGrpcClientConfiguration,
IGrpcClientOptions,
IGrpcConfiguration,
IGrpcServerOptions,
IResponseStream,
} from "..";
import { sample } from "./Sample";

const apiKey = "token";
let nextPort = 56011;

export interface ISampleService {
Expand Down Expand Up @@ -78,16 +81,23 @@ export const SampleServiceDefinition = {
},
};

function testApp(handler: any, host?: string): CancelablePromise<void> {
function testApp(
handler: any,
host?: string,
options?: IGrpcServerOptions
): CancelablePromise<void> {
return Application.create()
.input()
.add(
grpcSource({
port: nextPort,
host,
definitions: [SampleServiceDefinition],
skipNoStreamingValidation: true,
})
grpcSource(
{
port: nextPort,
host,
definitions: [SampleServiceDefinition],
skipNoStreamingValidation: true,
},
options
)
)
.done()
.dispatch(handler)
Expand All @@ -96,13 +106,17 @@ function testApp(handler: any, host?: string): CancelablePromise<void> {

async function createClient(
host?: string,
config?: Partial<IGrpcClientConfiguration & IGrpcConfiguration>
config?: Partial<IGrpcClientConfiguration & IGrpcConfiguration>,
options?: string | IGrpcClientOptions
): Promise<ISampleService & IRequireInitialization & IDisposable> {
const client = grpcClient<ISampleService & IRequireInitialization & IDisposable>({
endpoint: `${host || "localhost"}:${nextPort++}`,
definition: SampleServiceDefinition,
...config,
});
const client = grpcClient<ISampleService & IRequireInitialization & IDisposable>(
{
endpoint: `${host || "localhost"}:${nextPort++}`,
definition: SampleServiceDefinition,
...config,
},
options
);
return client;
}

Expand All @@ -126,6 +140,29 @@ describe("gRPC source", () => {
}
});

it("serves requests with api key validation", async () => {
const app = testApp(
{
onNoStreaming: async (
request: sample.ISampleRequest,
_: IDispatchContext
): Promise<sample.ISampleResponse> => {
return { name: request.id.toString() };
},
},
undefined,
{ apiKey }
);
try {
const client = await createClient(undefined, undefined, { apiKey });
const response = await client.NoStreaming({ id: 15 });
expect(response).toMatchObject({ name: "15" });
} finally {
app.cancel();
await app;
}
});

it("serves response streams", async () => {
const app = testApp({
onStreamingOut: async (
Expand Down Expand Up @@ -235,6 +272,29 @@ describe("gRPC source", () => {
}
});

it("throws error for missing/invalid api key", async () => {
const app = testApp(
{
onNoStreaming: async (
request: sample.ISampleRequest,
_: IDispatchContext
): Promise<sample.ISampleResponse> => {
return { name: request.id.toString() };
},
},
undefined,
{ apiKey }
);
try {
const client = await createClient();
const response = client.NoStreaming({ id: 15 });
await expect(response).rejects.toThrowError(/Invalid API Key/);
} finally {
app.cancel();
await app;
}
});

it("validates that no streaming operations are exposed", () => {
const a = () =>
grpcSource({
Expand Down
21 changes: 17 additions & 4 deletions packages/grpc/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ export interface IGrpcServerConfiguration {
readonly skipNoStreamingValidation?: boolean;
}

export interface IGrpcServerOptions {
readonly apiKey?: string;
}

export interface IGrpcClientConfiguration {
readonly endpoint: string;
readonly definition: IGrpcServiceDefinition;
Expand All @@ -66,6 +70,11 @@ export interface IGrpcClientConfiguration {
readonly behavior?: Required<IComponentRuntimeBehavior>;
}

export interface IGrpcClientOptions {
readonly certPath?: string;
readonly apiKey?: string;
}

export enum GrpcMetadata {
OperationPath = "grpc.OperationPath",
ResponseStream = "grpc.ResponseStream",
Expand All @@ -80,7 +89,8 @@ export interface IResponseStream<TResponse> {
}

export function grpcSource(
configuration: IGrpcServerConfiguration & IGrpcConfiguration
configuration: IGrpcServerConfiguration & IGrpcConfiguration,
options?: IGrpcServerOptions
): IInputSource & IRequireInitialization {
configuration = config.parse<IGrpcServerConfiguration & IGrpcConfiguration>(
GrpcSourceConfiguration,
Expand All @@ -90,7 +100,7 @@ export function grpcSource(
allocator: Buffer,
}
);
return new GrpcInputSource(configuration);
return new GrpcInputSource(configuration, options);
}

export function grpcMsg(operation: IGrpcServiceMethod, request: any): IMessage {
Expand All @@ -102,7 +112,7 @@ export function grpcMsg(operation: IGrpcServiceMethod, request: any): IMessage {

export function grpcClient<T>(
configuration: IGrpcClientConfiguration & IGrpcConfiguration,
certPath?: string
certPathOrOptions?: string | IGrpcClientOptions
): T & IRequireInitialization & IDisposable {
configuration = config.parse<IGrpcClientConfiguration & IGrpcConfiguration>(
GrpcClientConfiguration,
Expand All @@ -122,5 +132,8 @@ export function grpcClient<T>(
},
}
);
return createGrpcClient<T>(configuration, certPath);
if (typeof certPathOrOptions === "string") {
certPathOrOptions = { certPath: certPathOrOptions };
}
return createGrpcClient<T>(configuration, certPathOrOptions);
}
29 changes: 25 additions & 4 deletions packages/grpc/src/internal/GrpcInputSource.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import {
OpenTracingTagKeys,
} from "@walmartlabs/cookie-cutter-core";
import {
Metadata,
sendUnaryData,
Server,
ServerCredentials,
Expand All @@ -38,7 +39,7 @@ import {
GrpcResponseStream,
GrpcStreamHandler,
} from ".";
import { GrpcMetadata, IGrpcConfiguration, IGrpcServerConfiguration } from "..";
import { GrpcMetadata, IGrpcConfiguration, IGrpcServerConfiguration, IGrpcServerOptions } from "..";
import { GrpcOpenTracingTagKeys } from "./helper";

enum GrpcMetrics {
Expand All @@ -59,7 +60,10 @@ export class GrpcInputSource implements IInputSource, IRequireInitialization {
private tracer: Tracer;
private metrics: IMetrics;

constructor(private readonly config: IGrpcServerConfiguration & IGrpcConfiguration) {
constructor(
private readonly config: IGrpcServerConfiguration & IGrpcConfiguration,
private readonly options?: IGrpcServerOptions
) {
if (!config.skipNoStreamingValidation) {
for (const def of config.definitions) {
for (const key of Object.keys(def)) {
Expand Down Expand Up @@ -180,7 +184,11 @@ export class GrpcInputSource implements IInputSource, IRequireInitialization {
if (value !== undefined) {
callback(undefined, value);
} else if (error !== undefined) {
callback(this.createError(error), null);
if ((error as ServerErrorResponse).code !== undefined) {
callback(error, null);
} else {
callback(this.createError(error), null);
}
} else {
callback(
this.createError("not implemented", status.UNIMPLEMENTED),
Expand All @@ -198,7 +206,15 @@ export class GrpcInputSource implements IInputSource, IRequireInitialization {
path: method.path,
});
});

if (this.options?.apiKey) {
if (!this.isApiKeyValid(call.metadata)) {
await msgRef.release(
undefined,
this.createError("Invalid API Key", status.UNAUTHENTICATED)
);
return;
}
}
if (!(await this.queue.enqueue(msgRef))) {
await msgRef.release(undefined, new Error("service unavailable"));
}
Expand Down Expand Up @@ -239,4 +255,9 @@ export class GrpcInputSource implements IInputSource, IRequireInitialization {
message: error.toString(),
};
}

private isApiKeyValid(meta: Metadata) {
const headerValue = meta.get("authorization");
return headerValue?.[0]?.toString() === this.options.apiKey;
}
}
36 changes: 24 additions & 12 deletions packages/grpc/src/internal/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import {
import { FORMAT_HTTP_HEADERS, Span, SpanContext, Tags, Tracer } from "opentracing";
import { performance } from "perf_hooks";
import { createGrpcConfiguration, createServiceDefinition } from ".";
import { IGrpcClientConfiguration, IGrpcConfiguration } from "..";
import { IGrpcClientConfiguration, IGrpcClientOptions, IGrpcConfiguration } from "..";

enum GrpcMetrics {
RequestSent = "cookie_cutter.grpc_client.request_sent",
Expand Down Expand Up @@ -75,23 +75,27 @@ class ClientBase implements IRequireInitialization, IDisposable {

export function createGrpcClient<T>(
config: IGrpcClientConfiguration & IGrpcConfiguration,
certPath?: string
options?: IGrpcClientOptions
): T & IDisposable & IRequireInitialization {
const serviceDef = createServiceDefinition(config.definition);
let client: Client;
const ClientType = makeGenericClientConstructor(serviceDef, undefined, undefined);
const certPath = options?.certPath;
const apiKey = options?.apiKey;
if (certPath) {
const rootCert = readFileSync(certPath);
const channelCreds = credentials.createSsl(rootCert);

const metaCallback = (_params: any, callback: (arg0: null, arg1: Metadata) => void) => {
const meta = new Metadata();
meta.add("custom-auth-header", "token");
callback(null, meta);
};

const callCreds = credentials.createFromMetadataGenerator(metaCallback);
const combCreds = credentials.combineChannelCredentials(channelCreds, callCreds);
let combCreds = channelCreds;
if (apiKey) {
const metaCallback = (_params: any, callback: (arg0: null, arg1: Metadata) => void) => {
const meta = new Metadata();
meta.add("authorization", apiKey);
callback(null, meta);
};
const callCreds = credentials.createFromMetadataGenerator(metaCallback);
combCreds = credentials.combineChannelCredentials(channelCreds, callCreds);
}
client = new ClientType(config.endpoint, combCreds, createGrpcConfiguration(config));
} else {
client = new ClientType(
Expand Down Expand Up @@ -165,12 +169,16 @@ export function createGrpcClient<T>(

const stream = await retrier.retry((bail) => {
try {
const meta = createTracingMetadata(wrapper.tracer, span);
if (!certPath && apiKey) {
meta.set("authorization", apiKey);
}
return client.makeServerStreamRequest(
method.path,
method.requestSerialize,
method.responseDeserialize,
request,
createTracingMetadata(wrapper.tracer, span),
meta,
callOptions()
);
} catch (e) {
Expand Down Expand Up @@ -239,12 +247,16 @@ export function createGrpcClient<T>(
return await retrier.retry(async (bail) => {
try {
return await new Promise((resolve, reject) => {
const meta = createTracingMetadata(wrapper.tracer, span);
if (!certPath && apiKey) {
meta.set("authorization", apiKey);
}
client.makeUnaryRequest(
method.path,
method.requestSerialize,
method.responseDeserialize,
request,
createTracingMetadata(wrapper.tracer, span),
meta,
callOptions(),
(error, value) => {
this.metrics.increment(GrpcMetrics.RequestProcessed, {
Expand Down
Loading