# -*- coding: utf-8 -*-
"""Functions for handling PyTorch Geometric Datasets"""
from pathlib import Path
from typing import Any
import numpy as np
from torch_geometric.data import InMemoryDataset # type: ignore
from farmnet.data.datasets.kelmarsh import KelmarshDataset
# from farmgnn.configuration import settings
[docs]
def dataset_sample(
dataset: InMemoryDataset, sample_size: int
) -> InMemoryDataset:
"""
Randomly samples a subset from the given dataset without replacement.
:param dataset: The input dataset to sample from, an instance of InMemoryDataset.
:param sample_size: The number of samples to select.
:return: A new InMemoryDataset containing the randomly selected samples.
:raises IndexError: If the sample size is greater than the dataset length.
.. rubric:: Example
.. code-block:: python
>>> import numpy as np
>>> from torch_geometric.data import InMemoryDataset
>>> class DummyDataset(InMemoryDataset):
... def __init__(self, length):
... super().__init__()
... self.data_list = [i for i in range(length)]
... def __len__(self):
... return len(self.data_list)
... def copy(self, idx):
... new_ds = DummyDataset(0)
... new_ds.data_list = [self.data_list[i] for i in idx]
... return new_ds
>>> np.random.seed(42)
>>> dataset = DummyDataset(10)
>>> sampled_dataset = dataset_sample(dataset, 5)
>>> len(sampled_dataset)
5
>>> all(item in dataset.data_list for item in sampled_dataset.data_list)
True
.. warning:: TODO replace DummyDataset with Kelmarsh
"""
len_dataset = len(dataset)
if sample_size > len_dataset:
raise IndexError(
f"The sample size {sample_size} is greater than the dataset length {len_dataset}"
)
idx = np.arange(len_dataset)
sample_idx = np.random.choice(idx, size=sample_size, replace=False)
return dataset.copy(sample_idx)
[docs]
def load_dataset(path: str | Path) -> InMemoryDataset:
"""
Loads the KelmarshDataset from the specified path with predefined features and target.
:param path: The path to the dataset directory as a string or Path object.
:return: An instance of InMemoryDataset containing the loaded data.
.. rubric:: Example
.. code-block:: python
# >>> from farmnet.data.datasets.
# >>> from pathlib import Path
# >>> # Assuming KelmarshDataset is correctly defined and available
# >>> dataset = load_dataset(Path("examples"))
# >>> isinstance(dataset, InMemoryDataset)
# True
"""
dataset = KelmarshDataset(
path,
data_path=None,
windfarm_static_path=None,
features=["u_g", "v_g", "nacelle_direction"],
target="wind_speed",
wt_col="wt_id",
)
return dataset
# data_path = Path(settings.dataset.data_path).expanduser().absolute()
# dataset_dir = Path(settings.dataset.root_dir).expanduser().absolute()
#
# config = {"graph": settings.dataset.graph, "windfarm": settings.windfarm}
# if settings.dataset.name == "WinJiDataset":
# return WinJiDataset(dataset_dir, data_path, config=config)
# elif settings.dataset.name == "PyWakeDataset":
# return PyWakeDataset(dataset_dir, data_path, config=config)
# else:
# raise ValueError(f"Dataset {settings.dataset.name} does not exist!")
[docs]
def train_test_split(
dataset: InMemoryDataset, test_size: float = 0.2, seed: int = 0
) -> tuple[Any, Any]:
"""
Splits a dataset into training and testing subsets.
:param dataset: The input dataset to split, an instance of InMemoryDataset.
:param test_size: The proportion of the dataset to include in the test split (default is 0.2).
:param seed: Random seed for reproducibility (default is 0).
:return: A tuple containing two datasets (train_dataset, test_dataset).
.. rubric:: Example
.. code-block:: python
>>> import numpy as np
>>> from torch_geometric.data import InMemoryDataset
>>> class DummyDataset(InMemoryDataset):
... def __init__(self, length):
... super().__init__()
... self.data_list = [i for i in range(length)]
... def __len__(self):
... return len(self.data_list)
... def index_select(self, idx):
... new_ds = DummyDataset(0)
... new_ds.data_list = [self.data_list[i] for i in idx]
... return new_ds
>>> np.random.seed(0)
>>> dataset = DummyDataset(10)
>>> train_ds, test_ds = train_test_split(dataset, test_size=0.3, seed=42)
>>> len(train_ds)
7
>>> len(test_ds)
3
.. warning:: TODO replace DummyDataset with Kelmarsh
"""
np.random.seed(seed)
len_dataset = len(dataset)
idx = np.arange(len_dataset)
train_len = int(np.round(len_dataset * (1.0 - test_size)))
train_idx = np.random.choice(idx, size=train_len, replace=False)
test_idx = list(set(idx).difference(train_idx))
return dataset.index_select(train_idx), dataset.index_select(test_idx)