Fwd: svn commit: r793620 - /lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java

classic Classic list List threaded Threaded
1 message Options
Reply | Threaded
Open this post in threaded view
|

Fwd: svn commit: r793620 - /lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java

Grant Ingersoll-2
FWIW, it always seemed a bit strange to me that isConverged lives on  
KMeansDriver and not KMeansUtil or something like that.

Begin forwarded message:

> From: [hidden email]
> Date: July 13, 2009 12:38:52 PM EDT
> To: [hidden email]
> Subject: svn commit: r793620 - /lucene/mahout/trunk/core/src/main/
> java/org/apache/mahout/clustering/kmeans/KMeansDriver.java
> Reply-To: [hidden email]
>
> Author: jeastman
> Date: Mon Jul 13 16:38:52 2009
> New Revision: 793620
>
> URL: http://svn.apache.org/viewvc?rev=793620&view=rev
> Log:
> - modified KMeaansDriver.isConverged() to iterate over all part  
> files in the clusters directories
> - removed '/part-0000' append from runIteration()
> - unit test no longer throws exceptions
> - example synthetic control job runs
> - still some formatting differences between Eclipse and JBuilder
>
> Modified:
>    lucene/mahout/trunk/core/src/main/java/org/apache/mahout/
> clustering/kmeans/KMeansDriver.java
>
> Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/
> clustering/kmeans/KMeansDriver.java
> URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java?rev=793620&r1=793619&r2=793620&view=diff
> =
> =
> =
> =
> =
> =
> =
> =
> ======================================================================
> --- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/
> clustering/kmeans/KMeansDriver.java (original)
> +++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/
> clustering/kmeans/KMeansDriver.java Mon Jul 13 16:38:52 2009
> @@ -16,6 +16,8 @@
>  */
> package org.apache.mahout.clustering.kmeans;
>
> +import java.io.IOException;
> +
> import org.apache.commons.cli2.CommandLine;
> import org.apache.commons.cli2.Group;
> import org.apache.commons.cli2.Option;
> @@ -24,6 +26,7 @@
> import org.apache.commons.cli2.builder.DefaultOptionBuilder;
> import org.apache.commons.cli2.builder.GroupBuilder;
> import org.apache.commons.cli2.commandline.Parser;
> +import org.apache.hadoop.fs.FileStatus;
> import org.apache.hadoop.fs.FileSystem;
> import org.apache.hadoop.fs.Path;
> import org.apache.hadoop.io.SequenceFile;
> @@ -43,8 +46,6 @@
> import org.slf4j.Logger;
> import org.slf4j.LoggerFactory;
>
> -import java.io.IOException;
> -
> public class KMeansDriver {
>
>   /** The name of the directory used to output final results. */
> @@ -56,58 +57,68 @@
>   }
>
>   /** @param args Expects 7 args and they all correspond to the  
> order of the params in {@link #runJob} */
> -  public static void main(String[] args) throws  
> ClassNotFoundException, IOException, IllegalAccessException,  
> InstantiationException {
> +  public static void main(String[] args) throws  
> ClassNotFoundException, IOException, IllegalAccessException,
> +      InstantiationException {
>
>     DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
>     ArgumentBuilder abuilder = new ArgumentBuilder();
>     GroupBuilder gbuilder = new GroupBuilder();
>
>     Option inputOpt =  
> obuilder.withLongName("input").withRequired(true).withArgument(
> -        
> abuilder.withName("input").withMinimum(1).withMaximum(1).create()).
> -        withDescription("The Path for input Vectors. Must be a  
> SequenceFile of Writable, Vector").withShortName("i").create();
> +        
> abuilder
> .withName
> ("input").withMinimum(1).withMaximum(1).create()).withDescription(
> +        "The Path for input Vectors. Must be a SequenceFile of  
> Writable, Vector").withShortName("i").create();
>
> -    Option clustersOpt =  
> obuilder.withLongName("clusters").withRequired(true).withArgument(
> -        
> abuilder.withName("clusters").withMinimum(1).withMaximum(1).create()).
> -        withDescription("The input centroids, as Vectors.  Must be  
> a SequenceFile of Writable, Cluster/Canopy.  " +
> -            "If k is also specified, then a random set of vectors  
> will be selected and written out to this path  
> first").withShortName("c").create();
> -
> -    Option kOpt =  
> obuilder.withLongName("k").withRequired(false).withArgument(
> -        
> abuilder.withName("k").withMinimum(1).withMaximum(1).create()).
> -        withDescription("The k in k-Means.  If specified, then a  
> random selection of k Vectors will be chosen as the Centroid and  
> written to the clusters output path.").withShortName("k").create();
> +    Option clustersOpt = obuilder
> +        .withLongName("clusters")
> +        .withRequired(true)
> +
>         .withArgument
> (abuilder.withName("clusters").withMinimum(1).withMaximum(1).create())
> +        .withDescription(
> +            "The input centroids, as Vectors.  Must be a  
> SequenceFile of Writable, Cluster/Canopy.  "
> +                + "If k is also specified, then a random set of  
> vectors will be selected and written out to this path first")
> +        .withShortName("c").create();
> +
> +    Option kOpt = obuilder
> +        .withLongName("k")
> +        .withRequired(false)
> +
>         .withArgument
> (abuilder.withName("k").withMinimum(1).withMaximum(1).create())
> +        .withDescription(
> +            "The k in k-Means.  If specified, then a random  
> selection of k Vectors will be chosen as the Centroid and written to  
> the clusters output path.")
> +        .withShortName("k").create();
>
>     Option outputOpt =  
> obuilder.withLongName("output").withRequired(true).withArgument(
> -        
> abuilder.withName("output").withMinimum(1).withMaximum(1).create()).
> -        withDescription("The Path to put the output  
> in").withShortName("o").create();
> +        
> abuilder
> .withName
> ("output").withMinimum(1).withMaximum(1).create()).withDescription(
> +        "The Path to put the output in").withShortName("o").create();
>
> -    Option overwriteOutput =  
> obuilder.withLongName("overwrite").withRequired(false).
> -        withDescription("If set, overwrite the output  
> directory").withShortName("w").create();
> +    Option overwriteOutput =  
> obuilder
> .withLongName("overwrite").withRequired(false).withDescription(
> +        "If set, overwrite the output  
> directory").withShortName("w").create();
>
>     Option measureClassOpt =  
> obuilder.withLongName("distance").withRequired(false).withArgument(
> -        
> abuilder.withName("distance").withMinimum(1).withMaximum(1).create()).
> -        withDescription("The Distance Measure to use.  Default is  
> SquaredEuclidean").withShortName("m").create();
> +        
> abuilder
> .withName
> ("distance").withMinimum(1).withMaximum(1).create()).withDescription(
> +        "The Distance Measure to use.  Default is  
> SquaredEuclidean").withShortName("m").create();
>
>     Option convergenceDeltaOpt =  
> obuilder.withLongName("convergence").withRequired(false).withArgument(
> -        
> abuilder
> .withName("convergence").withMinimum(1).withMaximum(1).create()).
> -        withDescription("The threshold below which the clusters are  
> considered to be converged.  Default is  
> 0.5").withShortName("d").create();
> +        
> abuilder
> .withName
> ("convergence
> ").withMinimum(1).withMaximum(1).create()).withDescription(
> +        "The threshold below which the clusters are considered to  
> be converged.  Default is 0.5").withShortName("d")
> +        .create();
>
>     Option maxIterationsOpt =  
> obuilder.withLongName("max").withRequired(false).withArgument(
> -        
> abuilder.withName("max").withMinimum(1).withMaximum(1).create()).
> -        withDescription("The maximum number of iterations to  
> perform.  Default is 20").withShortName("x").create();
> +        
> abuilder
> .withName
> ("max").withMinimum(1).withMaximum(1).create()).withDescription(
> +        "The maximum number of iterations to perform.  Default is  
> 20").withShortName("x").create();
>
>     Option vectorClassOpt =  
> obuilder.withLongName("vectorClass").withRequired(false).withArgument(
> -        
> abuilder
> .withName("vectorClass").withMinimum(1).withMaximum(1).create()).
> -        withDescription("The Vector implementation class name.  
> Default is SparseVector.class").withShortName("v").create();
> +        
> abuilder
> .withName
> ("vectorClass
> ").withMinimum(1).withMaximum(1).create()).withDescription(
> +        "The Vector implementation class name.  Default is  
> SparseVector.class").withShortName("v").create();
>
>     Option numReduceTasksOpt =  
> obuilder.withLongName("numReduce").withRequired(false).withArgument(
> -        
> abuilder
> .withName("numReduce").withMinimum(1).withMaximum(1).create()).
> -        withDescription("The number of reduce  
> tasks").withShortName("r").create();
> +        
> abuilder
> .withName
> ("numReduce").withMinimum(1).withMaximum(1).create()).withDescription(
> +        "The number of reduce tasks").withShortName("r").create();
>
> -    Option helpOpt = obuilder.withLongName("help").
> -        withDescription("Print out  
> help").withShortName("h").create();
> +    Option helpOpt =  
> obuilder.withLongName("help").withDescription("Print out  
> help").withShortName("h").create();
>
> -    Group group =  
> gbuilder
> .withName
> ("Options
> ").withOption
> (inputOpt
> ).withOption
> (clustersOpt).withOption(outputOpt).withOption(measureClassOpt)
> -        .withOption
> (convergenceDeltaOpt
> ).withOption
> (maxIterationsOpt).withOption(numReduceTasksOpt).withOption(kOpt)
> -        .withOption
> (vectorClassOpt
> ).withOption(overwriteOutput).withOption(helpOpt).create();
> +    Group group =  
> gbuilder
> .withName
> ("Options
> ").withOption(inputOpt).withOption(clustersOpt).withOption(outputOpt)
> +
>         .withOption
> (measureClassOpt
> ).withOption
> (convergenceDeltaOpt).withOption(maxIterationsOpt).withOption(
> +            
> numReduceTasksOpt
> ).withOption
> (kOpt
> ).withOption(vectorClassOpt).withOption(overwriteOutput).withOption(
> +            helpOpt).create();
>     try {
>       Parser parser = new Parser();
>       parser.setGroup(group);
> @@ -129,11 +140,9 @@
>         convergenceDelta =  
> Double.parseDouble(cmdLine.getValue(convergenceDeltaOpt).toString());
>       }
>
> -      Class<? extends Vector> vectorClass =  
> cmdLine.hasOption(vectorClassOpt) == false ?
> -          SparseVector.class
> +      Class<? extends Vector> vectorClass =  
> cmdLine.hasOption(vectorClassOpt) == false ? SparseVector.class
>           : (Class<? extends Vector>)  
> Class.forName(cmdLine.getValue(vectorClassOpt).toString());
>
> -
>       int maxIterations = 20;
>       if (cmdLine.hasOption(maxIterationsOpt)) {
>         maxIterations =  
> Integer.parseInt(cmdLine.getValue(maxIterationsOpt).toString());
> @@ -146,36 +155,35 @@
>         HadoopUtil.overwriteOutput(output);
>       }
>       if (cmdLine.hasOption(kOpt)) {
> -        clusters = RandomSeedGenerator.buildRandom(input, clusters,  
> Integer.parseInt(cmdLine.getValue(kOpt).toString())).toString();
> +        clusters = RandomSeedGenerator
> +            .buildRandom(input, clusters,  
> Integer.parseInt(cmdLine.getValue(kOpt).toString())).toString();
>       }
> -      runJob(input, clusters, output, measureClass, convergenceDelta,
> -          maxIterations, numReduceTasks, vectorClass);
> +      runJob(input, clusters, output, measureClass,  
> convergenceDelta, maxIterations, numReduceTasks, vectorClass);
>     } catch (OptionException e) {
>       log.error("Exception", e);
>       CommandLineUtil.printHelp(group);
>     }
>   }
>
> -
>   /**
>    * Run the job using supplied arguments
> -   *
> -   * @param input            the directory pathname for input points
> -   * @param clustersIn       the directory pathname for initial &  
> computed clusters
> -   * @param output           the directory pathname for output points
> -   * @param measureClass     the classname of the DistanceMeasure
> +   *
> +   * @param input the directory pathname for input points
> +   * @param clustersIn the directory pathname for initial &  
> computed clusters
> +   * @param output the directory pathname for output points
> +   * @param measureClass the classname of the DistanceMeasure
>    * @param convergenceDelta the convergence delta value
> -   * @param maxIterations    the maximum number of iterations
> -   * @param numReduceTasks   the number of reducers
> +   * @param maxIterations the maximum number of iterations
> +   * @param numReduceTasks the number of reducers
>    */
> -  public static void runJob(String input, String clustersIn, String  
> output,
> -                            String measureClass, double  
> convergenceDelta, int maxIterations,
> -                            int numReduceTasks, Class<? extends  
> Vector> vectorClass) {
> +  public static void runJob(String input, String clustersIn, String  
> output, String measureClass,
> +      double convergenceDelta, int maxIterations, int  
> numReduceTasks, Class<? extends Vector> vectorClass) {
>     // iterate until the clusters converge
>     String delta = Double.toString(convergenceDelta);
>     if (log.isInfoEnabled()) {
>       log.info("Input: " + input + " Clusters In: " + clustersIn + "  
> Out: " + output + " Distance: " + measureClass);
> -      log.info("convergence: " + convergenceDelta + " max  
> Iterations: " + maxIterations + " num Reduce Tasks: " +  
> numReduceTasks + " Input Vectors: " + vectorClass.getName());
> +      log.info("convergence: " + convergenceDelta + " max  
> Iterations: " + maxIterations + " num Reduce Tasks: "
> +          + numReduceTasks + " Input Vectors: " +  
> vectorClass.getName());
>     }
>     boolean converged = false;
>     int iteration = 0;
> @@ -183,8 +191,7 @@
>       log.info("Iteration {}", iteration);
>       // point the output to a new directory per iteration
>       String clustersOut = output + "/clusters-" + iteration;
> -      converged = runIteration(input, clustersIn, clustersOut,  
> measureClass,
> -          delta, numReduceTasks, iteration);
> +      converged = runIteration(input, clustersIn, clustersOut,  
> measureClass, delta, numReduceTasks, iteration);
>       // now point the input to the old output directory
>       clustersIn = output + "/clusters-" + iteration;
>       iteration++;
> @@ -196,19 +203,18 @@
>
>   /**
>    * Run the job using supplied arguments
> -   *
> -   * @param input            the directory pathname for input points
> -   * @param clustersIn       the directory pathname for input  
> clusters
> -   * @param clustersOut      the directory pathname for output  
> clusters
> -   * @param measureClass     the classname of the DistanceMeasure
> +   *
> +   * @param input the directory pathname for input points
> +   * @param clustersIn the directory pathname for input clusters
> +   * @param clustersOut the directory pathname for output clusters
> +   * @param measureClass the classname of the DistanceMeasure
>    * @param convergenceDelta the convergence delta value
> -   * @param numReduceTasks   the number of reducer tasks
> -   * @param iteration        The iteration number
> +   * @param numReduceTasks the number of reducer tasks
> +   * @param iteration The iteration number
>    * @return true if the iteration successfully runs
>    */
> -  private static boolean runIteration(String input, String  
> clustersIn,
> -                                      String clustersOut, String  
> measureClass, String convergenceDelta,
> -                                      int numReduceTasks, int  
> iteration) {
> +  private static boolean runIteration(String input, String  
> clustersIn, String clustersOut, String measureClass,
> +      String convergenceDelta, int numReduceTasks, int iteration) {
>     JobConf conf = new JobConf(KMeansDriver.class);
>     conf.setMapOutputKeyClass(Text.class);
>     conf.setMapOutputValueClass(KMeansInfo.class);
> @@ -229,11 +235,10 @@
>     conf.set(Cluster.CLUSTER_CONVERGENCE_KEY, convergenceDelta);
>     conf.setInt(Cluster.ITERATION_NUMBER, iteration);
>
> -
>     try {
>       JobClient.runJob(conf);
>       FileSystem fs = FileSystem.get(outPath.toUri(), conf);
> -      return isConverged(clustersOut + "/part-00000", conf, fs);
> +      return isConverged(clustersOut, conf, fs);
>     } catch (IOException e) {
>       log.warn(e.toString(), e);
>       return true;
> @@ -242,15 +247,15 @@
>
>   /**
>    * Run the job using supplied arguments
> -   *
> -   * @param input            the directory pathname for input points
> -   * @param clustersIn       the directory pathname for input  
> clusters
> -   * @param output           the directory pathname for output points
> -   * @param measureClass     the classname of the DistanceMeasure
> +   *
> +   * @param input the directory pathname for input points
> +   * @param clustersIn the directory pathname for input clusters
> +   * @param output the directory pathname for output points
> +   * @param measureClass the classname of the DistanceMeasure
>    * @param convergenceDelta the convergence delta value
>    */
> -  private static void runClustering(String input, String clustersIn,
> -                                    String output, String  
> measureClass, String convergenceDelta, Class<? extends Vector>  
> vectorClass) {
> +  private static void runClustering(String input, String  
> clustersIn, String output, String measureClass,
> +      String convergenceDelta, Class<? extends Vector> vectorClass) {
>     if (log.isInfoEnabled()) {
>       log.info("Running Clustering");
>       log.info("Input: " + input + " Clusters In: " + clustersIn + "  
> Out: " + output + " Distance: " + measureClass);
> @@ -263,7 +268,7 @@
>     conf.setMapOutputKeyClass(Text.class);
>     conf.setMapOutputValueClass(vectorClass);
>     conf.setOutputKeyClass(Text.class);
> -    //the output is the cluster id
> +    // the output is the cluster id
>     conf.setOutputValueClass(Text.class);
>
>     FileInputFormat.setInputPaths(conf, new Path(input));
> @@ -284,33 +289,36 @@
>   }
>
>   /**
> -   * Return if all of the Clusters in the filePath have converged  
> or not
> -   *
> +   * Return if all of the Clusters in the parts in the filePath  
> have converged or not
> +   *
>    * @param filePath the file path to the single file containing the  
> clusters
> -   * @param conf     the JobConf
> -   * @param fs       the FileSystem
> +   * @param conf the JobConf
> +   * @param fs the FileSystem
>    * @return true if all Clusters are converged
>    * @throws IOException if there was an IO error
>    */
> -  private static boolean isConverged(String filePath, JobConf conf,  
> FileSystem fs)
> -      throws IOException {
> -    Path outPart = new Path(filePath + "/*");
> -    SequenceFile.Reader reader = new SequenceFile.Reader(fs,  
> outPart, conf);
> -    Writable key;
> -    try {
> -      key = (Writable) reader.getKeyClass().newInstance();
> -    } catch (InstantiationException e) {//shouldn't happen
> -      log.error("Exception", e);
> -      throw new RuntimeException(e);
> -    } catch (IllegalAccessException e) {
> -      log.error("Exception", e);
> -      throw new RuntimeException(e);
> -    }
> -    Cluster value = new Cluster();
> -    boolean converged = true;
> -    while (converged && reader.next(key, value)) {
> -      converged = value.isConverged();
> -    }
> -    return converged;
> +  private static boolean isConverged(String filePath, JobConf conf,  
> FileSystem fs) throws IOException {
> +    FileStatus[] parts = fs.listStatus(new Path(filePath));
> +    for (FileStatus part : parts)
> +      if (!part.getPath().getName().endsWith(".crc")) {
> +        SequenceFile.Reader reader = new SequenceFile.Reader(fs,  
> part.getPath(), conf);
> +        Writable key;
> +        try {
> +          key = (Writable) reader.getKeyClass().newInstance();
> +        } catch (InstantiationException e) {// shouldn't happen
> +          log.error("Exception", e);
> +          throw new RuntimeException(e);
> +        } catch (IllegalAccessException e) {
> +          log.error("Exception", e);
> +          throw new RuntimeException(e);
> +        }
> +        Cluster value = new Cluster();
> +        while (reader.next(key, value)) {
> +          if (value.isConverged() == false) {
> +            return false;
> +          }
> +        }
> +      }
> +    return true;
>   }
> }
>
>

--------------------------
Grant Ingersoll
http://www.lucidimagination.com/

Search the Lucene ecosystem (Lucene/Solr/Nutch/Mahout/Tika/Droids)  
using Solr/Lucene:
http://www.lucidimagination.com/search