A general purpose runner for TF-GNN.
class ContextLabelFn
: Reads out a tfgnn.Field
from the GraphTensor
context.
class DatasetProvider
: Helper class that
provides a standard way to create an ABC using inheritance.
class DotProductLinkPrediction
:
Implements edge score as dot product of features of endpoint nodes.
class FitOrSkipPadding
: Calculates fit or skip
SizeConstraints
for GraphTensor
padding.
class GraphBinaryClassification
:
Graph binary (or multi-label) classification from pooled node states.
class GraphMeanAbsoluteError
: Regression
from pooled node states with mean absolute error.
class GraphMeanAbsolutePercentageError
:
Regression from pooled node states with mean absolute percentage error.
class GraphMeanSquaredError
: Regression
from pooled node states with mean squared error.
class GraphMeanSquaredLogScaledError
:
Regression from pooled node states with mean squared log scaled error.
class GraphMeanSquaredLogarithmicError
:
Regression from pooled node states with mean squared logarithmic error.
class GraphMulticlassClassification
:
Graph multiclass classification from pooled node states.
class GraphTensorPadding
: Collects
GraphtTensor
padding helpers.
class GraphTensorProcessorFn
: A class
for GraphTensor
processing.
class HadamardProductLinkPrediction
:
Implements edge score as hadamard product of features of endpoint nodes.
class IntegratedGradientsExporter
:
Exports a Keras model with an additional integrated gradients signature.
class KerasModelExporter
: Exports a Keras
model (with Keras API) via tf.keras.models.save_model
.
class KerasTrainer
: Trains using the
tf.keras.Model.fit
training loop.
class KerasTrainerCheckpointOptions
:
Provides Keras Checkpointing related configuration options.
class KerasTrainerOptions
: Provides Keras
training related options.
class ModelExporter
: Saves a Keras model.
class NodeBinaryClassification
: Node
binary (or multi-label) classification via structured readout.
class NodeMeanAbsoluteError
: Node
regression with mean absolute error via structured readout.
class NodeMeanAbsolutePercentageError
:
Node regression with mean absolute percentage error via structured readout.
class NodeMeanSquaredError
: Node
regression with mean squared error via structured readout.
class NodeMeanSquaredLogScaledError
:
Node regression with mean squared log scaled error via structured readout.
class NodeMeanSquaredLogarithmicError
:
Node regression with mean squared log error via structured readout.
class NodeMulticlassClassification
:
Node multiclass classification via structured readout.
class ParameterServerStrategy
: A
ParameterServerStrategy
convenience wrapper.
class PassthruDatasetProvider
: Builds a
tf.data.Dataset
from a pass thru dataset.
class PassthruSampleDatasetsProvider
:
Builds a sampled tf.data.Dataset
from multiple pass thru datasets.
class RootNodeBinaryClassification
:
Root node binary (or multi-label) classification.
class RootNodeLabelFn
: Reads out a
tfgnn.Field
from the GraphTensor
root (i.e. first) node.
class RootNodeMeanAbsoluteError
: Root
node regression with mean absolute error.
class RootNodeMeanAbsolutePercentageError
:
Root node regression with mean absolute percentage error.
class RootNodeMeanSquaredError
: Root
node regression with mean squared error.
class RootNodeMeanSquaredLogScaledError
:
Root node regression with mean squared log scaled error.
class RootNodeMeanSquaredLogarithmicError
:
Root node regression with mean squared logarithmic error.
class RootNodeMulticlassClassification
:
Root node multiclass classification.
class RunResult
: Holds the return values of
run(...)
.
class SampleTFRecordDatasetsProvider
:
Builds a sampling tf.data.Dataset
from multiple filenames.
class SimpleDatasetProvider
: Builds a
tf.data.Dataset
from a list of files.
class SimpleSampleDatasetsProvider
:
Builds a sampling tf.data.Dataset
from multiple filenames.
class SubmoduleExporter
: Exports a Keras
submodule.
class TFDataServiceConfig
: Provides tf.data
service related configuration options.
class TFRecordDatasetProvider
: Builds a
tf.data.Dataset
from a list of files.
class TPUStrategy
: A TPUStrategy
convenience
wrapper.
class Task
: Defines a learning objective for a GNN.
class TightPadding
: Calculates tight
SizeConstraints
for GraphTensor
padding.
class Trainer
: A class for training and validation of a
Keras model.
export_model(...)
: Exports a Keras model without
traces s.t. it is loadable without TF-GNN.
incrementing_model_dir(...)
: Create,
given some dirname
, an incrementing model directory.
integrated_gradients(...)
: Integrated
gradients.
one_node_per_component(...)
: Returns a
Mapping
node_set_name: 1
for every node set in gtspec
.
run(...)
: Runs training (and validation) of a model on
task(s) with the given data.