Home
Add a little accelerant to your torch
About
stoke
is a lightweight wrapper for PyTorch that provides a simple declarative API for context switching between
devices (e.g. CPU, GPU), distributed modes, mixed-precision, and PyTorch extensions. This allows you to switch from
local full-precision CPU to mixed-precision distributed multi-GPU with extensions (like optimizer state sharding)
by simply changing a few declarative flags. Additionally, stoke
exposes configuration settings for every
underlying backend for those that want configurability and raw access to the underlying libraries.
In short, stoke
is the best of
PyTorch Lightning Accelerators
disconnected from the rest of PyTorch Lightning. Write whatever PyTorch code you want, but leave device and backend
context switching to stoke
.
Supports
- Devices: CPU, GPU, multi-GPU
- Distributed: DDP, Horovod, deepspeed (via DDP)
- Mixed-Precision: AMP, Nvidia Apex, deepspeed (custom APEX like backend)
- Extensions: fairscale (Optimizer State Sharding, Sharded DDP, Fully Sharded DDP), deepspeed (ZeRO Stage 0-3, etc.)
Benefits/Capabilities
- Declarative style API -- allows you to declare or specify the desired state and let
stoke
handle the rest - Mirrors base PyTorch style
model
,loss
,backward
, andstep
calls - Automatic device placement of model(s) and data
- Universal interface for saving and loading regardless of backend(s) or device
- Automatic handling of gradient accumulation and clipping
- Common
attrs
interface for all backend configuration parameters (with docstrings) - Helper methods for printing synced losses, device specific print, number of model parameters
- Extra(s) - Custom torch.utils.data.distributed.Sampler: BucketedDistributedSampler which buckets data by a sorted idx and then randomly samples from specific bucket(s) to prevent situations like grossly mismatched sequence length leading to wasted computational overhead (ie excess padding)