pymlfs.kmeans

Module Contents

Classes

Kmeans

Implements a simplified version of the K-means algorithm.

class pymlfs.kmeans.Kmeans(k: int, max_iters: int)[source]

Implements a simplified version of the K-means algorithm.

fit(X: torch.Tensor) None[source]

This method is used to train the K-means algorithm.

Parameters:

X (torch.Tensor) – The training data (m x n) with m being the number of examples and n the number of features.

predict(X: torch.Tensor) torch.Tensor[source]

Performs cluster prediction

This method is used to predict the cluster for a batch of examples.

Parameters:

X (torch.Tensor) – The test data (m x n) with m being the number of examples and n the number of features.

Returns:

returns pcluster predictions in a torch tensor.

Return type:

torch.Tensor

_e_step(X: torch.Tensor) torch.Tensor[source]

Performs the expectation step.

This method performs the expectation step of the EM algorithm. In other words, it assigns a cluster to every training point given the cluster’s centroids.

Parameters:

X (torch.Tensor) – The training data (m x n) with m being the number of examples and n the number of features.

Returns:

data_clusters – Clusters of the input examples.

Return type:

torch.Tensor

_m_step(X: torch.Tensor, data_clusters: torch.Tensor) None[source]

Performs the maximization step.

This method performs the maximization step of the EM algorithm. In other words, given the clusters for each training point, it estimates the new centroids.

Parameters:
  • X (torch.Tensor) – The training data (m x n) with m being the number of examples and n the number of features.

  • data_clusters (torch.Tensor) – Clusters of the input examples.

_stopping_criterion(iter: int, old_centroids: torch.Tensor) bool[source]

Stopping criterion definition.

This method defines the stopping criterion for the K-means algorithm.

Parameters:
  • iter (int) – The curent iteration.

  • old_centroids (torch.Tensor) – The previous version of clusters centroids.

Returns:

returns True if the conditions of the stopping criterion are satisfied.

Return type:

bool

_logger_config() None[source]

Logger configuration