#!/usr/bin/env python3

from __future__ import division
from __future__ import print_function

import unittest
from copy import deepcopy

from add import add


class AddTests(unittest.TestCase):
    """
    Tests for Add.
    """

    def test_one_by_one(self):
        self.assertEqual(add([[1]], [[2]]), [[3]])
        self.assertEqual(add([[10]], [[-2]]), [[8]])

    def test_two_by_two(self):
        m1 = [[1, 2], [3, 4]]
        m2 = [[-1, -2], [-3, -4]]
        m3 = [[0, 0], [0, 0]]
        self.assertEqual(add(m1, m2), m3)

    def test_two_by_three(self):
        m1 = [[1, 2, 5], [3, 4, 10]]
        m2 = [[-1, -2, 2], [-2, -3, -3]]
        m3 = [[0, 0, 7], [1, 1, 7]]
        self.assertEqual(add(m1, m2), m3)

    def test_one_by_ten(self):
        m1 = [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9]]
        m2 = [[9], [8], [7], [6], [5], [4], [3], [2], [1], [0]]
        m3 = [[9], [9], [9], [9], [9], [9], [9], [9], [9], [9]]
        self.assertEqual(add(m1, m2), m3)

    def test_input_unchanged(self):
        m1 = [[7, 4, 7], [6, 6, 6]]
        m2 = [[1, 2, 3], [4, 5, 9]]
        m1_original = deepcopy(m1)
        m2_original = deepcopy(m2)
        add(m1, m2)
        self.assertEqual(m1, m1_original)
        self.assertEqual(m2, m2_original)

    # Comment the following line to force the check the bonus code
    @unittest.expectedFailure
    def test_any_number_of_matrixes(self):
        m1 = [[2, 4], [6, 8]]
        m2 = [[1, 1], [1, 10]]
        m3 = [[2, 1], [8, 3]]
        m4 = [[5, 6], [15, 21]]
        m5 = [[15, 21], [43, 68]]
        self.assertEqual(add(m1, m2, m3), m4)
        self.assertEqual(add(m2, m3, m1, m1, m2, m4, m1), m5)

    # Comment the following line to force the check the bonus code
    @unittest.expectedFailure
    def test_different_matrix_size(self):
        m1 = [[3], [3]]
        m2 = [[1, 2], [3, 4]]
        m3 = [[5, 6], [7, 8, 9, 10]]
        m4 = [[1], [2], [3]]
        with self.assertRaises(ValueError):
            add(m1, m2)
        with self.assertRaises(ValueError):
            add(m1, m3)
        with self.assertRaises(ValueError):
            add(m1, m1, m1, m3, m1, m1)
        with self.assertRaises(ValueError):
            add(m1, m1, m1, m2, m1, m1)
        with self.assertRaises(ValueError):
            add(m1, m4)


if __name__ == "__main__":
    unittest.main(verbosity=2)
