Post

DataLoaders

An introduction to the functions in /TrainingMethod/DataLoaders.py

DataLoaders

DglGraphLoader

(

data: Dict[Union[Literal['data', 'labels'], Literal['data', 'names']], Any]

batch_size: <class 'int'>

device: str | torch.device = cpu

shuffle: <class 'bool'> = True

is_train: <class 'bool'> = True

data_names: Optional[Sequence[str]] = None

)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
A Data loader to form dgl graph IN MEMORY

Args:
    data:
        when is_train == True:
            Dict: {
                'data': List[dgl.Graph],
                'label': Dict{'energy': Sequence[float], 'forces': Sequence[np.NDArray[n_atom, 3]]}
            }. Wherein 'force' is optional.
        else see `data_names`.
    batch_size: batch size.
    device: the device that data put on.
    shuffle: whether shuffle data.
    is_train: if `is_train` = True, data need to contain labels; else data need not contain labels and the return of labels [i.e., next(iter(dataloader))] was depended on `contain_names`.
    data_names: only works when `is_train` = False.
            if `data_names` is not None, it should be a Sequence(data names) with the same order as data,
            and the returned `labels` [i.e., next(iter(dataloader))] would be data_names instead of "energy" or "forces",
            else `labels` would be None.

Yields:
        (dgl.DGLGraph, {'energy': energy, 'forces': force})
     or (dgl.DGLGraph, {'energy': energy, })
     or (dgl.DGLGraph, data_names | None) [when is_train == False]

shuffle

(

)

PyGDataLoader

(

data: Dict[Union[Literal['data', 'labels'], Literal['data', 'names']], Any]

batch_size: <class 'int'>

device: str | torch.device = cpu

shuffle: <class 'bool'> = True

is_train: <class 'bool'> = True

data_names: Optional[Sequence[str]] = None

)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
A Data loader to form pygData IN MEMORY

Args:
    data: pyg.Data that contain attributes `pos`, `cell`, `atomic_numbers`, `natoms`, `tags`, `fixed`, `pbc`, `idx`.
          `pos`: Tensor, atom coordinates.
          `cell`: Tensor, cell vectors.
          `atomic_numbers`: Tensor, atomic numbers, corresponding to `pos` one by one.
          `natoms`: int, number of atoms.
          `tags`: Tensor, to be compatible with 'FAIR-CHEM' (https://fair-chem.github.io/),
                  which fixed slab part is set to 0, free slab part is 1, adsorbate is 2.
          `fixed`: Tensor, fixed tag, which fixed atoms are 0, free atoms are 1.
          `pbc`: List[bool, bool, bool], where to be periodic at x, y, z directions.
    batch_size: batch size.
    device: the device that data put on.
    shuffle: whether shuffle data.
    is_train: if `is_train` = True, data need to contain labels; else data need not contain labels and the return of labels [i.e., next(iter(dataloader))] was depended on `contain_names`.
    data_names: only works when `is_train` = False.
            if `data_names` is not None, it should be a Sequence(data names) with the same order as data,
            and the returned `labels` [i.e., next(iter(dataloader))] would be data_names instead of "energy" or "forces",
            else `labels` would be None.

Yields:
        (pyg.Data, {'energy': energy, 'forces': force})
     or (pyg.Data, {'energy': energy, })
     or (pyg.Data, data_names | None) [when is_train == False]

shuffle

(

)

This post is licensed under CC BY 4.0 by the author.