import csv
# colors
RED=1.0
BLUE=0.0
UNKNOWN=2.0
class Rect:
'''
Class for representing rectangles.
'''
def __init__(self, j, k, l, m):
self._j = j
self._k = k
self._l = l
self._m = m
@property
def coords(self):
'''
Convert from the four integer representation to the more common
opposing points representation.
'''
return ((self._j/(2**self._k), self._l/(2**self._m)),
((self._j+1)/(2**self._k), (self._l+1)/(2**self._m)))
@property
def left(self):
'''
Return a new rectangle that represents the left half of this
rectange
'''
return Rect(2*self._j, self._k+1, self._l, self._m)
@property
def right(self):
'''
Return a new rectangle that represents the right half of this
rectange
'''
return Rect(2*self._j+1, self._k+1, self._l, self._m)
@property
def bottom(self):
'''
Return a new rectangle that represents the bottom half of this
rectange
'''
return Rect(self._j, self._k, 2*self._l, self._m+1)
@property
def top(self):
'''
Return a new rectangle that represents the top half of this
rectange
'''
return Rect(self._j, self._k, 2*self._l+1, self._m+1)
@property
def x_index(self):
'''
Compute the row index for the memoization data structure.
'''
return 2**(self._k+1) + self._j - 1
@property
def y_index(self):
'''
Compute the column index for the memoization data structure.
'''
return 2**(self._m+1) + self._l - 1
def __str__(self):
return "(({}, {}), ({}, {}))".format(self._j, self._k, self._l, self._m)
class Tree_Node:
'''
Representation for a node in the decision tree.
'''
def __init__(self, R, color, cost):
self._R = R
self._color = color
self._cost = cost
self._left = None
self._right = None
self._split_on = "no_split"
@property
def R(self):
return self._R
@property
def color(self):
return self._color
@color.setter
def color(self, c):
self._color = c
@property
def cost(self):
assert self._cost >= 0
return self._cost
@cost.setter
def cost(self, c):
self._cost = c
@property
def left(self):
return self._left
@left.setter
def left(self, left):
self._left = left
@property
def right(self):
return self._right
@right.setter
def right(self, right):
self._right = right
@property
def split_on(self):
return self._split_on
@split_on.setter
def split_on(self, s):
self._split_on = s
def print_partition(self):
'''
Print the coordinates of the rectangles that correspond to the
partitioning represented by this tree.
'''
if self._split_on == "no_split":
((x0, y0), (x1, y1)) = self._R.coords
print("(({}, {}), ({}, {})): {}".format(x0, y0, x1, y1, self._color))
return
self._left.print_partition()
self._right.print_partition()
def classify(self, x, y):
'''
Classify a point using this decision tree.
'''
if self._split_on == "no_split":
return self._color
((x0, y0), (x1, y1)) = self._R.coords
if self.split_on == "x_split":
mid = (x0 + x1)/2
if x <= mid:
return self._left.classify(x, y)
else:
return self._right.classify(x, y)
else:
mid = (y0 + y1)/2
if y <= mid:
return self._left.classify(x, y)
else:
return self._right.classify(x, y)
class Dyadic:
'''
Wrapper class for computing and representing a dyadic partition.
'''
def __init__(self, data, lmbda, max_level):
self._memo = []
for i in range(2**(max_level+2)):
self._memo.append([None]*(2**(max_level+2)))
self._root = self._compute_cost(data, Rect(0, 0, 0, 0), lmbda, max_level)
def _compute_error(self, data, R):
'''
Compute the major color and the number of points from data
that are in R and have the minority color.
Returns: (color, error)
'''
((x0, y0), (x1, y1)) = R.coords
red_cnt = 0
blue_cnt = 0
for (x, y, c) in data:
if (x0 <= x < x1) and (y0 <= y < y1):
if c == RED:
red_cnt = red_cnt + 1
else:
blue_cnt = blue_cnt + 1
# return the color and the number of points that are
# mis-classified
if red_cnt > blue_cnt:
return (RED, blue_cnt)
else:
return (BLUE, red_cnt)
def _compute_cost(self, data, R, lmbda, level):
'''
Compute the optimal cost for this data and R with a leaf
weighting factor of lmbda and the specified level
Returns: a tree node.
'''
# check the memo table first
xi = R.x_index
yi = R.y_index
memo_R = self._memo[xi][yi]
if memo_R:
return memo_R
# compute the cost of making R a leaf
(color, error) = self._compute_error(data, R)
tree_node = Tree_Node(R, color, error + lmbda)
if (level == 0) or (error == 0):
self._memo[xi][yi] = tree_node
return tree_node
# compute the cost of spliting left/right
l = self._compute_cost(data, R.left, lmbda, level-1)
r = self._compute_cost(data, R.right, lmbda, level-1)
# update tree if Left/Right split is better
cost = l.cost + r.cost
if cost < tree_node.cost:
tree_node.cost = cost
tree_node.split_on = "split_x"
tree_node.left = l
tree_node.right = r
# compute the cost of spliting top/bottom
t = self._compute_cost(data, R.top, lmbda, level-1)
b = self._compute_cost(data, R.bottom, lmbda, level-1)
# update the tree if the Top/Bottom split is better
cost = t.cost + b.cost
if cost < tree_node.cost:
tree_node.cost = cost
tree_node.split_on = "split_y"
tree_node.left = b
tree_node.right = t
# update the memo table
self._memo[xi][yi] = tree_node
return tree_node
def classify(self, x, y):
'''
Classify point x, y.
'''
return root.classify(x, y)
def print_partition(self):
'''
Print the partitioning.
'''
self._root.print_partition()
def read_data(filename):
'''
Read the data from a csv file.
'''
data = []
with open(filename) as f:
data = [row for row in csv.reader(f)]
data = [[float(x) for x in row] for row in data[1:]]
return data
# quick test.
data = read_data("training.csv")
unit = Rect(0,0,0,0)
dt = Dyadic(data, 14, 10)