Added some real accessors.
[pseudoku.git] / pseudoku / grid / __init__.py
index 2f77aaf..a62f18f 100644 (file)
@@ -1,10 +1,11 @@
 from __future__ import division
 
 from math import sqrt
+from operator import attrgetter
 import re
-from weakref import proxy
+from weakref import ref
 
-from cellgroup import Row, Column, Box
+from cellgroup import CellConstraint, Row, Column, Box
 
 symbols = [str(x + 1) for x in range(9)] + [chr(x + 97) for x in xrange(26)]
 
@@ -35,46 +36,28 @@ class Cell(object):
         return None
     value = property(_get_value)
 
-    def _get_row(self):
-        """Returns the Row object associated with this cell."""
-        return self._grid._rows[self._row]
-    row = property(_get_row)
-
-    def _get_column(self):
-        """Returns the Column object associated with this cell."""
-        return self._grid._columns[self._col]
-    column = property(_get_column)
-
-    def _get_box(self):
-        """Returns the Box object associated with this cell."""
-        # Some actual math required here!
-        # Row 0..2 -> box 0..2
-        # Col 0..2 -> box 0, 3, 6 (box col 0)
-        box_row = self._row // self._grid._box_height
-        box_col = self._col // self._grid._box_width
-        box_idx = box_row * self._grid._box_height + box_col
-        return self._grid._boxes[box_idx]
-    box = property(_get_box)
+    grid = property(lambda self: self._grid())
+    constraints = property(attrgetter('_constraints'))
 
     def __init__(self, grid, row, column):
-        self._grid = proxy(grid)
+        self._grid = ref(grid)
         self._row = row
         self._col = column
-        self._values = range(self._grid.size)
+        self._values = range(self.grid.size)
+        self._constraints = []
         self._normalized = False
 
-    def set_naively(self, value):
-        """Sets the value of this cell, WITHOUT eliminating the value from
-        every other cell in its row/column/box.
-        """
+    def add_constraint(self, constraint):
+        self._constraints.append(constraint)
 
+    def set(self, value, normalize=True):
+        """Sets the value of this cell.  If `normalize` is True or omitted, the
+        grid will be updated accordingly.
+        """
         self._values = [value]
-
-    def set(self, value):
-        """Sets the value of this cell and adjusts the grid accordingly."""
-        self.set_naively(value)
-        self._normalized = False
-        self.normalize()
+        if normalize:
+            self._normalized = False
+            self.normalize()
 
 
 
@@ -96,9 +79,8 @@ class Cell(object):
             return
 
         # Elimination time
-        for group_type in 'row', 'column', 'box':
-            group = getattr(self, group_type)
-            for cell in group.cells:
+        for constraint in self.constraints:
+            for cell in constraint.cells:
                 if cell == self:
                     continue
                 cell.eliminate(self.value)
@@ -117,14 +99,6 @@ class Cell(object):
             self.normalize()
 
 
-    def __str__(self):
-        """Stringification for pretty-printing."""
-        if self.value != None:
-            return symbols[self.value]
-
-        return '.'
-
-
 class Grid(object):
     """Represents a Sudoku grid."""
 
@@ -132,7 +106,7 @@ class Grid(object):
 
     def _cellidx(self, row, col):
         """Hashes a row and column into a flat array index."""
-        return row * self._size + col
+        return row * self.size + col
 
     @classmethod
     def _infer_box_size(cls, dimension):
@@ -170,21 +144,19 @@ class Grid(object):
 
     ### Accessors
 
-    def _get_box_height(self):
-        return self._box_height
-    box_height = property(_get_box_height)
-
-    def _get_box_width(self):
-        return self._box_width
-    box_width = property(_get_box_width)
+    def get_constraints(self, constraint_class=CellConstraint):
+        """Returns constraints of a certain type.  Returns all of them by
+        default.
+        """
 
-    def _get_size(self):
-        return self._size
-    size = property(_get_size)
+        condition = lambda constraint: isinstance(constraint, constraint_class)
+        return filter(condition, self._constraints)
 
-    def _get_cell_groups(self):
-        return self._rows + self._columns + self._boxes
-    cell_groups = property(_get_cell_groups)
+    rows = property(lambda self: self.get_constraints(Row))
+    box_height = property(attrgetter('_box_height'))
+    box_width = property(attrgetter('_box_width'))
+    size = property(attrgetter('_size'))
+    constraints = property(attrgetter('_constraints'))
 
     ### Constructors
 
@@ -196,16 +168,15 @@ class Grid(object):
         self._box_width = box_width
         self._size = box_height * box_width
 
-        self._rows = [Row(self, i) for i in xrange(self._size)]
-        self._columns = [Column(self, i) for i in xrange(self._size)]
-        self._boxes = [Box(self, i) for i in xrange(self._size)]
-
-        self._cells = range(self._size ** 2)
-        for row in xrange(self._size):
-            for col in xrange(self._size):
+        self._cells = range(self.size ** 2)
+        for row in xrange(self.size):
+            for col in xrange(self.size):
                 self._cells[self._cellidx(row, col)] \
                     = Cell(self, row, col)
 
+        self._constraints = []
+        self.add_default_constraints()
+
     @classmethod
     def from_matrix(cls, rows, box_height=None, box_width=None):
         """Creates and returns a grid read from a list of lists."""
@@ -217,12 +188,12 @@ class Grid(object):
 
         self = cls(box_width=box_width, box_height=box_height)
 
-        for row in xrange(self._size):
-            for col in xrange(self._size):
+        for row in xrange(self.size):
+            for col in xrange(self.size):
                 value = rows[row][col]
                 if not value:
                     continue
-                self.cell(row, col).set_naively(value - 1)
+                self.cell(row, col).set(value - 1, normalize=False)
 
         return self
 
@@ -256,15 +227,30 @@ class Grid(object):
 
         self = cls(box_width=box_width, box_height=box_height)
 
-        for row in xrange(self._size):
-            for col in xrange(self._size):
+        for row in xrange(self.size):
+            for col in xrange(self.size):
                 ch = grid[ self._cellidx(row, col) ]
                 if ch == '0':
                     continue
-                self.cell(row, col).set_naively(symbols.index(ch))
+                self.cell(row, col).set(symbols.index(ch), normalize=False)
 
         return self
 
+    ### Constraints
+
+    def add_constraint(self, constraint):
+        self._constraints.append(constraint)
+        for cell in constraint.cells:
+            cell.add_constraint(constraint)
+
+    def add_default_constraints(self):
+        for i in xrange(self.size):
+            self.add_constraint(Row(self, i))
+            self.add_constraint(Column(self, i))
+            self.add_constraint(Box(self, i))
+
+        return
+
     ### Inspectors
 
     def cell(self, row, column):
@@ -294,7 +280,7 @@ class Grid(object):
         self.normalize_cells()
 
         # Step 1: Find values that can only go in one cell in a group
-        for group in self.cell_groups:
+        for group in self.constraints:
             group.resolve_uniques()
 
 
@@ -302,31 +288,3 @@ class Grid(object):
         """Normalizes every cell in the grid."""
         for cell in self._cells:
             cell.normalize()
-
-
-    def __str__(self):
-        """Pretty-printing."""
-        divider = '+'
-        for box in xrange(self._box_height):
-            for col in xrange(self._box_width):
-                divider += '-'
-            divider += '+'
-
-        res = ''
-        for row in xrange(self._size):
-            if row % self._box_height == 0:
-                res += divider
-                res += "\n"
-
-            for col in xrange(self._size):
-                if col % self._box_width == 0:
-                    res += '|'
-                res += str(self.cell(row, col))
-
-            res += '|'
-            res += "\n"
-
-        res += divider
-        res += "\n"
-
-        return res