From d83120a45f32dcf6a8e9261e62578f4711955325 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jaras=C5=82a=C5=AD=20Viktor=C4=8Dyk?= Date: Sun, 26 Jan 2025 09:54:10 +0100 Subject: [PATCH] feat: switch DataFrame generic to its schema --- __tests__/type.test.ts | 78 ++++++++++++ package.json | 1 + polars/dataframe.ts | 256 +++++++++++++++++++++++---------------- polars/lazy/dataframe.ts | 150 ++++++++++++++--------- polars/series/index.ts | 27 ++++- yarn.lock | 8 ++ 6 files changed, 350 insertions(+), 170 deletions(-) create mode 100644 __tests__/type.test.ts diff --git a/__tests__/type.test.ts b/__tests__/type.test.ts new file mode 100644 index 00000000..22d04217 --- /dev/null +++ b/__tests__/type.test.ts @@ -0,0 +1,78 @@ +import { expectType } from "ts-expect"; + +import { + DataFrame, + type Float64, + type Int64, + type String as PlString, + type Series, +} from "../polars"; + +describe("type tests", () => { + it("is cleaned up later", () => { + const a = null as unknown as DataFrame<{ + id: Int64; + age: Int64; + name: PlString; + }>; + const b = null as unknown as DataFrame<{ + id: Int64; + age: Int64; + fl: Float64; + }>; + expectType>(a.getColumn("age")); + expectType< + (Series | Series | Series)[] + >(a.getColumns()); + expectType>(a.drop("name")); + expectType>(a.drop(["name", "age"])); + expectType>(a.drop("name", "age")); + // expectType>(a.age); + expectType< + DataFrame<{ + id: Int64; + age: Int64; + age_right: Int64; + name: PlString; + fl: Float64; + }> + >(a.join(b, { on: ["id"] })); + expectType< + DataFrame<{ + id: Int64; + age: Int64; + ageRight: Int64; + name: PlString; + fl: Float64; + }> + >(a.join(b, { on: ["id"], suffix: "Right" })); + expectType< + DataFrame<{ + id: Int64; + id_right: Int64; + age: Int64; + name: PlString; + fl: Float64; + }> + >(a.join(b, { leftOn: "id", rightOn: ["age"] })); + expectType< + DataFrame<{ + id: Int64; + id_right: Int64; + age: Int64; + age_right: Int64; + name: PlString; + fl: Float64; + }> + >(a.join(b, { how: "cross" })); + }); + + it("folds", () => { + const df = DataFrame({ + a: [2, 1, 3], + b: [1, 2, 3], + c: [1.0, 2.0, 3.0], + }); + expectType>(df.fold((s1, s2) => s1.plus(s2))); + }); +}); diff --git a/package.json b/package.json index 3cdfbf71..b3d39c73 100644 --- a/package.json +++ b/package.json @@ -62,6 +62,7 @@ "chance": "^1.1.12", "jest": "^29.7.0", "source-map-support": "^0.5.21", + "ts-expect": "^1.3.0", "ts-jest": "^29.2.5", "ts-node": "^10.9.2", "typedoc": "^0.27.3", diff --git a/polars/dataframe.ts b/polars/dataframe.ts index b62bc7a4..4d417fe9 100644 --- a/polars/dataframe.ts +++ b/polars/dataframe.ts @@ -184,6 +184,41 @@ interface WriteMethods { writeAvro(destination: string | Writable, options?: WriteAvroOptions): void; } +export type Schema = Record; +type SchemaToSeriesRecord> = { + [K in keyof T]: K extends string ? Series : never; +}; + +type ExtractJoinKeys = T extends string[] ? T[number] : T; +type ExtractSuffix = T extends { suffix: infer Suffix } + ? Suffix + : "_right"; +export type JoinSchemas< + S1 extends Schema, + S2 extends Schema, + Opt extends JoinOptions, +> = Simplify< + { + [K1 in keyof S1]: S1[K1]; + } & { + [K2 in Exclude]: K2 extends keyof S1 ? never : S2[K2]; + } & { + [K_SUFFIXED in keyof S1 & + Exclude< + keyof S2, + Opt extends { how: "cross" } + ? never + : Opt extends Pick + ? ExtractJoinKeys + : Opt extends Pick + ? ExtractJoinKeys + : never + > as `${K_SUFFIXED extends string ? K_SUFFIXED : never}${ExtractSuffix}`]: K_SUFFIXED extends string + ? S2[K_SUFFIXED] + : never; + } +>; + /** * A DataFrame is a two-dimensional data structure that represents data as a table * with rows and columns. @@ -255,10 +290,9 @@ interface WriteMethods { * ╰─────┴─────┴─────╯ * ``` */ -export interface DataFrame = any> - extends Arithmetic>, - Sample>, - Arithmetic>, +export interface DataFrame + extends Arithmetic>, + Sample>, WriteMethods, Serialize, GroupByOps { @@ -275,7 +309,7 @@ export interface DataFrame = any> /** * Very cheap deep clone. */ - clone(): DataFrame; + clone(): DataFrame; /** * __Summary statistics for a DataFrame.__ * @@ -346,14 +380,14 @@ export interface DataFrame = any> * ╰─────┴─────╯ * ``` */ - drop(name: U): DataFrame>>; + drop(name: U): DataFrame>>; drop( names: U, - ): DataFrame>>; + ): DataFrame>>; drop( name: U, ...names: V - ): DataFrame>>; + ): DataFrame>>; /** * __Return a new DataFrame where the null values are dropped.__ * @@ -379,9 +413,9 @@ export interface DataFrame = any> * └─────┴─────┴─────┘ * ``` */ - dropNulls(column: keyof T): DataFrame; - dropNulls(columns: (keyof T)[]): DataFrame; - dropNulls(...columns: (keyof T)[]): DataFrame; + dropNulls(column: keyof S): DataFrame; + dropNulls(columns: (keyof S)[]): DataFrame; + dropNulls(...columns: (keyof S)[]): DataFrame; /** * __Explode `DataFrame` to long format by exploding a column with Lists.__ * ___ @@ -464,7 +498,7 @@ export interface DataFrame = any> * @param other DataFrame to vertically add. */ - extend(other: DataFrame): DataFrame; + extend(other: DataFrame): DataFrame; /** * Fill null/missing values by a filling strategy * @@ -478,7 +512,7 @@ export interface DataFrame = any> * - "one" * @returns DataFrame with None replaced with the filling strategy. */ - fillNull(strategy: FillNullStrategy): DataFrame; + fillNull(strategy: FillNullStrategy): DataFrame; /** * Filter the rows in the DataFrame based on a predicate expression. * ___ @@ -517,7 +551,7 @@ export interface DataFrame = any> * └─────┴─────┴─────┘ * ``` */ - filter(predicate: any): DataFrame; + filter(predicate: any): DataFrame; /** * Find the index of a column by name. * ___ @@ -533,7 +567,7 @@ export interface DataFrame = any> * 2 * ``` */ - findIdxByName(name: keyof T): number; + findIdxByName(name: keyof S): number; /** * __Apply a horizontal reduction on a DataFrame.__ * @@ -590,7 +624,13 @@ export interface DataFrame = any> * ] * ``` */ - fold(operation: (s1: Series, s2: Series) => Series): Series; + fold< + D extends DataType, + F extends ( + s1: SchemaToSeriesRecord[keyof S] | Series, + s2: SchemaToSeriesRecord[keyof S], + ) => Series, + >(operation: F): Series; /** * Check if DataFrame is equal to other. * ___ @@ -637,7 +677,7 @@ export interface DataFrame = any> * // column: pl.Series * ``` */ - getColumn(name: U): T[U]; + getColumn(name: U): SchemaToSeriesRecord[U]; getColumn(name: string): Series; /** * Get the DataFrame as an Array of Series. @@ -658,7 +698,7 @@ export interface DataFrame = any> * // columns: (pl.Series | pl.Series | pl.Series)[] * ``` */ - getColumns(): T[keyof T][]; + getColumns(): SchemaToSeriesRecord[keyof S][]; /** * Start a groupby operation. * ___ @@ -705,7 +745,7 @@ export interface DataFrame = any> * ╰─────┴─────┴─────╯ * ``` */ - head(length?: number): DataFrame; + head(length?: number): DataFrame; /** * Return a new DataFrame grown horizontally by stacking multiple Series to it. * @param columns - array of Series or DataFrame to stack @@ -745,13 +785,12 @@ export interface DataFrame = any> * ╰─────┴─────┴─────┴───────╯ * ``` */ - hstack = any>( - columns: DataFrame, - ): DataFrame>; + hstack( + columns: DataFrame, + ): DataFrame>; hstack( columns: U, - ): DataFrame>; - hstack(columns: Array | DataFrame): DataFrame; + ): DataFrame>; hstack(columns: Array | DataFrame, inPlace?: boolean): void; /** * Insert a Series at a certain column index. This operation is in place. @@ -762,7 +801,7 @@ export interface DataFrame = any> /** * Interpolate intermediate values. The interpolation method is linear. */ - interpolate(): DataFrame; + interpolate(): DataFrame; /** * Get a mask of all duplicated rows in this DataFrame. */ @@ -809,21 +848,24 @@ export interface DataFrame = any> * ╰─────┴─────┴─────┴───────╯ * ``` */ - join( - other: DataFrame, - options: { on: ValueOrArray } & Omit< + join< + S2 extends Schema, + const Opts extends { on: ValueOrArray } & Omit< JoinOptions, "leftOn" | "rightOn" >, - ): DataFrame; - join( - other: DataFrame, - options: { - leftOn: ValueOrArray; - rightOn: ValueOrArray; + >(other: DataFrame, options: Opts): DataFrame>; + join< + S2 extends Schema, + const Opts extends { + leftOn: ValueOrArray; + rightOn: ValueOrArray; } & Omit, - ): DataFrame; - join(other: DataFrame, options: { how: "cross"; suffix?: string }): DataFrame; + >(other: DataFrame, options: Opts): DataFrame>; + join( + other: DataFrame, + options: Opts, + ): DataFrame>; /** * Perform an asof join. This is similar to a left-join except that we @@ -930,12 +972,12 @@ export interface DataFrame = any> forceParallel?: boolean; }, ): DataFrame; - lazy(): LazyDataFrame; + lazy(): LazyDataFrame; /** * Get first N rows as DataFrame. * @see {@link head} */ - limit(length?: number): DataFrame; + limit(length?: number): DataFrame; map( // TODO: strong types for the mapping function func: (row: any[], i: number, arr: any[][]) => ReturnT, @@ -963,8 +1005,8 @@ export interface DataFrame = any> * ╰─────┴─────┴──────╯ * ``` */ - max(): DataFrame; - max(axis: 0): DataFrame; + max(): DataFrame; + max(axis: 0): DataFrame; max(axis: 1): Series; /** * Aggregate the columns of this DataFrame to their mean value. @@ -973,8 +1015,8 @@ export interface DataFrame = any> * @param axis - either 0 or 1 * @param nullStrategy - this argument is only used if axis == 1 */ - mean(): DataFrame; - mean(axis: 0): DataFrame; + mean(): DataFrame; + mean(axis: 0): DataFrame; mean(axis: 1): Series; mean(axis: 1, nullStrategy?: "ignore" | "propagate"): Series; /** @@ -998,7 +1040,7 @@ export interface DataFrame = any> * ╰─────┴─────┴──────╯ * ``` */ - median(): DataFrame; + median(): DataFrame; /** * Unpivot a DataFrame from wide to long format. * ___ @@ -1055,8 +1097,8 @@ export interface DataFrame = any> * ╰─────┴─────┴──────╯ * ``` */ - min(): DataFrame; - min(axis: 0): DataFrame; + min(): DataFrame; + min(axis: 0): DataFrame; min(axis: 1): Series; /** * Get number of chunks used by the ChunkedArrays of this DataFrame. @@ -1084,13 +1126,13 @@ export interface DataFrame = any> * ``` */ nullCount(): DataFrame<{ - [K in keyof T]: Series, K & string>; + [K in keyof S]: JsToDtype; }>; partitionBy( cols: string | string[], stable?: boolean, includeKey?: boolean, - ): DataFrame[]; + ): DataFrame[]; partitionBy( cols: string | string[], stable: boolean, @@ -1208,13 +1250,13 @@ export interface DataFrame = any> * ╰─────┴─────┴──────╯ * ``` */ - quantile(quantile: number): DataFrame; + quantile(quantile: number): DataFrame; /** * __Rechunk the data in this DataFrame to a contiguous allocation.__ * * This will make sure all subsequent operations have optimal and predictable performance. */ - rechunk(): DataFrame; + rechunk(): DataFrame; /** * __Rename column names.__ * ___ @@ -1246,9 +1288,9 @@ export interface DataFrame = any> * ╰───────┴─────┴─────╯ * ``` */ - rename>>( + rename>>( mapping: U, - ): DataFrame<{ [K in keyof T as U[K] extends string ? U[K] : K]: T[K] }>; + ): DataFrame<{ [K in keyof S as U[K] extends string ? U[K] : K]: S[K] }>; rename(mapping: Record): DataFrame; /** * Replace a column at an index location. @@ -1333,7 +1375,7 @@ export interface DataFrame = any> * // } * ``` */ - get schema(): { [K in keyof T]: T[K]["dtype"] }; + get schema(): S; /** * Select columns from this DataFrame. * ___ @@ -1368,8 +1410,8 @@ export interface DataFrame = any> * └─────┘ * ``` */ - select(...columns: U[]): DataFrame<{ [P in U]: T[P] }>; - select(...columns: ExprOrString[]): DataFrame; + select(...columns: U[]): DataFrame<{ [P in U]: S[P] }>; + select(...columns: ExprOrString[]): DataFrame; /** * Shift the values by a given period and fill the parts that will be empty due to this operation * with `Nones`. @@ -1410,8 +1452,8 @@ export interface DataFrame = any> * └──────┴──────┴──────┘ * ``` */ - shift(periods: number): DataFrame; - shift({ periods }: { periods: number }): DataFrame; + shift(periods: number): DataFrame; + shift({ periods }: { periods: number }): DataFrame; /** * Shift the values by a given period and fill the parts that will be empty due to this operation * with the result of the `fill_value` expression. @@ -1441,15 +1483,18 @@ export interface DataFrame = any> * └─────┴─────┴─────┘ * ``` */ - shiftAndFill(n: number, fillValue: number): DataFrame; + shiftAndFill(n: number, fillValue: number): DataFrame; shiftAndFill({ n, fillValue, - }: { n: number; fillValue: number }): DataFrame; + }: { + n: number; + fillValue: number; + }): DataFrame; /** * Shrink memory usage of this DataFrame to fit the exact capacity needed to hold the data. */ - shrinkToFit(): DataFrame; + shrinkToFit(): DataFrame; shrinkToFit(inPlace: true): void; shrinkToFit({ inPlace }: { inPlace: true }): void; /** @@ -1478,8 +1523,8 @@ export interface DataFrame = any> * └─────┴─────┴─────┘ * ``` */ - slice({ offset, length }: { offset: number; length: number }): DataFrame; - slice(offset: number, length: number): DataFrame; + slice({ offset, length }: { offset: number; length: number }): DataFrame; + slice(offset: number, length: number): DataFrame; /** * Sort the DataFrame by column. * ___ @@ -1494,7 +1539,7 @@ export interface DataFrame = any> descending?: boolean, nullsLast?: boolean, maintainOrder?: boolean, - ): DataFrame; + ): DataFrame; sort({ by, reverse, // deprecated @@ -1505,7 +1550,7 @@ export interface DataFrame = any> reverse?: boolean; // deprecated nullsLast?: boolean; maintainOrder?: boolean; - }): DataFrame; + }): DataFrame; sort({ by, descending, @@ -1515,7 +1560,7 @@ export interface DataFrame = any> descending?: boolean; nullsLast?: boolean; maintainOrder?: boolean; - }): DataFrame; + }): DataFrame; /** * Aggregate the columns of this DataFrame to their standard deviation value. * ___ @@ -1537,7 +1582,7 @@ export interface DataFrame = any> * ╰─────┴─────┴──────╯ * ``` */ - std(): DataFrame; + std(): DataFrame; /** * Aggregate the columns of this DataFrame to their mean value. * ___ @@ -1545,8 +1590,8 @@ export interface DataFrame = any> * @param axis - either 0 or 1 * @param nullStrategy - this argument is only used if axis == 1 */ - sum(): DataFrame; - sum(axis: 0): DataFrame; + sum(): DataFrame; + sum(axis: 0): DataFrame; sum(axis: 1): Series; sum(axis: 1, nullStrategy?: "ignore" | "propagate"): Series; /** @@ -1596,7 +1641,7 @@ export interface DataFrame = any> * ╰─────────┴─────╯ * ``` */ - tail(length?: number): DataFrame; + tail(length?: number): DataFrame; /** * Converts dataframe object into row oriented javascript objects * @example @@ -1610,7 +1655,7 @@ export interface DataFrame = any> * ``` * @category IO */ - toRecords(): { [K in keyof T]: DTypeToJs | null }[]; + toRecords(): { [K in keyof S]: DTypeToJs | null }[]; /** * compat with `JSON.stringify` @@ -1640,8 +1685,8 @@ export interface DataFrame = any> * ``` * @category IO */ - toObject(): { [K in keyof T]: DTypeToJs[] }; - toSeries(index?: number): T[keyof T]; + toObject(): { [K in keyof S]: DTypeToJs[] }; + toSeries(index?: number): SchemaToSeriesRecord[keyof S]; toString(): string; /** * Convert a ``DataFrame`` to a ``Series`` of type ``Struct`` @@ -1753,12 +1798,12 @@ export interface DataFrame = any> maintainOrder?: boolean, subset?: ColumnSelection, keep?: "first" | "last", - ): DataFrame; + ): DataFrame; unique(opts: { maintainOrder?: boolean; subset?: ColumnSelection; keep?: "first" | "last"; - }): DataFrame; + }): DataFrame; /** Decompose a struct into its fields. The fields will be inserted in to the `DataFrame` on the location of the `struct` type. @@ -1818,7 +1863,7 @@ export interface DataFrame = any> * ╰─────┴─────┴──────╯ * ``` */ - var(): DataFrame; + var(): DataFrame; /** * Grow this DataFrame vertically by stacking a DataFrame to it. * @param df - DataFrame to stack. @@ -1851,16 +1896,14 @@ export interface DataFrame = any> * ╰─────┴─────┴─────╯ * ``` */ - vstack(df: DataFrame): DataFrame; + vstack(df: DataFrame): DataFrame; /** * Return a new DataFrame with the column added or replaced. * @param column - Series, where the name of the Series refers to the column in the DataFrame. */ withColumn( column: Series, - ): DataFrame< - Simplify }> - >; + ): DataFrame>; withColumn(column: Series | Expr): DataFrame; withColumns(...columns: (Expr | Series)[]): DataFrame; /** @@ -1868,16 +1911,16 @@ export interface DataFrame = any> * @param existingName * @param newName */ - withColumnRenamed( + withColumnRenamed( existingName: Existing, replacement: New, - ): DataFrame<{ [K in keyof T as K extends Existing ? New : K]: T[K] }>; + ): DataFrame<{ [K in keyof S as K extends Existing ? New : K]: S[K] }>; withColumnRenamed(existing: string, replacement: string): DataFrame; - withColumnRenamed(opts: { + withColumnRenamed(opts: { existingName: Existing; replacement: New; - }): DataFrame<{ [K in keyof T as K extends Existing ? New : K]: T[K] }>; + }): DataFrame<{ [K in keyof S as K extends Existing ? New : K]: S[K] }>; withColumnRenamed(opts: { existing: string; replacement: string }): DataFrame; /** * Add a column at index 0 that counts the rows. @@ -1885,7 +1928,7 @@ export interface DataFrame = any> */ withRowCount(name?: string): DataFrame; /** @see {@link filter} */ - where(predicate: any): DataFrame; + where(predicate: any): DataFrame; /** Upsample a DataFrame at a regular frequency. @@ -1961,13 +2004,13 @@ shape: (7, 3) every: string, by?: string | string[], maintainOrder?: boolean, - ): DataFrame; + ): DataFrame; upsample(opts: { timeColumn: string; every: string; by?: string | string[]; maintainOrder?: boolean; - }): DataFrame; + }): DataFrame; } function prepareOtherArg(anyValue: any): Series { @@ -2025,11 +2068,11 @@ function mapPolarsTypeToJSONSchema(colType: DataType): string { /** * @ignore */ -export const _DataFrame = (_df: any): DataFrame => { +export const _DataFrame = (_df: any): DataFrame => { const unwrap = (method: string, ...args: any[]) => { return _df[method as any](...args); }; - const wrap = (method, ...args): DataFrame => { + const wrap = (method, ...args): DataFrame => { return _DataFrame(unwrap(method, ...args)); }; @@ -2089,7 +2132,7 @@ export const _DataFrame = (_df: any): DataFrame => { "text/html": limited.toHTML(), }; }, - get schema() { + get schema(): any { return this.getColumns().reduce((acc, curr) => { acc[curr.name] = curr.dtype; @@ -2100,7 +2143,7 @@ export const _DataFrame = (_df: any): DataFrame => { return wrap("clone"); }, describe() { - const describeCast = (df: DataFrame) => { + const describeCast = (df: DataFrame) => { return DataFrame( df.getColumns().map((s) => { if (s.isNumeric() || s.isBoolean()) { @@ -2116,7 +2159,7 @@ export const _DataFrame = (_df: any): DataFrame => { describeCast(this.min()), describeCast(this.max()), describeCast(this.median()), - ]); + ] as any); summary.insertAtIdx( 0, Series("describe", ["mean", "std", "min", "max", "median"]), @@ -2178,7 +2221,7 @@ export const _DataFrame = (_df: any): DataFrame => { findIdxByName(name) { return unwrap("findIdxByName", name); }, - fold(fn: (s1, s2) => Series) { + fold(fn: (s1, s2) => any) { if (this.width === 1) { return this.toSeries(0); } @@ -2228,7 +2271,7 @@ export const _DataFrame = (_df: any): DataFrame => { by, ); }, - upsample(opts, every?, by?, maintainOrder?) { + upsample(opts, every?, by?, maintainOrder?): any { let timeColumn; if (typeof opts === "string") { timeColumn = opts; @@ -2282,7 +2325,7 @@ export const _DataFrame = (_df: any): DataFrame => { isDuplicated: () => _Series(_df.isDuplicated()) as any, isEmpty: () => _df.height === 0, isUnique: () => _Series(_df.isUnique()) as any, - join(other: DataFrame, options): DataFrame { + join(other, options): any { options = { how: "inner", ...options }; const on = columnOrColumns(options.on); const how = options.how; @@ -2309,7 +2352,7 @@ export const _DataFrame = (_df: any): DataFrame => { .joinAsof(other.lazy(), options as any) .collectSync(); }, - lazy: () => _LazyDataFrame(_df.lazy()), + lazy: () => _LazyDataFrame(_df.lazy()) as unknown as LazyDataFrame, limit: (length = 5) => wrap("head", length), max(axis = 0) { if (axis === 1) { @@ -2401,7 +2444,7 @@ export const _DataFrame = (_df: any): DataFrame => { rechunk() { return wrap("rechunk"); }, - rename(mapping) { + rename(mapping): any { const df = this.clone(); for (const [column, new_col] of Object.entries(mapping)) { (df as any).inner().rename(column, new_col); @@ -2462,7 +2505,7 @@ export const _DataFrame = (_df: any): DataFrame => { return wrap("select", columnOrColumnsStrict(selection as any)); }, shift: (opt) => wrap("shift", opt?.periods ?? opt), - shiftAndFill(n: any, fillValue?: number | undefined) { + shiftAndFill(n: any, fillValue?: number | undefined): any { if (typeof n === "number" && fillValue) { return _DataFrame(_df).lazy().shiftAndFill(n, fillValue).collectSync(); } @@ -2576,7 +2619,7 @@ export const _DataFrame = (_df: any): DataFrame => { return { data, schema: { fields } }; }, - toObject() { + toObject(): any { return this.getColumns().reduce((acc, curr) => { acc[curr.name] = curr.toArray(); @@ -2733,7 +2776,7 @@ export const _DataFrame = (_df: any): DataFrame => { .withColumns(columns) .collectSync({ noOptimization: true }); }, - withColumnRenamed(opt, replacement?) { + withColumnRenamed(opt, replacement?): any { if (typeof opt === "string") { return this.rename({ [opt]: replacement }); } @@ -2756,10 +2799,10 @@ export const _DataFrame = (_df: any): DataFrame => { divideBy: (other) => wrap("div", prepareOtherArg(other).inner()), multiplyBy: (other) => wrap("mul", prepareOtherArg(other).inner()), modulo: (other) => wrap("rem", prepareOtherArg(other).inner()), - } as DataFrame; + } as DataFrame; return new Proxy(df, { - get(target: DataFrame, prop, receiver) { + get(target: DataFrame, prop, receiver) { if (typeof prop === "string" && target.columns.includes(prop)) { return target.getColumn(prop); } @@ -2768,7 +2811,7 @@ export const _DataFrame = (_df: any): DataFrame => { } return Reflect.get(target, prop, receiver); }, - set(target: DataFrame, prop, receiver) { + set(target: DataFrame, prop, receiver) { if (Series.isSeries(receiver)) { if (typeof prop === "string" && target.columns.includes(prop)) { const idx = target.columns.indexOf(prop); @@ -2862,21 +2905,22 @@ export interface DataFrameConstructor extends Deserialize { data: T1, options?: DataFrameOptions, ): DataFrame<{ - [K in T1[number] as K["name"]]: K; + [K in T1[number] as K["name"]]: K["dtype"]; }>; >>( data: T2, options?: DataFrameOptions, ): DataFrame<{ - [K in keyof T2]: K extends string - ? Series, K> - : never; + [K in keyof T2]: K extends string ? JsToDtype : never; }>; (data: any, options?: DataFrameOptions): DataFrame; isDataFrame(arg: any): arg is DataFrame; } -function DataFrameConstructor(data?, options?): DataFrame { +function DataFrameConstructor( + data?, + options?, +): DataFrame { if (!data) { return _DataFrame(objToDF({})); } diff --git a/polars/lazy/dataframe.ts b/polars/lazy/dataframe.ts index 4b9d3d8a..7f1d4388 100644 --- a/polars/lazy/dataframe.ts +++ b/polars/lazy/dataframe.ts @@ -1,4 +1,9 @@ -import { type DataFrame, _DataFrame } from "../dataframe"; +import { + type DataFrame, + type JoinSchemas, + type Schema, + _DataFrame, +} from "../dataframe"; import pli from "../internals/polars_internal"; import type { Series } from "../series"; import type { Deserialize, GroupByOps, Serialize } from "../shared_traits"; @@ -12,6 +17,7 @@ import { type ColumnSelection, type ColumnsOrExpr, type ExprOrString, + type Simplify, type ValueOrArray, columnOrColumnsStrict, selectionToExprList, @@ -24,7 +30,9 @@ const inspect = Symbol.for("nodejs.util.inspect.custom"); /** * Representation of a Lazy computation graph / query. */ -export interface LazyDataFrame extends Serialize, GroupByOps { +export interface LazyDataFrame + extends Serialize, + GroupByOps { /** @ignore */ _ldf: any; [inspect](): string; @@ -33,8 +41,8 @@ export interface LazyDataFrame extends Serialize, GroupByOps { /** * Cache the result once the execution of the physical plan hits this node. */ - cache(): LazyDataFrame; - clone(): LazyDataFrame; + cache(): LazyDataFrame; + clone(): LazyDataFrame; /** * * Collect into a DataFrame. @@ -57,8 +65,8 @@ export interface LazyDataFrame extends Serialize, GroupByOps { * @return DataFrame * */ - collect(opts?: LazyOptions): Promise; - collectSync(opts?: LazyOptions): DataFrame; + collect(opts?: LazyOptions): Promise>; + collectSync(opts?: LazyOptions): DataFrame; /** * A string representation of the optimized query plan. */ @@ -71,16 +79,21 @@ export interface LazyDataFrame extends Serialize, GroupByOps { * Remove one or multiple columns from a DataFrame. * @param columns - column or list of columns to be removed */ - drop(name: string): LazyDataFrame; - drop(names: string[]): LazyDataFrame; - drop(name: string, ...names: string[]): LazyDataFrame; + drop(name: U): LazyDataFrame>>; + drop( + names: U, + ): LazyDataFrame>>; + drop( + name: U, + ...names: V + ): LazyDataFrame>>; /** * Drop rows with null values from this DataFrame. * This method only drops nulls row-wise if any single value of the row is null. */ - dropNulls(column: string): LazyDataFrame; - dropNulls(columns: string[]): LazyDataFrame; - dropNulls(...columns: string[]): LazyDataFrame; + dropNulls(column: string): LazyDataFrame; + dropNulls(columns: string[]): LazyDataFrame; + dropNulls(...columns: string[]): LazyDataFrame; /** * Explode lists to long format. */ @@ -110,16 +123,16 @@ export interface LazyDataFrame extends Serialize, GroupByOps { at any point without it being considered a breaking change. * */ - fetch(numRows?: number): Promise; - fetch(numRows: number, opts: LazyOptions): Promise; + fetch(numRows?: number): Promise>; + fetch(numRows: number, opts: LazyOptions): Promise>; /** Behaves the same as fetch, but will perform the actions synchronously */ - fetchSync(numRows?: number): DataFrame; - fetchSync(numRows: number, opts: LazyOptions): DataFrame; + fetchSync(numRows?: number): DataFrame; + fetchSync(numRows: number, opts: LazyOptions): DataFrame; /** * Fill missing values * @param fillValue value to fill the missing values with */ - fillNull(fillValue: string | number | Expr): LazyDataFrame; + fillNull(fillValue: string | number | Expr): LazyDataFrame; /** * Filter the rows in the DataFrame based on a predicate expression. * @param predicate - Expression that evaluates to a boolean Series. @@ -144,11 +157,11 @@ export interface LazyDataFrame extends Serialize, GroupByOps { * └─────┴─────┴─────┘ * ``` */ - filter(predicate: Expr | string): LazyDataFrame; + filter(predicate: Expr | string): LazyDataFrame; /** * Get the first row of the DataFrame. */ - first(): DataFrame; + first(): DataFrame; /** * Start a groupby operation. */ @@ -161,7 +174,7 @@ export interface LazyDataFrame extends Serialize, GroupByOps { * Consider using the `fetch` operation. * The `fetch` operation will truly load the first `n`rows lazily. */ - head(length?: number): LazyDataFrame; + head(length?: number): LazyDataFrame; inner(): any; /** * __SQL like joins.__ @@ -200,26 +213,38 @@ export interface LazyDataFrame extends Serialize, GroupByOps { * ╰─────┴─────┴─────┴───────╯ * ``` */ - join( - other: LazyDataFrame, - joinOptions: { on: ValueOrArray } & LazyJoinOptions, - ): LazyDataFrame; - join( - other: LazyDataFrame, - joinOptions: { - leftOn: ValueOrArray; - rightOn: ValueOrArray; - } & LazyJoinOptions, - ): LazyDataFrame; - join( - other: LazyDataFrame, - options: { + join< + S2 extends Schema, + const Opts extends { on: ValueOrArray } & Omit< + LazyJoinOptions, + "leftOn" | "rightOn" + >, + >( + other: LazyDataFrame, + joinOptions: Opts, + ): LazyDataFrame>; + join< + S2 extends Schema, + const Opts extends { + leftOn: ValueOrArray; + rightOn: ValueOrArray; + } & Omit, + >( + other: LazyDataFrame, + joinOptions: Opts, + ): LazyDataFrame>; + join< + S2 extends Schema, + const Opts extends { how: "cross"; suffix?: string; allowParallel?: boolean; forceParallel?: boolean; }, - ): LazyDataFrame; + >( + other: LazyDataFrame, + options: Opts, + ): LazyDataFrame>; /** * Perform an asof join. This is similar to a left-join except that we @@ -333,23 +358,23 @@ export interface LazyDataFrame extends Serialize, GroupByOps { /** * Get the last row of the DataFrame. */ - last(): LazyDataFrame; + last(): LazyDataFrame; /** * @see {@link head} */ - limit(n?: number): LazyDataFrame; + limit(n?: number): LazyDataFrame; /** * @see {@link DataFrame.max} */ - max(): LazyDataFrame; + max(): LazyDataFrame; /** * @see {@link DataFrame.mean} */ - mean(): LazyDataFrame; + mean(): LazyDataFrame; /** * @see {@link DataFrame.median} */ - median(): LazyDataFrame; + median(): LazyDataFrame; /** * @see {@link DataFrame.unpivot} */ @@ -357,43 +382,44 @@ export interface LazyDataFrame extends Serialize, GroupByOps { /** * @see {@link DataFrame.min} */ - min(): LazyDataFrame; + min(): LazyDataFrame; /** * @see {@link DataFrame.quantile} */ - quantile(quantile: number): LazyDataFrame; + quantile(quantile: number): LazyDataFrame; /** * @see {@link DataFrame.rename} */ + rename>>( + mapping: U, + ): LazyDataFrame<{ [K in keyof S as U[K] extends string ? U[K] : K]: S[K] }>; rename(mapping: Record): LazyDataFrame; /** * Reverse the DataFrame. */ - reverse(): LazyDataFrame; + reverse(): LazyDataFrame; /** * @see {@link DataFrame.select} */ + select(...columns: U[]): LazyDataFrame<{ [P in U]: S[P] }>; select(column: ExprOrString): LazyDataFrame; select(columns: ExprOrString[]): LazyDataFrame; select(...columns: ExprOrString[]): LazyDataFrame; /** * @see {@link DataFrame.shift} */ - shift(periods: number): LazyDataFrame; - shift(opts: { periods: number }): LazyDataFrame; + shift(periods: number): LazyDataFrame; + shift(opts: { periods: number }): LazyDataFrame; /** * @see {@link DataFrame.shiftAndFill} */ - shiftAndFill(n: number, fillValue: number): LazyDataFrame; - shiftAndFill(opts: { - n: number; - fillValue: number; - }): LazyDataFrame; + shiftAndFill(n: number, fillValue: number): LazyDataFrame; + shiftAndFill(opts: { n: number; fillValue: number }): LazyDataFrame; /** * @see {@link DataFrame.slice} */ - slice(offset: number, length: number): LazyDataFrame; - slice(opts: { offset: number; length: number }): LazyDataFrame; + slice(offset: number, length: number): LazyDataFrame; + slice(opts: { offset: number; length: number }): LazyDataFrame; /** * @see {@link DataFrame.sort} */ @@ -402,26 +428,26 @@ export interface LazyDataFrame extends Serialize, GroupByOps { descending?: ValueOrArray, nullsLast?: boolean, maintainOrder?: boolean, - ): LazyDataFrame; + ): LazyDataFrame; sort(opts: { by: ColumnsOrExpr; descending?: ValueOrArray; nullsLast?: boolean; maintainOrder?: boolean; - }): LazyDataFrame; + }): LazyDataFrame; /** * @see {@link DataFrame.std} */ - std(): LazyDataFrame; + std(): LazyDataFrame; /** * Aggregate the columns in the DataFrame to their sum value. */ - sum(): LazyDataFrame; + sum(): LazyDataFrame; /** * Get the last `n` rows of the DataFrame. * @see {@link DataFrame.tail} */ - tail(length?: number): LazyDataFrame; + tail(length?: number): LazyDataFrame; /** * compatibility with `JSON.stringify` */ @@ -437,16 +463,16 @@ export interface LazyDataFrame extends Serialize, GroupByOps { maintainOrder?: boolean, subset?: ColumnSelection, keep?: "first" | "last", - ): LazyDataFrame; + ): LazyDataFrame; unique(opts: { maintainOrder?: boolean; subset?: ColumnSelection; keep?: "first" | "last"; - }): LazyDataFrame; + }): LazyDataFrame; /** * Aggregate the columns in the DataFrame to their variance value. */ - var(): LazyDataFrame; + var(): LazyDataFrame; /** * Add or overwrite column in a DataFrame. * @param expr - Expression that evaluates to column. @@ -459,6 +485,10 @@ export interface LazyDataFrame extends Serialize, GroupByOps { */ withColumns(exprs: (Expr | Series)[]): LazyDataFrame; withColumns(...exprs: (Expr | Series)[]): LazyDataFrame; + withColumnRenamed( + existing: Existing, + replacement: New, + ): LazyDataFrame<{ [K in keyof S as K extends Existing ? New : K]: S[K] }>; withColumnRenamed(existing: string, replacement: string): LazyDataFrame; /** * Add a column at index 0 that counts the rows. diff --git a/polars/series/index.ts b/polars/series/index.ts index a1eb106f..ea5daca0 100644 --- a/polars/series/index.ts +++ b/polars/series/index.ts @@ -1,6 +1,7 @@ import { DataFrame, _DataFrame } from "../dataframe"; import { DTYPE_TO_FFINAME, DataType, type Optional } from "../datatypes"; import type { + Bool, DTypeToJs, DTypeToJsLoose, DtypeToJsName, @@ -139,11 +140,17 @@ export interface Series argSort({ descending, nullsLast, - }: { descending?: boolean; nullsLast?: boolean }): Series; + }: { + descending?: boolean; + nullsLast?: boolean; + }): Series; argSort({ reverse, // deprecated nullsLast, - }: { reverse?: boolean; nullsLast?: boolean }): Series; + }: { + reverse?: boolean; + nullsLast?: boolean; + }): Series; /** * __Rename this Series.__ * @@ -405,7 +412,7 @@ export interface Series /** * Check if this Series is a Boolean. */ - isBoolean(): boolean; + isBoolean(): this is Series; /** * Check if this Series is a DataTime. */ @@ -507,7 +514,19 @@ export interface Series /** * Check if this Series datatype is numeric. */ - isNumeric(): boolean; + isNumeric(): this is Series< + | DataType.Int8 + | DataType.Int16 + | DataType.Int32 + | DataType.Int64 + | DataType.UInt8 + | DataType.UInt16 + | DataType.UInt32 + | DataType.UInt64 + | DataType.Float32 + | DataType.Float64 + | DataType.Decimal + >; /** * __Get mask of unique values.__ * ___ diff --git a/yarn.lock b/yarn.lock index cce09f53..01932fd0 100644 --- a/yarn.lock +++ b/yarn.lock @@ -3138,6 +3138,7 @@ __metadata: chance: "npm:^1.1.12" jest: "npm:^29.7.0" source-map-support: "npm:^0.5.21" + ts-expect: "npm:^1.3.0" ts-jest: "npm:^29.2.5" ts-node: "npm:^10.9.2" typedoc: "npm:^0.27.3" @@ -3767,6 +3768,13 @@ __metadata: languageName: node linkType: hard +"ts-expect@npm:^1.3.0": + version: 1.3.0 + resolution: "ts-expect@npm:1.3.0" + checksum: 10/3f33ba17f9cdad550e9b12a23d78b19836d4fd5741da590c25426ea31ceb8f1765feeeeef2c529bab3e2350b5f9ce5fbbb207012db2c7c4b4b116d1b7fed7ec5 + languageName: node + linkType: hard + "ts-jest@npm:^29.2.5": version: 29.2.5 resolution: "ts-jest@npm:29.2.5"