Skip to content

Latest commit

 

History

History
185 lines (120 loc) · 7.28 KB

runner.md

File metadata and controls

185 lines (120 loc) · 7.28 KB

Module: runner

View source on GitHub

A general purpose runner for TF-GNN.

Classes

class ContextLabelFn: Reads out a tfgnn.Field from the GraphTensor context.

class DatasetProvider: Helper class that provides a standard way to create an ABC using inheritance.

class DotProductLinkPrediction: Implements edge score as dot product of features of endpoint nodes.

class FitOrSkipPadding: Calculates fit or skip SizeConstraints for GraphTensor padding.

class GraphBinaryClassification: Graph binary (or multi-label) classification from pooled node states.

class GraphMeanAbsoluteError: Regression from pooled node states with mean absolute error.

class GraphMeanAbsolutePercentageError: Regression from pooled node states with mean absolute percentage error.

class GraphMeanSquaredError: Regression from pooled node states with mean squared error.

class GraphMeanSquaredLogScaledError: Regression from pooled node states with mean squared log scaled error.

class GraphMeanSquaredLogarithmicError: Regression from pooled node states with mean squared logarithmic error.

class GraphMulticlassClassification: Graph multiclass classification from pooled node states.

class GraphTensorPadding: Collects GraphtTensor padding helpers.

class GraphTensorProcessorFn: A class for GraphTensor processing.

class HadamardProductLinkPrediction: Implements edge score as hadamard product of features of endpoint nodes.

class IntegratedGradientsExporter: Exports a Keras model with an additional integrated gradients signature.

class KerasModelExporter: Exports a Keras model (with Keras API) via tf.keras.models.save_model.

class KerasTrainer: Trains using the tf.keras.Model.fit training loop.

class KerasTrainerCheckpointOptions: Provides Keras Checkpointing related configuration options.

class KerasTrainerOptions: Provides Keras training related options.

class ModelExporter: Saves a Keras model.

class NodeBinaryClassification: Node binary (or multi-label) classification via structured readout.

class NodeMeanAbsoluteError: Node regression with mean absolute error via structured readout.

class NodeMeanAbsolutePercentageError: Node regression with mean absolute percentage error via structured readout.

class NodeMeanSquaredError: Node regression with mean squared error via structured readout.

class NodeMeanSquaredLogScaledError: Node regression with mean squared log scaled error via structured readout.

class NodeMeanSquaredLogarithmicError: Node regression with mean squared log error via structured readout.

class NodeMulticlassClassification: Node multiclass classification via structured readout.

class ParameterServerStrategy: A ParameterServerStrategy convenience wrapper.

class PassthruDatasetProvider: Builds a tf.data.Dataset from a pass thru dataset.

class PassthruSampleDatasetsProvider: Builds a sampled tf.data.Dataset from multiple pass thru datasets.

class RootNodeBinaryClassification: Root node binary (or multi-label) classification.

class RootNodeLabelFn: Reads out a tfgnn.Field from the GraphTensor root (i.e. first) node.

class RootNodeMeanAbsoluteError: Root node regression with mean absolute error.

class RootNodeMeanAbsolutePercentageError: Root node regression with mean absolute percentage error.

class RootNodeMeanSquaredError: Root node regression with mean squared error.

class RootNodeMeanSquaredLogScaledError: Root node regression with mean squared log scaled error.

class RootNodeMeanSquaredLogarithmicError: Root node regression with mean squared logarithmic error.

class RootNodeMulticlassClassification: Root node multiclass classification.

class RunResult: Holds the return values of run(...).

class SampleTFRecordDatasetsProvider: Builds a sampling tf.data.Dataset from multiple filenames.

class SimpleDatasetProvider: Builds a tf.data.Dataset from a list of files.

class SimpleSampleDatasetsProvider: Builds a sampling tf.data.Dataset from multiple filenames.

class SubmoduleExporter: Exports a Keras submodule.

class TFDataServiceConfig: Provides tf.data service related configuration options.

class TFRecordDatasetProvider: Builds a tf.data.Dataset from a list of files.

class TPUStrategy: A TPUStrategy convenience wrapper.

class Task: Defines a learning objective for a GNN.

class TightPadding: Calculates tight SizeConstraints for GraphTensor padding.

class Trainer: A class for training and validation of a Keras model.

Functions

export_model(...): Exports a Keras model without traces s.t. it is loadable without TF-GNN.

incrementing_model_dir(...): Create, given some dirname, an incrementing model directory.

integrated_gradients(...): Integrated gradients.

one_node_per_component(...): Returns a Mapping node_set_name: 1 for every node set in gtspec.

run(...): Runs training (and validation) of a model on task(s) with the given data.

Type Aliases

Loss

Losses

Metric

Metrics

Predictions