pymlfs.kmeans
Module Contents
Classes
Implements a simplified version of the K-means algorithm. |
- class pymlfs.kmeans.Kmeans(k: int, max_iters: int)[source]
Implements a simplified version of the K-means algorithm.
- fit(X: torch.Tensor) None[source]
This method is used to train the K-means algorithm.
- Parameters:
X (torch.Tensor) – The training data (m x n) with m being the number of examples and n the number of features.
- predict(X: torch.Tensor) torch.Tensor[source]
Performs cluster prediction
This method is used to predict the cluster for a batch of examples.
- Parameters:
X (torch.Tensor) – The test data (m x n) with m being the number of examples and n the number of features.
- Returns:
returns pcluster predictions in a torch tensor.
- Return type:
torch.Tensor
- _e_step(X: torch.Tensor) torch.Tensor[source]
Performs the expectation step.
This method performs the expectation step of the EM algorithm. In other words, it assigns a cluster to every training point given the cluster’s centroids.
- Parameters:
X (torch.Tensor) – The training data (m x n) with m being the number of examples and n the number of features.
- Returns:
data_clusters – Clusters of the input examples.
- Return type:
torch.Tensor
- _m_step(X: torch.Tensor, data_clusters: torch.Tensor) None[source]
Performs the maximization step.
This method performs the maximization step of the EM algorithm. In other words, given the clusters for each training point, it estimates the new centroids.
- Parameters:
X (torch.Tensor) – The training data (m x n) with m being the number of examples and n the number of features.
data_clusters (torch.Tensor) – Clusters of the input examples.
- _stopping_criterion(iter: int, old_centroids: torch.Tensor) bool[source]
Stopping criterion definition.
This method defines the stopping criterion for the K-means algorithm.
- Parameters:
iter (int) – The curent iteration.
old_centroids (torch.Tensor) – The previous version of clusters centroids.
- Returns:
returns True if the conditions of the stopping criterion are satisfied.
- Return type:
bool