class CNN1DModel(BaseModel):
"""1D-CNN wrapper implementing BaseModel interface."""
def __init__(self, config=None):
config = config or {}
self.filters = config.get("filters", [64, 128, 64])
self.kernel_size = config.get("kernel_size", 3)
self.dropout = config.get("dropout", 0.3)
self.epochs = config.get("epochs", 30)
self.lr = config.get("learning_rate", 0.001)
self.batch_size = config.get("batch_size", 256)
self._model = None
self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self._history = None
@property
def name(self) -> str:
return "1D-CNN"
@property
def model_type(self) -> str:
return "sequence"
def fit(self, X_train, y_train, X_val=None, y_val=None, **kwargs):
input_features = X_train.shape[2]
num_classes = len(np.unique(y_train))
self._model = PitchCNN1D(
input_features=input_features,
num_classes=num_classes,
filters=self.filters,
kernel_size=self.kernel_size,
dropout=self.dropout,
)
train_ds = PitchSequenceDataset(X_train, y_train)
train_loader = DataLoader(train_ds, batch_size=self.batch_size, shuffle=True)
if X_val is not None and y_val is not None:
val_ds = PitchSequenceDataset(X_val, y_val)
else:
split = int(len(X_train) * 0.8)
val_ds = PitchSequenceDataset(X_train[split:], y_train[split:])
val_loader = DataLoader(val_ds, batch_size=self.batch_size, shuffle=False)
self._history = train_torch_model(
self._model, train_loader, val_loader,
epochs=self.epochs, lr=self.lr, device=self._device,
)
def predict(self, X) -> np.ndarray:
ds = PitchSequenceDataset(X, np.zeros(len(X), dtype=np.int64))
loader = DataLoader(ds, batch_size=self.batch_size, shuffle=False)
preds, _ = predict_torch_model(self._model, loader, self._device)
return preds
def predict_proba(self, X) -> np.ndarray:
ds = PitchSequenceDataset(X, np.zeros(len(X), dtype=np.int64))
loader = DataLoader(ds, batch_size=self.batch_size, shuffle=False)
_, probs = predict_torch_model(self._model, loader, self._device)
return probs
def get_params(self) -> dict:
return {
"filters": self.filters,
"kernel_size": self.kernel_size,
"dropout": self.dropout,
"epochs": self.epochs,
"learning_rate": self.lr,
}