Skip to content
This repository has been archived by the owner on May 27, 2024. It is now read-only.

Latest commit

 

History

History
530 lines (431 loc) · 21.3 KB

RFC-0001-torch-function-for-methods.md

File metadata and controls

530 lines (431 loc) · 21.3 KB
Authors Hameer Abbasi, Edward Z. Yang and Ralf Gommers
Status Accepted
Type Proposal
Created 2020-01-24
Resolution TBD

Improving subclassing Tensor by propagating subclass instances

This RFC describes changes necessary to allow __torch_function__ to be used by methods of torch.Tensor in an attempt to make subclassing more accessible to the users of the class. This entails making an API for subclass views public, and a change in the signature of __torch_function__.

Motivation and Scope

Quoting [1], [2] and [3], the goals of this proposal are:

  1. Support subclassing torch.Tensor in Python
  2. Preserve torch.Tensor subclasses when calling torch functions on them
  3. Use the PyTorch API with torch.Tensor-like objects that are not torch.Tensor subclasses
  4. Preserve torch.Tensor subclasses when calling torch.Tensor methods.
  5. Propagating subclass instances correctly also with operators, using views/slices/indexing/etc.
  6. Preserve subclass attributes when using methods or views/slices/indexing.
  7. A way to insert code that operates on both functions and methods uniformly (so we can write a single function that overrides all operators).
  8. The ability to give external libraries a way to also define functions/methods that follow the __torch_function__ protocol.

Goals 1‒6 are explicitly about subclassing, goal 7 is already partially achieved via the __torch_function__ protocol (which we're proposing to extend to methods), and goal 8 is a by-product required to make overridden torch.Tensor subclass methods behave similar to torch.Tensor methods.

Achieving interoperability with NumPy and adopting its array protocols is out of scope for this proposal and we propose to defer it to a later proposal.

We propose to solve this problem with the following changes to PyTorch:

  1. Make methods, operators and properties of torch.Tensor go through the __torch_function__ machinery.
  2. Add a types argument to __torch_function__, to make it match NumPy's __array_function__.
  3. Add a new method to torch.Tensor, as_subclass, which creates a subtype view into the original object.
  4. Make torch.Tensor gain a generic implementation of __torch_function__.

Usage and Impact

Once this proposal is merged, users of subclasses of torch.Tensor will have a much more streamlined experience. Namely, the following code example will work as-is, without the need for any further modification:

class SubTensor(torch.Tensor):
    a = 1

t = SubTensor([1])
s = t.sum()
isinstance(s, SubTensor)  # True
s.a  # 1
i = t[0]
isinstance(i, SubTensor)  # True
i.a  # 1

s2 = t + torch.Tensor(1)
isinstance(s2, SubTensor)  # True
s2.a  # 1

s3 = torch.Tensor(1) + t
isinstance(s3, SubTensor)  # True
s3.a  # 1

Additionally, it will provide subclass authors the ability to also modify the results of methods, operators and properties in __torch_function__, along with regular function calls, and to modify the result to their specific use-case, perform logging, or otherwise change the result or the action of the method. For example:

import logging

class LoggingTensor(torch.Tensor):
    @classmethod
    def __torch_function__(cls, func, types, args, kwargs):
        logging.info(f"func: {func.__name__}, args: {args!r}, kwargs: {kwargs!r}")
        return super().__torch_function__(
            func,
            types,
            args,
            kwargs
        )

Assuming minimum logging level is set to logging.INFO, the following indicates the code run, with the logging output in the comments.

t = LoggingTensor([1])

t.sum()  # Tensor.sum, (LoggingTensor([1]),), {}
t[0]  # Tensor.__getitem__, (LoggingTensor([1]), 0,), {}

# This is already possible
torch.sum(t)  # sum, (LoggingTensor([1]),), {}

To make the protocol operate only on functions rather than methods, one can check for func not in type(self).__dict__.values(). To check for operators and/or indexing, one can check func.__name__.endswith("__").

Performance

There are a few requirements for the performance of this proposal, when implemented:

  1. No deterioration for function/method calls on torch.Tensor objects.
  2. No deterioration of current __torch_function__ overhead
  3. Sub-µs impact on the performance of subclasses not implementing __torch_function__.

Requirement 1 seems unachievable due to the structure of the code at this point, as:

  1. In methods defined in C++, self is excluded from the argument processing that gathers Tensor-likes in C++.
  2. Similar to point 1, C++ methods that take only self as a Tensor-like don't pass through this processing, and they will be required to.
  3. For methods defined in Python, the processing for handling __torch_function__ will need to be added, similar to the original __torch_function__ PR [5].

We think an overhead of sub-100 ns per method call is feasible.

Backwards Compatibility

With PyTorch master as of writing

PyTorch master pointed to commit hash 957a07ffbd13d8a805f4d718e0282efc5d2bff85 at the time of writing. Any classes implementing __torch_function__ based on the usage in this commit hash will break completely, due to the differing signature of the protocol. However, as a release hasn't been made with __torch_function__ in it, this is a minor- impact issue. This brings the design of __torch_function__ more in line with NumPy's __array_function__, and one familiar with NumPy's protocol could transition to PyTorch's take on it without too many surprises, with the caveat that it could also receive methods rather than functions. The release that __torch_function__ will make it into PyTorch is expected to be 1.5.0.

With NumPy

The implementation of this proposal will have no effect on how things interact with NumPy.

Detailed Description

Introduction

Subclasses are an important way to override functionality of classes. Given the popularity of PyTorch, a number of subclasses have sprung up, both within and outside PyTorch. It is important that functions operating on torch.Tensor, as well as methods on it, support passing through the appropriate subclasses, otherwise information about which type was passed into the function is lost. The same applies equally, if not more so, to operators and indexing.

In addition, there has been interest in adding a "universal hook" that operated on both functions and methods, perhaps modifying the control flow before returning the result. Such a hook already exists today in the form of __torch_function__, however, it only operates on functions and not on methods, and support for subclassed torch.Tensor objects in this protocol is limited.

Proposal

We propose the following signature change to __torch_function__, to make it match NumPy, other than the @classmethod decorator: [4]

class SubTensor(torch.Tensor):
    @classmethod
    def __torch_function__(cls, func, types, args, kwargs):
        # Implementation here

The reason for adding types to the signature is necessitated so we can check for support of the types if Tensor-likes coming in and we do not mix unrelated class trees.

Process followed during a function/method call

The process followed during a function/method call would be equivalent to:

  1. The dispatcher is called to extract the Tensor-likes.
  2. All Tensor-likes are checked for __torch_function__. If none exist, the internal implementation is called, and the final result is returned.
  3. A collection of types that implement __torch_function__ is created, with no guaranteed order other than that subclasses come before superclasses.
  4. For one instance of each type in types, __torch_function__ is called. The first such function or method to return something other than NotImplemented will be the final result. All exceptions will be propagated upward.
  5. If all __torch_function__ implementations return NotImplemented, a TypeError is raised with an appropriate error message.

In practice, for most PyTorch functions, the list of tensor-likes is already available and the dispatcher doesn't need to be called. Additionally, while equivalent to the code above, if the Tensor-likes are all Tensor or don't have an __torch_function__ implementation, the internal implementation is called immediately. This is done as a performance optimisation to avoid overhead for concrete Tensor objects.

It will be the job of the dispatcher to extract Tensor-like objects from the argument list, however, arguments of type Optional[Tensor] will be considered Tensor-like. If one gets a compound or dependent type such as List[Tensor] or Tuple[Tensor, ...] or Tuple[Tensor, int], the dispatcher will have the job of extracting an iterable of objects that could be Tensor-like.

Generic implementation of __torch_function__

torch.Tensor will gain a generic __torch_function__ of the following form:

class Tensor:
    @classmethod
    def __torch_function__(cls, func, types, args, kwargs):
        if not all(issubclass(cls, t) for t in types):
            return NotImplemented
        
        # Defer to internal implementation
        ret = func._implementation(*args, **kwargs)
        if cls is not Tensor and isinstance(ret, Tensor):
            ret = ret.as_subclass(cls)
        return ret

This method has the effect of passing through subclasses through all functions/methods as intended.

This corresponds roughly to the implementation numpy.ndarray gains in [4], except for the fact that subclasses are passed through via another internal mechanism (namely the __array_finalize__ protocol) there, as well as the fact that we are checking subclassing against cls instead of Tensor. This has the side-effect of ensuring unrelated class trees are not merged, which is an inconsistency in NumPy's own design. Specifically, consider the example of two direct subclasses of torch.Tensor. Both will return NotImplemented, and therefore, the check will fail and TypeError will be raised.

Since subclasses are checked before superclasses in __torch_function__, it is guaranteed that the subclass implementation will be called first. In this instance, since cls is a subclass of all types, the code will continue. Since cls is not torch.Tensor, a view into the original data is created and returned.

This also works for all operators: __add__, __getitem__ and so on since in Python these operators are just dunder methods of the corresponding class.

Checking for compatibility

One can check for compatibility with supported classes in the following manner:

class MyTensor:
    HANDLED_CLASSES = (MyTensor, Tensor, ...)
    @classmethod
    def __torch_function__(cls, func, types, args, kwargs):
        if not issubclass(t, HANDLED_CLASSES) for t in types:
            return NotImplemented
        # Do further processing here.

Implementing a subset of the API

One can directly follow the following procedure to implement a subset of the API by using a hashmap to your own implementations of a function:

_TORCH_IMPLEMENTATIONS = {}

def implements(torch_function):
    def inner(f):
        _TORCH_IMPLEMENTATIONS[torch_function] = f
        return f
    return inner

@implements(torch.add)
def my_add(self, other):
    # Implementation here

class MyTensor:
    @classmethod
    def __torch_function__(cls, func, types, args, kwargs):
        compatible = ...
        if not compatible:
            return NotImplemented
        
        if func not in _TORCH_IMPLEMENTATIONS:
            return NotImplemented

        return _TORCH_IMPLEMENTATIONS[func](*args, **kwargs)

The need for super().__torch_function__

To access super, one would do the following:

class SubTensor(torch.Tensor):
    @classmethod
    def __torch_function__(cls, func, types, args, kwargs):
        # Pre-processing here
        val = super().__torch_function__(
            func,
            types
            args,
            kwargs
        )
        # Post processing here

To make the need for super() to be available concrete, let's consider the following scenario:

class SubTensor(torch.Tensor):
    @classmethod
    def __torch_function__(...):
        # Pre-processing
        ret = super().__torch_function__(
            func,
            types
            args,
            kwargs
        )
        # Post processing
        return ret

class SubSubTensor(SubTensor):
    def __add__(self, other):
        # Pre-processing
        ret = super().__add__(other)
        # Post-processing
        return ret

In this instance, processing would follow the __torch_function__ protocol. This means that control would end up in SubSubTensor.__add__, go to Tensor._add__, SubTensor.__torch_function__ from there and and then come to Tensor.__torch_function__, from where it would go to Tensor.__add__, and then back up the stack in the reverse order. This means that great care needs to be taken when writing SubTensor.__torch_function__ to take into account the fact that it has to handle subclass methods.

In general, control flow will follow this pattern:

Control flow diagram

The reason we use super().__torch_function__ instead of func directly is

  1. We do not know if there are other Tensor-likes that may need to be handled.
  2. Calling func directly would dispatch back to __torch_function__, leading to an infinite recursion.

Protocol support for external libraries

We will also recommend that all Tensor subclasses make their own methods that do not exist on torch.Tensor go through __torch_function__ via a decorator @torch_function_dispatch. This decorator was added and then removed for performance reasons, however it will be added back to allow external libraries to interface with the protocol. It will take a single argument: a dispatcher, i.e. a callable that returns an iterable of all the "duck-Tensors", or possible candidates for classes that may implement __torch_function__.

If a library forgets to add the aforementioned decorator, then the method will no longer dispatch at all to any form of __torch_function__. In other words, it will lose support for the protocol. This can lead to confusion, as some methods of the subclass will pass through __torch_function__ (the ones inherited from torch.Tensor), and some won't.

Note that subclasses will still be passed through due to the default implementation of __torch_function__, but any __torch_function__ defined on the class itself (or any of its subclasses) won't have an effect on its methods.

This is a design choice that a subclass author will have to make, whether they prefer their own functions/methods to pass through __torch_function__ like PyTorch's implementations, or whether they'd like ultimately to not support the protocol and accept having a mix of overridable and non-overridable methods.

We do not propose automatic marking of functions with this decorator due to the potential backwards-compatibility break it could cause, as well as the parameters that are needed in order to allow this to happen (namely the dispatcher, which isn't in our control).

Getting the method from its __name__ and __module__

To construct the function given its __name__ and __module__, one can do the following, as an example:

def get_function(name, module):
    func = __import__(module)
    for n in name.split('.'):
        func = getattr(func, n)
    return func

Adding the torch.Tensor.as_subclass method

The torch.Tensor.as_subclass method will be added, taking a single non-self argument: cls, the class for which an instance will be created with a view into the data of the original Tensor. It will become public API. This method will create an object that has the same data pointer as the original object, which means that modifications to this will be reflected in the original object. More or less, it will have the same effect as modifying an object's __class__ attribute in Python.

This method is already used in external libraries, and they may need it as a way to e.g. bypass the processing of torch.Tensor.__torch_function__ entirely, while still creating torch.Tensor subclasses in their own code.

Implementation

To implement this proposal requires three main steps:

  1. Add a types argument to __torch_function__ and make sure that only arguments that are instances of a type in types are processed.
  2. Making sure that all Tensor methods except __new__ and __init__ go through __torch_function__.
  3. Add Tensor.as_subclass and @torch_function_dispatch as public API.

Implementing only some methods but not others

One can use the dictionary idiom to only implement some methods but not others. A code example follows:

HANDLED_FUNCTIONS = {}

def implements(func):
    def inner(implementation):
        HANDLED_FUNCTIONS[func] = implementation
        return implementation

@implements(torch.add)
def my_add(self, other):
    ...

class TensorLike:
    @classmethod
    def __torch_function__(cls, func, types, args, kwargs):
        implementation = HANDLED_FUNCTIONS.get(func, None)
        if implementation is None:
            return NotImplemented
        
        return implementation(*args, **kwargs)

For subclasses, one can also choose to use the fallback implementation if a specialized implementation isn't available using super, as shown below.

class SubTensor(torch.Tensor):
    @classmethod
    def __torch_function__(cls, func, types, args, kwargs):
        implementation = HANDLED_FUNCTIONS.get(func, None)
        if implementation is None:
            return super().__torch_function__(
                func, types, args, kwargs
            )
        
        return implementation(*args, **kwargs)

A call to super().__torch_function__ can also be used to call the fallback implementation within any other function.

The examples we have seen here actually specify what we anticipate will be two common patterns of using __torch_function__: LoggingTensor is an example of a global hook, and the two examples above show a way to achieve specialised implementations of particular functions.

Wrapping torch.Tensor

Sometimes it's useful to wrap torch.Tensor rather than have a subclass. The following class shows how this is possible in practice:

def wrap(f):
    @functools.wraps(f)
    def inner(self, *a, **kw):
        # Call `f` with all-unwrapped args
        # Possibly wrap back result before returning

class WrappedTensor:
    def __init__(self, towrap: Tensor):
        self._wrapped = towrap

    def __getattr__(self, name):
        base = getattr(torch.Tensor, name)
        if not callable(base):
            return property(wrap(base.__get__))
        
        return wrap(base)
    
    @classmethod
    def __torch_function__(cls, func, types, args, kwargs):
        return wrap(func)(*args, **kwargs)

Proposed alternatives

One alternative that has been proposed is to automatically pass through subclasses a-la NumPy and provide a __torch_finalize__ method that allows for any post-processing of the result. While this would achieve most goals, it would miss out on the one to provide a hook for methods and operators.

Appendix: Special handling for torch.Tensor properties/methods

Both functions and methods/properties on torch.Tensor will be possible arguments to __torch_function__. These are different in subtle but important ways, and in some cases it is required to handle them differently. For instance, torch.Tensor methods/properties have the following properties:

  1. They can only accept torch.Tensor instances as the first argument.
  2. They may or may not have a __module__ defined.

Even classes implementing __torch_function__ that aren't subclasses can have methods passed in. It is required to treat this case with care. Consider the following code:

class TensorLike:
    @classmethod
    def __torch_function__(cls, func, types, args, kwargs):
        print(func.__name__)

torch.tensor([5]) + TensorLike()  # prints "add"

If, in this case, we are using the default implementation, of func, and a torch.Tensor instance is not passed in, an error will be raised. To handle this case, we have provided a utility method, torch.overrides.is_tensor_method_or_property, to determine whether something is a torch.Tensor method/property.

For properties, their __get__ method is passed in. For example,for torch.Tensor.grad, torch.Tensor.grad.__get__ is passed in as func.