"""Provides some utilities widely used by other modules"""
from __future__ import annotations
import bisect
import collections
import collections.abc
import functools
import heapq
import os.path
import random
from itertools import chain, combinations
from statistics import mean
import numpy as np
# part1. General data structures and their functions
# ______________________________________________________________________________
# Queues: Stack, FIFOQueue, PriorityQueue
# Stack and FIFOQueue are implemented as list and collection.deque
# PriorityQueue is implemented here
[docs]
class PriorityQueue:
"""A Queue in which the minimum (or maximum) element (as determined by f and order) is returned first.
If order is 'min', the item with minimum f(x) is
returned first; if order is 'max', then it is the item with maximum f(x).
Also supports dict-like lookup."""
def __init__(self, order='min', f=lambda x: x):
self.heap = []
if order == 'min':
self.f = f
elif order == 'max': # now item with max f(x)
self.f = lambda x: -f(x) # will be popped first
else:
raise ValueError("Order must be either 'min' or 'max'.")
[docs]
def append(self, item):
"""Insert item at its correct position."""
heapq.heappush(self.heap, (self.f(item), item))
[docs]
def extend(self, items):
"""Insert each item in items at its correct position."""
for item in items:
self.append(item)
[docs]
def pop(self):
"""Pop and return the item (with min or max f(x) value)
depending on the order."""
if self.heap:
return heapq.heappop(self.heap)[1]
else:
raise Exception('Trying to pop from empty PriorityQueue.')
def __len__(self):
"""Return current capacity of PriorityQueue."""
return len(self.heap)
def __contains__(self, key):
"""Return True if the key is in PriorityQueue."""
return any([item == key for _, item in self.heap])
def __getitem__(self, key):
"""Returns the first value associated with key in PriorityQueue.
Raises KeyError if key is not present."""
for value, item in self.heap:
if item == key:
return value
raise KeyError(str(key) + " is not in the priority queue")
def __delitem__(self, key):
"""Delete the first occurrence of key."""
try:
del self.heap[[item == key for _, item in self.heap].index(True)]
except ValueError:
raise KeyError(str(key) + " is not in the priority queue")
heapq.heapify(self.heap)
# ______________________________________________________________________________
# Functions on Sequences and Iterables
[docs]
def sequence(iterable):
"""Converts iterable to sequence, if it is not already one."""
return (iterable if isinstance(iterable, collections.abc.Sequence)
else tuple([iterable]))
[docs]
def remove_all(item, seq):
"""Return a copy of seq (or string) with all occurrences of item removed."""
if isinstance(seq, str):
return seq.replace(item, '')
elif isinstance(seq, set):
rest = seq.copy()
rest.remove(item)
return rest
else:
return [x for x in seq if x != item]
[docs]
def unique(seq):
"""Remove duplicate elements from seq. Assumes hashable elements."""
return list(set(seq))
[docs]
def count(seq):
"""Count the number of items in sequence that are interpreted as true."""
return sum(map(bool, seq))
[docs]
def multimap(items):
"""Given (key, val) pairs, return {key: [val, ....], ...}."""
result = collections.defaultdict(list)
for (key, val) in items:
result[key].append(val)
return dict(result)
[docs]
def multimap_items(mmap):
"""Yield all (key, val) pairs stored in the multimap."""
for (key, vals) in mmap.items():
for val in vals:
yield key, val
[docs]
def product(numbers):
"""Return the product of the numbers, e.g. product([2, 3, 10]) == 60"""
result = 1
for x in numbers:
result *= x
return result
[docs]
def first(iterable, default=None):
"""Return the first element of an iterable; or default."""
return next(iter(iterable), default)
[docs]
def is_in(elt, seq):
"""Similar to (elt in seq), but compares with 'is', not '=='."""
return any(x is elt for x in seq)
[docs]
def mode(data):
"""Return the most common data item. If there are ties, return any one of them."""
[(item, count)] = collections.Counter(data).most_common(1)
return item
[docs]
def power_set(iterable):
"""power_set([1,2,3]) --> (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)"""
s = list(iterable)
return list(chain.from_iterable(combinations(s, r) for r in range(len(s) + 1)))[1:]
[docs]
def extend(s, var, val):
"""Copy dict s and extend it by setting var to val; return copy."""
return {**s, var: val}
[docs]
def flatten(seqs):
"""Flatten a sequence of sequences into a single flat list."""
return sum(seqs, [])
# ______________________________________________________________________________
# argmin and argmax
identity = lambda x: x
[docs]
def argmin_random_tie(seq, key=identity):
"""Return a minimum element of seq; break ties at random."""
return min(shuffled(seq), key=key)
[docs]
def argmax_random_tie(seq, key=identity):
"""Return an element with highest fn(seq[i]) score; break ties at random."""
return max(shuffled(seq), key=key)
[docs]
def shuffled(iterable):
"""Randomly shuffle a copy of iterable."""
items = list(iterable)
random.shuffle(items)
return items
# part2. Mathematical and Statistical util functions
# ______________________________________________________________________________
[docs]
def histogram(values, mode=0, bin_function=None):
"""Return a list of (value, count) pairs, summarizing the input values.
Sorted by increasing value, or if mode=1, by decreasing count.
If bin_function is given, map it over values first."""
if bin_function:
values = map(bin_function, values)
bins = {}
for val in values:
bins[val] = bins.get(val, 0) + 1
if mode:
return sorted(list(bins.items()), key=lambda x: (x[1], x[0]), reverse=True)
else:
return sorted(bins.items())
[docs]
def element_wise_product(x, y):
"""Return the element-wise product of x and y, recursing into nested iterables.
Scalars are multiplied directly; iterables must have matching lengths.
"""
if hasattr(x, '__iter__') and hasattr(y, '__iter__'):
assert len(x) == len(y)
return [element_wise_product(_x, _y) for _x, _y in zip(x, y)]
elif hasattr(x, '__iter__') == hasattr(y, '__iter__'):
return x * y
else:
raise Exception('Inputs must be in the same size!')
[docs]
def vector_add(a, b):
"""Component-wise addition of two vectors."""
if not (a and b):
return a or b
if hasattr(a, '__iter__') and hasattr(b, '__iter__'):
assert len(a) == len(b)
return list(map(vector_add, a, b))
else:
try:
return a + b
except TypeError:
raise Exception('Inputs must be in the same size!')
[docs]
def scalar_vector_product(x, y):
"""Return vector as a product of a scalar and a vector recursively."""
return [scalar_vector_product(x, _y) for _y in y] if hasattr(y, '__iter__') else x * y
[docs]
def map_vector(f, x):
"""Apply function f to iterable x."""
return [map_vector(f, _x) for _x in x] if hasattr(x, '__iter__') else list(map(f, [x]))[0]
[docs]
def probability(p: float) -> bool:
"""Return true with probability p."""
return p > random.uniform(0.0, 1.0)
[docs]
def weighted_sample_with_replacement(n, seq, weights):
"""Pick n samples from seq at random, with replacement, with the
probability of each element in proportion to its corresponding
weight."""
sample = weighted_sampler(seq, weights)
return [sample() for _ in range(n)]
[docs]
def weighted_sampler(seq, weights):
"""Return a random-sample function that picks from seq weighted by weights."""
totals = []
for w in weights:
totals.append(w + totals[-1] if totals else w)
return lambda: seq[bisect.bisect(totals, random.uniform(0, totals[-1]))]
[docs]
def weighted_choice(choices):
"""A weighted version of random.choice"""
# NOTE: Should be replaced by random.choices if we port to Python 3.6
total = sum(w for _, w in choices)
r = random.uniform(0, total)
upto = 0
for c, w in choices:
if upto + w >= r:
return c, w
upto += w
[docs]
def rounder(numbers, d=4):
"""Round a single number, or sequence of numbers, to d decimal places."""
if isinstance(numbers, (int, float)):
return round(numbers, d)
else:
constructor = type(numbers) # Can be list, set, tuple, etc.
return constructor(rounder(n, d) for n in numbers)
[docs]
def num_or_str(x: str) -> int | float | str: # TODO: rename as `atom`
"""The argument is a string; convert to a number if
possible, or strip it."""
try:
return int(x)
except ValueError:
try:
return float(x)
except ValueError:
return str(x).strip()
[docs]
def euclidean_distance(x, y) -> float:
"""Return the Euclidean (L2) distance between vectors x and y."""
return np.sqrt(sum((_x - _y) ** 2 for _x, _y in zip(x, y)))
[docs]
def manhattan_distance(x, y) -> float:
"""Return the Manhattan (L1) distance between vectors x and y."""
return sum(abs(_x - _y) for _x, _y in zip(x, y))
[docs]
def hamming_distance(x, y) -> int:
"""Return the number of positions at which vectors x and y differ."""
return sum(_x != _y for _x, _y in zip(x, y))
[docs]
def rms_error(x, y) -> float:
"""Return the root-mean-square error between vectors x and y."""
return np.sqrt(ms_error(x, y))
[docs]
def ms_error(x, y) -> float:
"""Return the mean of the squared differences between vectors x and y."""
return mean((x - y) ** 2 for x, y in zip(x, y))
[docs]
def mean_error(x, y) -> float:
"""Return the mean of the absolute differences between vectors x and y."""
return mean(abs(x - y) for x, y in zip(x, y))
[docs]
def mean_boolean_error(x, y) -> float:
"""Return the fraction of positions at which vectors x and y differ."""
return mean(_x != _y for _x, _y in zip(x, y))
# part3. Neural network util functions
# ______________________________________________________________________________
[docs]
def cross_entropy_loss(x, y) -> float:
"""Cross entropy loss function. x and y are 1D iterable objects."""
return (-1.0 / len(x)) * sum(_x * np.log(_y) + (1 - _x) * np.log(1 - _y) for _x, _y in zip(x, y))
[docs]
def mean_squared_error_loss(x, y) -> float:
"""Min square loss function. x and y are 1D iterable objects."""
return (1.0 / len(x)) * sum((_x - _y) ** 2 for _x, _y in zip(x, y))
[docs]
def normalize(dist):
"""Multiply each number by a constant such that the sum is 1.0"""
if isinstance(dist, dict):
total = sum(dist.values())
for key in dist:
dist[key] = dist[key] / total
assert 0 <= dist[key] <= 1 # probabilities must be between 0 and 1
return dist
total = sum(dist)
return [(n / total) for n in dist]
[docs]
def random_weights(min_value: float, max_value: float, num_weights: int) -> list:
"""Return a list of num_weights random floats drawn uniformly from [min_value, max_value]."""
return [random.uniform(min_value, max_value) for _ in range(num_weights)]
[docs]
def conv1D(x, k):
"""1D convolution. x: input vector; K: kernel vector."""
return np.convolve(x, k, mode='same')
[docs]
def gaussian_kernel(size=3):
"""Return a length-size 1D Gaussian kernel centred at the middle (fixed st_dev 0.1)."""
return [gaussian((size - 1) / 2, 0.1, x) for x in range(size)]
[docs]
def gaussian_kernel_1D(size=3, sigma=0.5):
"""Return a length-size 1D Gaussian kernel centred at the middle with st_dev sigma."""
return [gaussian((size - 1) / 2, sigma, x) for x in range(size)]
[docs]
def gaussian_kernel_2D(size=3, sigma=0.5):
"""Return a size x size 2D Gaussian kernel with st_dev sigma, normalized to sum to 1."""
x, y = np.mgrid[-size // 2 + 1:size // 2 + 1, -size // 2 + 1:size // 2 + 1]
g = np.exp(-((x ** 2 + y ** 2) / (2.0 * sigma ** 2)))
return g / g.sum()
[docs]
def step(x: float) -> int:
"""Return activation value of x with sign function."""
return 1 if x >= 0 else 0
[docs]
def gaussian(mean: float, st_dev: float, x: float) -> float:
"""Given the mean and standard deviation of a distribution, it returns the probability of x."""
return 1 / (np.sqrt(2 * np.pi) * st_dev) * np.exp(-0.5 * (float(x - mean) / st_dev) ** 2)
[docs]
def linear_kernel(x, y=None):
"""Return the linear kernel (dot product) between x and y; defaults y to x."""
if y is None:
y = x
return np.dot(x, y.T)
[docs]
def polynomial_kernel(x, y=None, degree=2.0):
"""Return the polynomial kernel (1 + x.y)**degree between x and y; defaults y to x."""
if y is None:
y = x
return (1.0 + np.dot(x, y.T)) ** degree
[docs]
def rbf_kernel(x, y=None, gamma=None):
"""Radial-basis function kernel (aka squared-exponential kernel)."""
if y is None:
y = x
if gamma is None:
gamma = 1.0 / x.shape[1] # 1.0 / n_features
return np.exp(-gamma * (-2.0 * np.dot(x, y.T) +
np.sum(x * x, axis=1).reshape((-1, 1)) + np.sum(y * y, axis=1).reshape((1, -1))))
# part4. Self defined data structures
# ______________________________________________________________________________
# Grid Functions
orientations = EAST, NORTH, WEST, SOUTH = [(1, 0), (0, 1), (-1, 0), (0, -1)]
turns = LEFT, RIGHT = (+1, -1)
[docs]
def turn_heading(heading, inc, headings=orientations):
"""Return the heading reached by turning inc steps around the list of headings."""
return headings[(headings.index(heading) + inc) % len(headings)]
[docs]
def turn_right(heading):
"""Return the heading obtained by turning right (clockwise) from heading."""
return turn_heading(heading, RIGHT)
[docs]
def turn_left(heading):
"""Return the heading obtained by turning left (counter-clockwise) from heading."""
return turn_heading(heading, LEFT)
[docs]
def distance(a, b):
"""The distance between two (x, y) points."""
xA, yA = a
xB, yB = b
return np.hypot((xA - xB), (yA - yB))
[docs]
def distance_squared(a, b):
"""The square of the distance between two (x, y) points."""
xA, yA = a
xB, yB = b
return (xA - xB) ** 2 + (yA - yB) ** 2
# ______________________________________________________________________________
# Misc Functions
[docs]
class injection:
"""Dependency injection of temporary values for global functions/classes/etc.
E.g., `with injection(DataBase=MockDataBase): ...`"""
def __init__(self, **kwds):
self.new = kwds
def __enter__(self):
self.old = {v: globals()[v] for v in self.new}
globals().update(self.new)
def __exit__(self, type, value, traceback):
globals().update(self.old)
[docs]
def memoize(fn, slot=None, maxsize=32):
"""Memoize fn: make it remember the computed value for any argument list.
If slot is specified, store result in that slot of first argument.
If slot is false, use lru_cache for caching the values."""
if slot:
def memoized_fn(obj, *args):
if hasattr(obj, slot):
return getattr(obj, slot)
else:
val = fn(obj, *args)
setattr(obj, slot, val)
return val
else:
@functools.lru_cache(maxsize=maxsize)
def memoized_fn(*args):
return fn(*args)
return memoized_fn
[docs]
def name(obj):
"""Try to find some reasonable name for the object."""
return (getattr(obj, 'name', 0) or getattr(obj, '__name__', 0) or
getattr(getattr(obj, '__class__', 0), '__name__', 0) or
str(obj))
[docs]
def isnumber(x):
"""Is x a number?"""
return hasattr(x, '__int__')
[docs]
def issequence(x):
"""Is x a sequence?"""
return isinstance(x, collections.abc.Sequence)
[docs]
def print_table(table, header=None, sep=' ', numfmt='{}'):
"""Print a list of lists as a table, so that columns line up nicely.
header, if specified, will be printed as the first row.
numfmt is the format for all numbers; you might want e.g. '{:.2f}'.
(If you want different formats in different columns,
don't use print_table.) sep is the separator between columns."""
justs = ['rjust' if isnumber(x) else 'ljust' for x in table[0]]
if header:
table.insert(0, header)
table = [[numfmt.format(x) if isnumber(x) else x for x in row]
for row in table]
sizes = list(
map(lambda seq: max(map(len, seq)),
list(zip(*[map(str, row) for row in table]))))
for row in table:
print(sep.join(getattr(
str(x), j)(size) for (j, size, x) in zip(justs, sizes, row)))
[docs]
def open_data(name, mode='r'):
"""Open and return the file named name from the aima-data directory."""
aima_root = os.path.dirname(__file__)
aima_file = os.path.join(aima_root, *['aima-data', name])
return open(aima_file, mode=mode)
[docs]
def failure_test(algorithm, tests):
"""Grades the given algorithm based on how many tests it passes.
Most algorithms have arbitrary output on correct execution, which is difficult
to check for correctness. On the other hand, a lot of algorithms output something
particular on fail (for example, False, or None).
tests is a list with each element in the form: (values, failure_output)."""
return mean(int(algorithm(x) != y) for x, y in tests)
# ______________________________________________________________________________
# Expressions
# See https://docs.python.org/3/reference/expressions.html#operator-precedence
# See https://docs.python.org/3/reference/datamodel.html#special-method-names
[docs]
class Expr:
"""A mathematical expression with an operator and 0 or more arguments.
op is a str like '+' or 'sin'; args are Expressions.
Expr('x') or Symbol('x') creates a symbol (a nullary Expr).
Expr('-', x) creates a unary; Expr('+', x, 1) creates a binary."""
def __init__(self, op, *args):
self.op = str(op)
self.args = args
# Operator overloads
def __neg__(self):
return Expr('-', self)
def __pos__(self):
return Expr('+', self)
def __invert__(self):
return Expr('~', self)
def __add__(self, rhs):
return Expr('+', self, rhs)
def __sub__(self, rhs):
return Expr('-', self, rhs)
def __mul__(self, rhs):
return Expr('*', self, rhs)
def __pow__(self, rhs):
return Expr('**', self, rhs)
def __mod__(self, rhs):
return Expr('%', self, rhs)
def __and__(self, rhs):
return Expr('&', self, rhs)
def __xor__(self, rhs):
return Expr('^', self, rhs)
def __rshift__(self, rhs):
return Expr('>>', self, rhs)
def __lshift__(self, rhs):
return Expr('<<', self, rhs)
def __truediv__(self, rhs):
return Expr('/', self, rhs)
def __floordiv__(self, rhs):
return Expr('//', self, rhs)
def __matmul__(self, rhs):
return Expr('@', self, rhs)
def __or__(self, rhs):
"""Allow both P | Q, and P |'==>'| Q."""
if isinstance(rhs, Expression):
return Expr('|', self, rhs)
else:
return PartialExpr(rhs, self)
# Reverse operator overloads
def __radd__(self, lhs):
return Expr('+', lhs, self)
def __rsub__(self, lhs):
return Expr('-', lhs, self)
def __rmul__(self, lhs):
return Expr('*', lhs, self)
def __rdiv__(self, lhs):
return Expr('/', lhs, self)
def __rpow__(self, lhs):
return Expr('**', lhs, self)
def __rmod__(self, lhs):
return Expr('%', lhs, self)
def __rand__(self, lhs):
return Expr('&', lhs, self)
def __rxor__(self, lhs):
return Expr('^', lhs, self)
def __ror__(self, lhs):
return Expr('|', lhs, self)
def __rrshift__(self, lhs):
return Expr('>>', lhs, self)
def __rlshift__(self, lhs):
return Expr('<<', lhs, self)
def __rtruediv__(self, lhs):
return Expr('/', lhs, self)
def __rfloordiv__(self, lhs):
return Expr('//', lhs, self)
def __rmatmul__(self, lhs):
return Expr('@', lhs, self)
def __call__(self, *args):
"""Call: if 'f' is a Symbol, then f(0) == Expr('f', 0)."""
if self.args:
raise ValueError('Can only do a call for a Symbol, not an Expr')
else:
return Expr(self.op, *args)
# Equality and repr
def __eq__(self, other):
"""'x == y' evaluates to True or False; does not build an Expr."""
return isinstance(other, Expr) and self.op == other.op and self.args == other.args
def __lt__(self, other):
return isinstance(other, Expr) and str(self) < str(other)
def __hash__(self):
return hash(self.op) ^ hash(self.args)
def __repr__(self):
op = self.op
args = [str(arg) for arg in self.args]
if op.isidentifier(): # f(x) or f(x, y)
return '{}({})'.format(op, ', '.join(args)) if args else op
elif len(args) == 1: # -x or -(x + 1)
return op + args[0]
else: # (x - y)
opp = (' ' + op + ' ')
return '(' + opp.join(args) + ')'
# An 'Expression' is either an Expr or a Number.
# Symbol is not an explicit type; it is any Expr with 0 args.
Number = (int, float, complex)
Expression = (Expr, Number)
[docs]
def Symbol(name):
"""A Symbol is just an Expr with no args."""
return Expr(name)
[docs]
def symbols(names):
"""Return a tuple of Symbols; names is a comma/whitespace delimited str."""
return tuple(Symbol(name) for name in names.replace(',', ' ').split())
[docs]
def subexpressions(x):
"""Yield the subexpressions of an Expression (including x itself)."""
yield x
if isinstance(x, Expr):
for arg in x.args:
yield from subexpressions(arg)
[docs]
def arity(expression):
"""The number of sub-expressions in this expression."""
if isinstance(expression, Expr):
return len(expression.args)
else: # expression is a number
return 0
# For operators that are not defined in Python, we allow new InfixOps:
[docs]
class PartialExpr:
r"""Given 'P \|'==>'\| Q, first form PartialExpr('==>', P), then combine with Q."""
def __init__(self, op, lhs):
self.op, self.lhs = op, lhs
def __or__(self, rhs):
return Expr(self.op, self.lhs, rhs)
def __repr__(self):
return "PartialExpr('{}', {})".format(self.op, self.lhs)
[docs]
def expr(x):
r"""Shortcut to create an Expression. x is a str in which:
- identifiers are automatically defined as Symbols.
- ==> is treated as an infix \|'==>'\|, as are <== and <=>.
If x is already an Expression, it is returned unchanged. Example:
>>> expr('P & Q ==> Q')
((P & Q) ==> Q)
"""
if isinstance(x, str):
return eval(expr_handle_infix_ops(x), defaultkeydict(Symbol))
else:
return x
infix_ops = '==> <== <=>'.split()
[docs]
def expr_handle_infix_ops(x):
r"""Given a str, return a new str with ==> replaced by \|'==>'\|, etc.
>>> expr_handle_infix_ops('P ==> Q')
"P |'==>'| Q"
"""
for op in infix_ops:
x = x.replace(op, '|' + repr(op) + '|')
return x
[docs]
class defaultkeydict(collections.defaultdict):
"""Like defaultdict, but the default_factory is a function of the key.
>>> d = defaultkeydict(len); d['four']
4
"""
def __missing__(self, key):
self[key] = result = self.default_factory(key)
return result
[docs]
class hashabledict(dict):
"""Allows hashing by representing a dictionary as tuple of key:value pairs.
May cause problems as the hash value may change during runtime."""
def __hash__(self):
return 1
# ______________________________________________________________________________
# Monte Carlo tree node and ucb function
[docs]
class MCT_Node:
"""Node in the Monte Carlo search tree, keeps track of the children states."""
def __init__(self, parent=None, state=None, U=0, N=0):
self.__dict__.update(parent=parent, state=state, U=U, N=N)
self.children = {}
self.actions = None
[docs]
def ucb(n, C=1.4):
"""Return the UCB1 score of node n (exploitation plus C-weighted exploration term).
Unvisited nodes (n.N == 0) score infinity so they are selected first; used to guide
selection in Monte Carlo tree search.
"""
return np.inf if n.N == 0 else n.U / n.N + C * np.sqrt(np.log(n.parent.N) / n.N)
# ______________________________________________________________________________
# Useful Shorthands
[docs]
class Bool(int):
"""Just like `bool`, except values display as 'T' and 'F' instead of 'True' and 'False'."""
__str__ = __repr__ = lambda self: 'T' if self else 'F'
T = Bool(True)
F = Bool(False)