[jira] Commented: (MAHOUT-6) Need a matrix implementation

classic Classic list List threaded Threaded
1 message Options
Reply | Threaded
Open this post in threaded view
|

[jira] Commented: (MAHOUT-6) Need a matrix implementation

Tim Allison (Jira)

    [ https://issues.apache.org/jira/browse/MAHOUT-6?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=12576985#action_12576985 ]

Jason Rennie commented on MAHOUT-6:
-----------------------------------

Ted, I am a bit surprised how speedy the HashMap impl. is! :-)  2-3x slower than the CRS impl, but much better than I thought.  Would be good to test a primitive HashMap (int, double).  Might be as fast as the CRS version and much more flexible.  I'm gonna just drop the code in here.  Besides the two vector classes and test class, you'll also find StopWatch, the code I used to do the timing.  One thing I didn't do was to check much time was wasted by the StopWatch code...

import java.util.Random;

import junit.framework.TestCase;

import org.apache.log4j.Logger;

public class SparseVectorPerformanceTests extends TestCase {

        private static final Logger log = Logger.getLogger(SparseVectorPerformanceTests.class);

        Random rand = new Random();

        /**
         * <ul>
         * <li> Finished HashMap dot product in 4.161 seconds.
         * <li> Finished CRS dot product in 2.340 seconds.
         * <li> numTrials=1000000 vectorSize=1000 nnz1=50 nnz2=200
         * </ul>
         * <ul>
         * <li> Finished HashMap dot product in 6.482 seconds.
         * <li> Finished CRS dot product in 2.663 seconds.
         * <li> numTrials=1000000 vectorSize=1000 nnz1=100 nnz2=100
         * </ul>
         */
        public void testSparseVectorPerformance() throws Exception {
                StopWatch hmvSW = new StopWatch("HashMap dot product", log, false);
                StopWatch crsSW = new StopWatch("CRS dot product", log, false);
                final int numTrials = 1000000;
                final int vectorSize = 1000;
                final int nnz1 = 100;
                final int nnz2 = 100;
                for (int i = 0; i < numTrials; ++i) {
                        SparseVectorHashMap hmv1 = new SparseVectorHashMap();
                        SparseVectorHashMap hmv2 = new SparseVectorHashMap();
                        for (int j = 0; j < nnz1; ++j) {
                                hmv1.set(this.rand.nextInt(vectorSize) + 1, this.rand.nextDouble());
                        }
                        for (int j = 0; j < nnz2; ++j) {
                                hmv2.set(this.rand.nextInt(vectorSize) + 1, this.rand.nextDouble());
                        }
                        SparseVectorCRS crsv1 = hmv1.buildSparseVector();
                        SparseVectorCRS crsv2 = hmv2.buildSparseVector();
                        hmvSW.start();
                        hmv1.dot(hmv2);
                        hmvSW.stop();
                        crsSW.start();
                        crsv1.dot(crsv2);
                        crsSW.stop();
                }
                hmvSW.logEndMessage();
                crsSW.logEndMessage();
                log.debug("numTrials=" + numTrials + " vectorSize=" + vectorSize + " nnz1=" + nnz1 + " nnz2=" + nnz2);
        }

        public void testSparseVectorCorrectness() throws Exception {
                final int vectorSize = 100;
                final int nnz = 10;
                SparseVectorHashMap hmv1 = new SparseVectorHashMap();
                SparseVectorHashMap hmv2 = new SparseVectorHashMap();
                for (int j = 0; j < nnz; ++j) {
                        hmv1.set(this.rand.nextInt(vectorSize) + 1, this.rand.nextDouble());
                        hmv2.set(this.rand.nextInt(vectorSize) + 1, this.rand.nextDouble());
                }
                SparseVectorCRS crsv1 = hmv1.buildSparseVector();
                SparseVectorCRS crsv2 = hmv2.buildSparseVector();
                double hmvDot = hmv1.dot(hmv2);
                double vDot = crsv1.dot(crsv2);
                assertTrue(hmvDot == vDot);
                log.debug(hmvDot);
                log.debug(vDot);
        }

}

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class SparseVectorHashMap {

        Map<Integer, Double> data;

        public SparseVectorHashMap() {
                this.data = new HashMap<Integer, Double>();
        }

        public SparseVectorCRS buildSparseVector() {
                int size = this.data.size();
                int[] index = new int[size];
                double[] value = new double[size];
                List<Integer> keyList = new ArrayList<Integer>(this.data.keySet());
                Collections.sort(keyList);
                for (int i = 0; i < size; ++i) {
                        Integer indexInteger = keyList.get(i);
                        index[i] = indexInteger.intValue();
                        value[i] = this.data.get(indexInteger).doubleValue();
                }
                return new SparseVectorCRS(index, value);
        }

        public void set(Integer _index, Double _value) {
                this.data.put(_index, _value);
        }

        /**
         * Assumption: _smaller.size() < _larger.size()
         *
         * @param _smaller
         *            data entry of SparseVectorHashMap
         * @param _larger
         *            data entry of SparseVectorHashMap
         * @return dot-product of corresponding vectors
         */
        static private double dot(Map<Integer, Double> _smaller, Map<Integer, Double> _larger) {
                double retval = 0.0;
                for (Map.Entry<Integer, Double> smallEntry : _smaller.entrySet()) {
                        Double largeValue = _larger.get(smallEntry.getKey());
                        if (largeValue != null) {
                                retval += largeValue.doubleValue() * smallEntry.getValue().doubleValue();
                        }
                }
                return retval;
        }

        /**
         * @param _v
         * @return dot-product of this vector with _v
         */
        public double dot(SparseVectorHashMap _v) {
                if (this.data.size() < _v.data.size()) {
                        return dot(this.data, _v.data);
                }
                return dot(_v.data, this.data);
        }
}

import java.util.Random;

import junit.framework.TestCase;

import org.apache.log4j.Logger;

public class SparseVectorPerformanceTests extends TestCase {

        private static final Logger log = Logger.getLogger(SparseVectorPerformanceTests.class);

        Random rand = new Random();

        /**
         * <ul>
         * <li> Finished HashMap dot product in 4.161 seconds.
         * <li> Finished CRS dot product in 2.340 seconds.
         * <li> numTrials=1000000 vectorSize=1000 nnz1=50 nnz2=200
         * </ul>
         * <ul>
         * <li> Finished HashMap dot product in 6.482 seconds.
         * <li> Finished CRS dot product in 2.663 seconds.
         * <li> numTrials=1000000 vectorSize=1000 nnz1=100 nnz2=100
         * </ul>
         */
        public void testSparseVectorPerformance() throws Exception {
                StopWatch hmvSW = new StopWatch("HashMap dot product", log, false);
                StopWatch crsSW = new StopWatch("CRS dot product", log, false);
                final int numTrials = 1000000;
                final int vectorSize = 1000;
                final int nnz1 = 100;
                final int nnz2 = 100;
                for (int i = 0; i < numTrials; ++i) {
                        SparseVectorHashMap hmv1 = new SparseVectorHashMap();
                        SparseVectorHashMap hmv2 = new SparseVectorHashMap();
                        for (int j = 0; j < nnz1; ++j) {
                                hmv1.set(this.rand.nextInt(vectorSize) + 1, this.rand.nextDouble());
                        }
                        for (int j = 0; j < nnz2; ++j) {
                                hmv2.set(this.rand.nextInt(vectorSize) + 1, this.rand.nextDouble());
                        }
                        SparseVectorCRS crsv1 = hmv1.buildSparseVector();
                        SparseVectorCRS crsv2 = hmv2.buildSparseVector();
                        hmvSW.start();
                        hmv1.dot(hmv2);
                        hmvSW.stop();
                        crsSW.start();
                        crsv1.dot(crsv2);
                        crsSW.stop();
                }
                hmvSW.logEndMessage();
                crsSW.logEndMessage();
                log.debug("numTrials=" + numTrials + " vectorSize=" + vectorSize + " nnz1=" + nnz1 + " nnz2=" + nnz2);
        }

        public void testSparseVectorCorrectness() throws Exception {
                final int vectorSize = 100;
                final int nnz = 10;
                SparseVectorHashMap hmv1 = new SparseVectorHashMap();
                SparseVectorHashMap hmv2 = new SparseVectorHashMap();
                for (int j = 0; j < nnz; ++j) {
                        hmv1.set(this.rand.nextInt(vectorSize) + 1, this.rand.nextDouble());
                        hmv2.set(this.rand.nextInt(vectorSize) + 1, this.rand.nextDouble());
                }
                SparseVectorCRS crsv1 = hmv1.buildSparseVector();
                SparseVectorCRS crsv2 = hmv2.buildSparseVector();
                double hmvDot = hmv1.dot(hmv2);
                double vDot = crsv1.dot(crsv2);
                assertTrue(hmvDot == vDot);
                log.debug(hmvDot);
                log.debug(vDot);
        }

}

import org.apache.log4j.Logger;

public class StopWatch {
       
        enum StopWatchState { RUNNING, STOPPED }

        static final double meg = 1024.0 * 1024.0;

        long t0;

        String action;

        Logger log;
       
        StopWatchState state;

        double elapsedSeconds = 0.0;
       
        /**
         * Initializes and starts the clock.
         *
         * @param _action
         * @param _log
         */
        public StopWatch(final String _action, final Logger _log) {
                this(_action, _log, true);
        }

        public StopWatch(final String _action, final Logger _log, final boolean _startRunning) {
                this.action = _action;
                this.log = _log;
                this.log.info(startMessage());
                if (_startRunning) {
                        this.state = StopWatchState.RUNNING;
                        this.t0 = System.currentTimeMillis();
                } else
                        this.state = StopWatchState.STOPPED;
        }

        String startMessage() {
                Runtime r = Runtime.getRuntime();
                double totMemMeg = r.totalMemory() / meg;
                double freeMemMeg = r.freeMemory() / meg;
                return String.format("mem=%.2fm  Started %s...", Double.valueOf(totMemMeg - freeMemMeg), this.action);
        }
       
        public void start() {
                if (this.state.equals(StopWatchState.RUNNING)) {
                        this.log.warn("start(): StopWatch is already running.  Not doing anything.");
                        return;
                }
                this.state = StopWatchState.RUNNING;
                this.t0 = System.currentTimeMillis();
        }

        public void stop() {
                double curElapsedSeconds = ((System.currentTimeMillis() - this.t0) / 1000.0);
                if (this.state.equals(StopWatchState.STOPPED)) {
                        this.log.warn("stop(): StopWatch is already stopped.  Not doing anything.");
                        return;
                }
                this.elapsedSeconds += curElapsedSeconds;
                this.state = StopWatchState.STOPPED;
        }
       
        public void logEndMessage() {
                this.log.info(endMessage());
        }

        public void logEndMessage(final String _s) {
                this.log.info(endMessage() + " " + _s);
        }

        String endMessage() {
                double totalElapsedSeconds = this.elapsedSeconds;
                if (this.state.equals(StopWatchState.RUNNING)) {
                        totalElapsedSeconds += ((System.currentTimeMillis() - this.t0) / 1000.0);
                }
                Runtime r = Runtime.getRuntime();
                double totMemMeg = r.totalMemory() / meg;
                double freeMemMeg = r.freeMemory() / meg;
                return String.format("mem=%.2fm  Finished %s in %.3f seconds.", Double.valueOf(totMemMeg - freeMemMeg), this.action, Double.valueOf(totalElapsedSeconds));
        }

}


> Need a matrix implementation
> ----------------------------
>
>                 Key: MAHOUT-6
>                 URL: https://issues.apache.org/jira/browse/MAHOUT-6
>             Project: Mahout
>          Issue Type: New Feature
>            Reporter: Ted Dunning
>            Assignee: Grant Ingersoll
>         Attachments: MAHOUT-6a.diff, MAHOUT-6b.diff, MAHOUT-6c.diff, MAHOUT-6d.diff, MAHOUT-6e.diff, MAHOUT-6f.diff, MAHOUT-6g.diff, MAHOUT-6h.patch, MAHOUT-6i.diff, MAHOUT-6j.diff, MAHOUT-6k.diff, MAHOUT-6l.patch
>
>
> We need matrices for Mahout.
> An initial set of basic requirements includes:
> a) sparse and dense support are required
> b) row and column labels are important
> c) serialization for hadoop use is required
> d) reasonable floating point performance is required, but awesome FP is not
> e) the API should be simple enough to understand
> f) it should be easy to carve out sub-matrices for sending to different reducers
> g) a reasonable set of matrix operations should be supported, these should eventually include:
>     simple matrix-matrix and matrix-vector and matrix-scalar linear algebra operations, A B, A + B, A v, A + x, v + x, u + v, dot(u, v)
>     row and column sums  
>     generalized level 2 and 3 BLAS primitives, alpha A B + beta C and A u + beta v
> h) easy and efficient iteration constructs, especially for sparse matrices
> i) easy to extend with new implementations

--
This message is automatically generated by JIRA.
-
You can reply to this email to add a comment to the issue online.