Random Forest
A scikit-learn random forest ensemble classifier for tabular pitch data.
Overview
- Type: Tabular
- Library: scikit-learn
- Registry name:
random_forest
- Class:
RandomForestModel
Configuration
# configs/models/random_forest.yaml
model_type: random_forest
n_estimators: 200
max_depth: 15
random_state: 42
| Parameter |
Default |
Description |
n_estimators |
200 |
Number of trees in the forest |
max_depth |
15 |
Maximum tree depth |
random_state |
42 |
Random seed for reproducibility |
Usage
from pitch_sequencing import get_model
model = get_model("random_forest", {"n_estimators": 200, "max_depth": 15})
model.fit(X_train, y_train)
predictions = model.predict(X_test)
probabilities = model.predict_proba(X_test)
API Reference
pitch_sequencing.models.baselines.RandomForestModel
Bases: BaseModel
Random Forest baseline for tabular pitch data.
Source code in src/pitch_sequencing/models/baselines.py
| class RandomForestModel(BaseModel):
"""Random Forest baseline for tabular pitch data."""
def __init__(self, config=None):
config = config or {}
self.n_estimators = config.get("n_estimators", 200)
self.max_depth = config.get("max_depth", 15)
self.min_samples_split = config.get("min_samples_split", 5)
self.class_weight = config.get("class_weight", "balanced")
self._model = None
@property
def name(self) -> str:
return "Random Forest"
@property
def model_type(self) -> str:
return "tabular"
def fit(self, X_train, y_train, X_val=None, y_val=None, **kwargs):
self._model = RandomForestClassifier(
n_estimators=self.n_estimators,
max_depth=self.max_depth,
min_samples_split=self.min_samples_split,
class_weight=self.class_weight,
random_state=42,
n_jobs=-1,
)
self._model.fit(X_train, y_train)
def predict(self, X) -> np.ndarray:
return self._model.predict(X)
def predict_proba(self, X) -> np.ndarray:
return self._model.predict_proba(X)
def get_params(self) -> dict:
return {
"n_estimators": self.n_estimators,
"max_depth": self.max_depth,
"min_samples_split": self.min_samples_split,
}
|