Source code for utils4e

"""Provides some utilities widely used by other modules"""

from __future__ import annotations

import bisect
import collections
import collections.abc
import functools
import heapq
import os.path
import random
from itertools import chain, combinations
from statistics import mean

import numpy as np


# part1. General data structures and their functions
# ______________________________________________________________________________
# Queues: Stack, FIFOQueue, PriorityQueue
# Stack and FIFOQueue are implemented as list and collection.deque
# PriorityQueue is implemented here


[docs] class PriorityQueue: """A Queue in which the minimum (or maximum) element (as determined by f and order) is returned first. If order is 'min', the item with minimum f(x) is returned first; if order is 'max', then it is the item with maximum f(x). Also supports dict-like lookup.""" def __init__(self, order='min', f=lambda x: x): self.heap = [] if order == 'min': self.f = f elif order == 'max': # now item with max f(x) self.f = lambda x: -f(x) # will be popped first else: raise ValueError("Order must be either 'min' or 'max'.")
[docs] def append(self, item): """Insert item at its correct position.""" heapq.heappush(self.heap, (self.f(item), item))
[docs] def extend(self, items): """Insert each item in items at its correct position.""" for item in items: self.append(item)
[docs] def pop(self): """Pop and return the item (with min or max f(x) value) depending on the order.""" if self.heap: return heapq.heappop(self.heap)[1] else: raise Exception('Trying to pop from empty PriorityQueue.')
def __len__(self): """Return current capacity of PriorityQueue.""" return len(self.heap) def __contains__(self, key): """Return True if the key is in PriorityQueue.""" return any([item == key for _, item in self.heap]) def __getitem__(self, key): """Returns the first value associated with key in PriorityQueue. Raises KeyError if key is not present.""" for value, item in self.heap: if item == key: return value raise KeyError(str(key) + " is not in the priority queue") def __delitem__(self, key): """Delete the first occurrence of key.""" try: del self.heap[[item == key for _, item in self.heap].index(True)] except ValueError: raise KeyError(str(key) + " is not in the priority queue") heapq.heapify(self.heap)
# ______________________________________________________________________________ # Functions on Sequences and Iterables
[docs] def sequence(iterable): """Converts iterable to sequence, if it is not already one.""" return (iterable if isinstance(iterable, collections.abc.Sequence) else tuple([iterable]))
[docs] def remove_all(item, seq): """Return a copy of seq (or string) with all occurrences of item removed.""" if isinstance(seq, str): return seq.replace(item, '') elif isinstance(seq, set): rest = seq.copy() rest.remove(item) return rest else: return [x for x in seq if x != item]
[docs] def unique(seq): """Remove duplicate elements from seq. Assumes hashable elements.""" return list(set(seq))
[docs] def count(seq): """Count the number of items in sequence that are interpreted as true.""" return sum(map(bool, seq))
[docs] def multimap(items): """Given (key, val) pairs, return {key: [val, ....], ...}.""" result = collections.defaultdict(list) for (key, val) in items: result[key].append(val) return dict(result)
[docs] def multimap_items(mmap): """Yield all (key, val) pairs stored in the multimap.""" for (key, vals) in mmap.items(): for val in vals: yield key, val
[docs] def product(numbers): """Return the product of the numbers, e.g. product([2, 3, 10]) == 60""" result = 1 for x in numbers: result *= x return result
[docs] def first(iterable, default=None): """Return the first element of an iterable; or default.""" return next(iter(iterable), default)
[docs] def is_in(elt, seq): """Similar to (elt in seq), but compares with 'is', not '=='.""" return any(x is elt for x in seq)
[docs] def mode(data): """Return the most common data item. If there are ties, return any one of them.""" [(item, count)] = collections.Counter(data).most_common(1) return item
[docs] def power_set(iterable): """power_set([1,2,3]) --> (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)""" s = list(iterable) return list(chain.from_iterable(combinations(s, r) for r in range(len(s) + 1)))[1:]
[docs] def extend(s, var, val): """Copy dict s and extend it by setting var to val; return copy.""" return {**s, var: val}
[docs] def flatten(seqs): """Flatten a sequence of sequences into a single flat list.""" return sum(seqs, [])
# ______________________________________________________________________________ # argmin and argmax identity = lambda x: x
[docs] def argmin_random_tie(seq, key=identity): """Return a minimum element of seq; break ties at random.""" return min(shuffled(seq), key=key)
[docs] def argmax_random_tie(seq, key=identity): """Return an element with highest fn(seq[i]) score; break ties at random.""" return max(shuffled(seq), key=key)
[docs] def shuffled(iterable): """Randomly shuffle a copy of iterable.""" items = list(iterable) random.shuffle(items) return items
# part2. Mathematical and Statistical util functions # ______________________________________________________________________________
[docs] def histogram(values, mode=0, bin_function=None): """Return a list of (value, count) pairs, summarizing the input values. Sorted by increasing value, or if mode=1, by decreasing count. If bin_function is given, map it over values first.""" if bin_function: values = map(bin_function, values) bins = {} for val in values: bins[val] = bins.get(val, 0) + 1 if mode: return sorted(list(bins.items()), key=lambda x: (x[1], x[0]), reverse=True) else: return sorted(bins.items())
[docs] def element_wise_product(x, y): """Return the element-wise product of x and y, recursing into nested iterables. Scalars are multiplied directly; iterables must have matching lengths. """ if hasattr(x, '__iter__') and hasattr(y, '__iter__'): assert len(x) == len(y) return [element_wise_product(_x, _y) for _x, _y in zip(x, y)] elif hasattr(x, '__iter__') == hasattr(y, '__iter__'): return x * y else: raise Exception('Inputs must be in the same size!')
[docs] def vector_add(a, b): """Component-wise addition of two vectors.""" if not (a and b): return a or b if hasattr(a, '__iter__') and hasattr(b, '__iter__'): assert len(a) == len(b) return list(map(vector_add, a, b)) else: try: return a + b except TypeError: raise Exception('Inputs must be in the same size!')
[docs] def scalar_vector_product(x, y): """Return vector as a product of a scalar and a vector recursively.""" return [scalar_vector_product(x, _y) for _y in y] if hasattr(y, '__iter__') else x * y
[docs] def map_vector(f, x): """Apply function f to iterable x.""" return [map_vector(f, _x) for _x in x] if hasattr(x, '__iter__') else list(map(f, [x]))[0]
[docs] def probability(p: float) -> bool: """Return true with probability p.""" return p > random.uniform(0.0, 1.0)
[docs] def weighted_sample_with_replacement(n, seq, weights): """Pick n samples from seq at random, with replacement, with the probability of each element in proportion to its corresponding weight.""" sample = weighted_sampler(seq, weights) return [sample() for _ in range(n)]
[docs] def weighted_sampler(seq, weights): """Return a random-sample function that picks from seq weighted by weights.""" totals = [] for w in weights: totals.append(w + totals[-1] if totals else w) return lambda: seq[bisect.bisect(totals, random.uniform(0, totals[-1]))]
[docs] def weighted_choice(choices): """A weighted version of random.choice""" # NOTE: Should be replaced by random.choices if we port to Python 3.6 total = sum(w for _, w in choices) r = random.uniform(0, total) upto = 0 for c, w in choices: if upto + w >= r: return c, w upto += w
[docs] def rounder(numbers, d=4): """Round a single number, or sequence of numbers, to d decimal places.""" if isinstance(numbers, (int, float)): return round(numbers, d) else: constructor = type(numbers) # Can be list, set, tuple, etc. return constructor(rounder(n, d) for n in numbers)
[docs] def num_or_str(x: str) -> int | float | str: # TODO: rename as `atom` """The argument is a string; convert to a number if possible, or strip it.""" try: return int(x) except ValueError: try: return float(x) except ValueError: return str(x).strip()
[docs] def euclidean_distance(x, y) -> float: """Return the Euclidean (L2) distance between vectors x and y.""" return np.sqrt(sum((_x - _y) ** 2 for _x, _y in zip(x, y)))
[docs] def manhattan_distance(x, y) -> float: """Return the Manhattan (L1) distance between vectors x and y.""" return sum(abs(_x - _y) for _x, _y in zip(x, y))
[docs] def hamming_distance(x, y) -> int: """Return the number of positions at which vectors x and y differ.""" return sum(_x != _y for _x, _y in zip(x, y))
[docs] def rms_error(x, y) -> float: """Return the root-mean-square error between vectors x and y.""" return np.sqrt(ms_error(x, y))
[docs] def ms_error(x, y) -> float: """Return the mean of the squared differences between vectors x and y.""" return mean((x - y) ** 2 for x, y in zip(x, y))
[docs] def mean_error(x, y) -> float: """Return the mean of the absolute differences between vectors x and y.""" return mean(abs(x - y) for x, y in zip(x, y))
[docs] def mean_boolean_error(x, y) -> float: """Return the fraction of positions at which vectors x and y differ.""" return mean(_x != _y for _x, _y in zip(x, y))
# part3. Neural network util functions # ______________________________________________________________________________
[docs] def cross_entropy_loss(x, y) -> float: """Cross entropy loss function. x and y are 1D iterable objects.""" return (-1.0 / len(x)) * sum(_x * np.log(_y) + (1 - _x) * np.log(1 - _y) for _x, _y in zip(x, y))
[docs] def mean_squared_error_loss(x, y) -> float: """Min square loss function. x and y are 1D iterable objects.""" return (1.0 / len(x)) * sum((_x - _y) ** 2 for _x, _y in zip(x, y))
[docs] def normalize(dist): """Multiply each number by a constant such that the sum is 1.0""" if isinstance(dist, dict): total = sum(dist.values()) for key in dist: dist[key] = dist[key] / total assert 0 <= dist[key] <= 1 # probabilities must be between 0 and 1 return dist total = sum(dist) return [(n / total) for n in dist]
[docs] def random_weights(min_value: float, max_value: float, num_weights: int) -> list: """Return a list of num_weights random floats drawn uniformly from [min_value, max_value].""" return [random.uniform(min_value, max_value) for _ in range(num_weights)]
[docs] def conv1D(x, k): """1D convolution. x: input vector; K: kernel vector.""" return np.convolve(x, k, mode='same')
[docs] def gaussian_kernel(size=3): """Return a length-size 1D Gaussian kernel centred at the middle (fixed st_dev 0.1).""" return [gaussian((size - 1) / 2, 0.1, x) for x in range(size)]
[docs] def gaussian_kernel_1D(size=3, sigma=0.5): """Return a length-size 1D Gaussian kernel centred at the middle with st_dev sigma.""" return [gaussian((size - 1) / 2, sigma, x) for x in range(size)]
[docs] def gaussian_kernel_2D(size=3, sigma=0.5): """Return a size x size 2D Gaussian kernel with st_dev sigma, normalized to sum to 1.""" x, y = np.mgrid[-size // 2 + 1:size // 2 + 1, -size // 2 + 1:size // 2 + 1] g = np.exp(-((x ** 2 + y ** 2) / (2.0 * sigma ** 2))) return g / g.sum()
[docs] def step(x: float) -> int: """Return activation value of x with sign function.""" return 1 if x >= 0 else 0
[docs] def gaussian(mean: float, st_dev: float, x: float) -> float: """Given the mean and standard deviation of a distribution, it returns the probability of x.""" return 1 / (np.sqrt(2 * np.pi) * st_dev) * np.exp(-0.5 * (float(x - mean) / st_dev) ** 2)
[docs] def linear_kernel(x, y=None): """Return the linear kernel (dot product) between x and y; defaults y to x.""" if y is None: y = x return np.dot(x, y.T)
[docs] def polynomial_kernel(x, y=None, degree=2.0): """Return the polynomial kernel (1 + x.y)**degree between x and y; defaults y to x.""" if y is None: y = x return (1.0 + np.dot(x, y.T)) ** degree
[docs] def rbf_kernel(x, y=None, gamma=None): """Radial-basis function kernel (aka squared-exponential kernel).""" if y is None: y = x if gamma is None: gamma = 1.0 / x.shape[1] # 1.0 / n_features return np.exp(-gamma * (-2.0 * np.dot(x, y.T) + np.sum(x * x, axis=1).reshape((-1, 1)) + np.sum(y * y, axis=1).reshape((1, -1))))
# part4. Self defined data structures # ______________________________________________________________________________ # Grid Functions orientations = EAST, NORTH, WEST, SOUTH = [(1, 0), (0, 1), (-1, 0), (0, -1)] turns = LEFT, RIGHT = (+1, -1)
[docs] def turn_heading(heading, inc, headings=orientations): """Return the heading reached by turning inc steps around the list of headings.""" return headings[(headings.index(heading) + inc) % len(headings)]
[docs] def turn_right(heading): """Return the heading obtained by turning right (clockwise) from heading.""" return turn_heading(heading, RIGHT)
[docs] def turn_left(heading): """Return the heading obtained by turning left (counter-clockwise) from heading.""" return turn_heading(heading, LEFT)
[docs] def distance(a, b): """The distance between two (x, y) points.""" xA, yA = a xB, yB = b return np.hypot((xA - xB), (yA - yB))
[docs] def distance_squared(a, b): """The square of the distance between two (x, y) points.""" xA, yA = a xB, yB = b return (xA - xB) ** 2 + (yA - yB) ** 2
# ______________________________________________________________________________ # Misc Functions
[docs] class injection: """Dependency injection of temporary values for global functions/classes/etc. E.g., `with injection(DataBase=MockDataBase): ...`""" def __init__(self, **kwds): self.new = kwds def __enter__(self): self.old = {v: globals()[v] for v in self.new} globals().update(self.new) def __exit__(self, type, value, traceback): globals().update(self.old)
[docs] def memoize(fn, slot=None, maxsize=32): """Memoize fn: make it remember the computed value for any argument list. If slot is specified, store result in that slot of first argument. If slot is false, use lru_cache for caching the values.""" if slot: def memoized_fn(obj, *args): if hasattr(obj, slot): return getattr(obj, slot) else: val = fn(obj, *args) setattr(obj, slot, val) return val else: @functools.lru_cache(maxsize=maxsize) def memoized_fn(*args): return fn(*args) return memoized_fn
[docs] def name(obj): """Try to find some reasonable name for the object.""" return (getattr(obj, 'name', 0) or getattr(obj, '__name__', 0) or getattr(getattr(obj, '__class__', 0), '__name__', 0) or str(obj))
[docs] def isnumber(x): """Is x a number?""" return hasattr(x, '__int__')
[docs] def issequence(x): """Is x a sequence?""" return isinstance(x, collections.abc.Sequence)
[docs] def open_data(name, mode='r'): """Open and return the file named name from the aima-data directory.""" aima_root = os.path.dirname(__file__) aima_file = os.path.join(aima_root, *['aima-data', name]) return open(aima_file, mode=mode)
[docs] def failure_test(algorithm, tests): """Grades the given algorithm based on how many tests it passes. Most algorithms have arbitrary output on correct execution, which is difficult to check for correctness. On the other hand, a lot of algorithms output something particular on fail (for example, False, or None). tests is a list with each element in the form: (values, failure_output).""" return mean(int(algorithm(x) != y) for x, y in tests)
# ______________________________________________________________________________ # Expressions # See https://docs.python.org/3/reference/expressions.html#operator-precedence # See https://docs.python.org/3/reference/datamodel.html#special-method-names
[docs] class Expr: """A mathematical expression with an operator and 0 or more arguments. op is a str like '+' or 'sin'; args are Expressions. Expr('x') or Symbol('x') creates a symbol (a nullary Expr). Expr('-', x) creates a unary; Expr('+', x, 1) creates a binary.""" def __init__(self, op, *args): self.op = str(op) self.args = args # Operator overloads def __neg__(self): return Expr('-', self) def __pos__(self): return Expr('+', self) def __invert__(self): return Expr('~', self) def __add__(self, rhs): return Expr('+', self, rhs) def __sub__(self, rhs): return Expr('-', self, rhs) def __mul__(self, rhs): return Expr('*', self, rhs) def __pow__(self, rhs): return Expr('**', self, rhs) def __mod__(self, rhs): return Expr('%', self, rhs) def __and__(self, rhs): return Expr('&', self, rhs) def __xor__(self, rhs): return Expr('^', self, rhs) def __rshift__(self, rhs): return Expr('>>', self, rhs) def __lshift__(self, rhs): return Expr('<<', self, rhs) def __truediv__(self, rhs): return Expr('/', self, rhs) def __floordiv__(self, rhs): return Expr('//', self, rhs) def __matmul__(self, rhs): return Expr('@', self, rhs) def __or__(self, rhs): """Allow both P | Q, and P |'==>'| Q.""" if isinstance(rhs, Expression): return Expr('|', self, rhs) else: return PartialExpr(rhs, self) # Reverse operator overloads def __radd__(self, lhs): return Expr('+', lhs, self) def __rsub__(self, lhs): return Expr('-', lhs, self) def __rmul__(self, lhs): return Expr('*', lhs, self) def __rdiv__(self, lhs): return Expr('/', lhs, self) def __rpow__(self, lhs): return Expr('**', lhs, self) def __rmod__(self, lhs): return Expr('%', lhs, self) def __rand__(self, lhs): return Expr('&', lhs, self) def __rxor__(self, lhs): return Expr('^', lhs, self) def __ror__(self, lhs): return Expr('|', lhs, self) def __rrshift__(self, lhs): return Expr('>>', lhs, self) def __rlshift__(self, lhs): return Expr('<<', lhs, self) def __rtruediv__(self, lhs): return Expr('/', lhs, self) def __rfloordiv__(self, lhs): return Expr('//', lhs, self) def __rmatmul__(self, lhs): return Expr('@', lhs, self) def __call__(self, *args): """Call: if 'f' is a Symbol, then f(0) == Expr('f', 0).""" if self.args: raise ValueError('Can only do a call for a Symbol, not an Expr') else: return Expr(self.op, *args) # Equality and repr def __eq__(self, other): """'x == y' evaluates to True or False; does not build an Expr.""" return isinstance(other, Expr) and self.op == other.op and self.args == other.args def __lt__(self, other): return isinstance(other, Expr) and str(self) < str(other) def __hash__(self): return hash(self.op) ^ hash(self.args) def __repr__(self): op = self.op args = [str(arg) for arg in self.args] if op.isidentifier(): # f(x) or f(x, y) return '{}({})'.format(op, ', '.join(args)) if args else op elif len(args) == 1: # -x or -(x + 1) return op + args[0] else: # (x - y) opp = (' ' + op + ' ') return '(' + opp.join(args) + ')'
# An 'Expression' is either an Expr or a Number. # Symbol is not an explicit type; it is any Expr with 0 args. Number = (int, float, complex) Expression = (Expr, Number)
[docs] def Symbol(name): """A Symbol is just an Expr with no args.""" return Expr(name)
[docs] def symbols(names): """Return a tuple of Symbols; names is a comma/whitespace delimited str.""" return tuple(Symbol(name) for name in names.replace(',', ' ').split())
[docs] def subexpressions(x): """Yield the subexpressions of an Expression (including x itself).""" yield x if isinstance(x, Expr): for arg in x.args: yield from subexpressions(arg)
[docs] def arity(expression): """The number of sub-expressions in this expression.""" if isinstance(expression, Expr): return len(expression.args) else: # expression is a number return 0
# For operators that are not defined in Python, we allow new InfixOps:
[docs] class PartialExpr: r"""Given 'P \|'==>'\| Q, first form PartialExpr('==>', P), then combine with Q.""" def __init__(self, op, lhs): self.op, self.lhs = op, lhs def __or__(self, rhs): return Expr(self.op, self.lhs, rhs) def __repr__(self): return "PartialExpr('{}', {})".format(self.op, self.lhs)
[docs] def expr(x): r"""Shortcut to create an Expression. x is a str in which: - identifiers are automatically defined as Symbols. - ==> is treated as an infix \|'==>'\|, as are <== and <=>. If x is already an Expression, it is returned unchanged. Example: >>> expr('P & Q ==> Q') ((P & Q) ==> Q) """ if isinstance(x, str): return eval(expr_handle_infix_ops(x), defaultkeydict(Symbol)) else: return x
infix_ops = '==> <== <=>'.split()
[docs] def expr_handle_infix_ops(x): r"""Given a str, return a new str with ==> replaced by \|'==>'\|, etc. >>> expr_handle_infix_ops('P ==> Q') "P |'==>'| Q" """ for op in infix_ops: x = x.replace(op, '|' + repr(op) + '|') return x
[docs] class defaultkeydict(collections.defaultdict): """Like defaultdict, but the default_factory is a function of the key. >>> d = defaultkeydict(len); d['four'] 4 """ def __missing__(self, key): self[key] = result = self.default_factory(key) return result
[docs] class hashabledict(dict): """Allows hashing by representing a dictionary as tuple of key:value pairs. May cause problems as the hash value may change during runtime.""" def __hash__(self): return 1
# ______________________________________________________________________________ # Monte Carlo tree node and ucb function
[docs] class MCT_Node: """Node in the Monte Carlo search tree, keeps track of the children states.""" def __init__(self, parent=None, state=None, U=0, N=0): self.__dict__.update(parent=parent, state=state, U=U, N=N) self.children = {} self.actions = None
[docs] def ucb(n, C=1.4): """Return the UCB1 score of node n (exploitation plus C-weighted exploration term). Unvisited nodes (n.N == 0) score infinity so they are selected first; used to guide selection in Monte Carlo tree search. """ return np.inf if n.N == 0 else n.U / n.N + C * np.sqrt(np.log(n.parent.N) / n.N)
# ______________________________________________________________________________ # Useful Shorthands
[docs] class Bool(int): """Just like `bool`, except values display as 'T' and 'F' instead of 'True' and 'False'.""" __str__ = __repr__ = lambda self: 'T' if self else 'F'
T = Bool(True) F = Bool(False)