Models Registry¶
Model registry and factory function.
pitch_sequencing.models
¶
Model registry for all pitch prediction models.
MODEL_REGISTRY = {'logistic_regression': LogisticRegressionModel, 'random_forest': RandomForestModel, 'hmm': HMMModel, 'autogluon': AutoGluonModel, 'lstm': LSTMModel, 'cnn1d': CNN1DModel, 'transformer': TransformerModel}
module-attribute
¶
get_model(name, config=None)
¶
Instantiate a model by registry name.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name
|
Key in MODEL_REGISTRY (e.g. 'lstm', 'random_forest'). |
required | |
config
|
Optional dict of hyperparameters. |
None
|
Returns:
| Type | Description |
|---|---|
|
Instance of the model class. |