IBM and Microsoft have released libraries and conda packages that modify pytorch in order to escape out of memory issues when scaling up, some of which hopefully will be implemented in future versions of pytorch (https://github.com/pytorch/pytorch/issues/35633)
have you explored using any of these alternatives?
https://github.com/IBM/pytorch-large-model-support
https://github.com/microsoft/DeepSpeed