Torch Utilities¶
Shared PyTorch training loop, dataset class, and prediction utilities.
pitch_sequencing.models.torch_utils
¶
Shared PyTorch training utilities for sequence models.
PitchSequenceDataset
¶
Bases: Dataset
Wraps numpy arrays as a PyTorch Dataset.
Source code in src/pitch_sequencing/models/torch_utils.py
predict_torch_model(model, data_loader, device=None)
¶
Run inference and return (predictions, probabilities).
Returns:
| Type | Description |
|---|---|
ndarray
|
(predictions, probabilities) where predictions has shape (n,) |
ndarray
|
and probabilities has shape (n, num_classes). |
Source code in src/pitch_sequencing/models/torch_utils.py
train_torch_model(model, train_loader, val_loader, epochs=20, lr=0.001, patience=5, device=None)
¶
Train a PyTorch model with Adam, LR scheduling, early stopping, and gradient clipping.
Returns:
| Type | Description |
|---|---|
Dict
|
Dictionary with train_losses, val_losses, val_accuracies. |