Failure to trace due to functionalization error#

Observed Error#

RuntimeError: false INTERNAL ASSERT FAILED at "aten/src/ATen/RegisterFunctionalization_1.cpp":11608, please report a bug to PyTorch. mutating a non-functional tensor with a functional tensor is not allowed. Please ensure that all of your inputs are wrapped inside of a functionalize() call.

Explanation#

In-place ops are not allowed in the compute graph. To work around this, all in-place ops that are encountered are “functionalized” and automatically converted into their non-in-place counterparts. However, in order for functionalization to work, all tensors must be on the same torch device. More specifically the torch device that belongs to the backend object constructed by cerebras.pytorch.backend. Otherwise, the functionalization algorithm will encounter a mix of tensors, some of which are functionalized and some of which are not. This is not supported behaviour and thus the above error will occur.

Work around#

Ensure that all tensors are on the same device. To do this, define or move tensors to the torch device belonging to the Cerebras backend, e.g.

backend = cstorch.backend("CSX")
...
@cstorch.trace
def training_step(inputs, targets):
    ...
    new_tensor_1 = torch.tensor([1, 2, 3], device=backend.torch_device)
    new_tensor_2 = torch.tensor([1, 2, 3]).to(backend.torch_device)
    ...

Alternatively, if you know you have tensors that are already on the torch device belonging to the backend, then you can the device attached to those tensors to move other tensors to the same device.

@cstorch.trace
def training_step(inputs, targets):
    ...
    new_tensor_1 = torch.tensor([1, 2, 3], device=inputs.device)
    new_tensor_2 = torch.tensor([1, 2, 3]).to(inputs.torch_device)
    ...