UM2N.loader package¶
Submodules¶
UM2N.loader.cluster_utils module¶
- get_neighbors(source_mask, edge_idx)[source]¶
Get the neighbors of the source nodes Args:
data: the data object source_mask: a mask of the source nodes edge_idx: the edge index
- return:
nei_mask: a mask of the neighbors
- calc_dist(coords, node_idx, neighbors_mask)[source]¶
Calculate the distance between the node and its neighbors Args:
coords: the coordinates of the nodes node_idx: the index of the node neighbors_mask: a mask of the neighbors
- return:
dist: the distance between the node and its neighbors
- sampler(num_nodes, coords, edge_idx, node_idx, r=0.25, N=100)[source]¶
For a single node, sample N neighbours within radius r. return the indices of the neighbours
UM2N.loader.data_transform module¶
UM2N.loader.dataset module¶
- class AggreateDataset(datasets)[source]¶
Bases:
Dataset
Aggregate multiple datasets into a single dataset.
- Attributes:
datasets (list): List of datasets. datasets_len (list): Length of each dataset in datasets.
- class MeshDataset(file_dir, transform=None, target_transform=None, x_feature=['coord', 'bd_mask', 'bd_left_mask', 'bd_right_mask', 'bd_down_mask', 'bd_up_mask'], mesh_feature=['coord', 'u'], conv_feature=['conv_uh'], conv_feature_fix=['conv_uh_fix'], load_analytical=False, load_jacobian=False, use_cluster=False, use_run_time_cluster=False, r=0.35, M=25, dist_weight=False, add_nei=True)[source]¶
Bases:
Dataset
Dataset for mesh-based data.
- Attributes:
x_feature (list): List of feature names for node features. mesh_feature (list): List of feature names for mesh features. conv_feature (list): List of feature names for convolution features. file_names (list): List of filenames containing mesh data.
- get_x_feature(data)[source]¶
Extracts and concatenates the x_features for each node from the data.
- Args:
data (dict): The data dictionary loaded from a .npy file.
- Returns:
tensor: The concatenated x_features for each node.
- get_mesh_feature(data)[source]¶
Extracts and concatenates the mesh_features from the data.
- Args:
data (dict): The data dictionary loaded from a .npy file.
- Returns:
tensor: The concatenated mesh_features.
- class MeshData(x: Tensor | None = None, edge_index: Tensor | None = None, edge_attr: Tensor | None = None, y: Tensor | int | float | None = None, pos: Tensor | None = None, time: Tensor | None = None, **kwargs)[source]¶
Bases:
Data
Custom PyTorch Data object designed to handle mesh data features.P
This class is intended to be used as the base class of data samples returned by the MeshDataset.