Source code for crappy.tool.fields

# coding: utf-8

"""More documentation coming soon !"""

import numpy as np
from typing import Union
from .._global import OptionalModule

try:
  import cv2
except (ModuleNotFoundError, ImportError):
  cv2 = OptionalModule("opencv-python")


def ones(h: int, w: int) -> np.ndarray:
  return np.ones((h, w), dtype=np.float32)


def zeros(h: int, w: int) -> np.ndarray:
  return np.zeros((h, w), dtype=np.float32)


Z = None


def z(h: int, w: int) -> np.ndarray:
  global Z
  if Z is None or Z[0].shape != (h, w):
    sh = 1 / (w * w / h / h + 1) ** .5
    sw = w*sh/h
    Z = np.meshgrid(np.linspace(-sw, sw, w, dtype=np.float32),
                    np.linspace(-sh, sh, h, dtype=np.float32))
  return Z


def get_field(s: str, h: int, w: int) -> tuple:
  if s == 'x':
    return ones(h, w), zeros(h, w)
  elif s == 'y':
    return zeros(h, w), ones(h, w)
  elif s == 'r':
    u, v = z(h, w)
    # Ratio (angle) of the rotation
    # Should be π/180 to be 1 for 1 deg
    # Z has and amplitude of 1 in the corners
    # 360 because h²+w² is twice the distance center-corner
    r = (h ** 2 + w ** 2) ** .5 * np.pi / 360
    return v * r, -u * r
  elif s == 'exx':
    return (np.concatenate((np.linspace(-w / 200, w / 200, w,
                                        dtype=np.float32)[np.newaxis, :],) * h,
                           axis=0),
            zeros(h, w))
  elif s == 'eyy':
    return (zeros(h, w),
            np.concatenate((np.linspace(-h / 200, h / 200, h,
                                        dtype=np.float32)[:, np.newaxis],) * w,
                           axis=1))
  elif s == 'exy':
    return (np.concatenate((np.linspace(-h / 200, h / 200, h,
                                        dtype=np.float32)[:, np.newaxis],) * w,
                           axis=1),
            zeros(h, w))
  elif s == 'eyx':
    return (zeros(h, w),
            np.concatenate((np.linspace(-w / 200, w / 200, w,
                                        dtype=np.float32)[np.newaxis, :],) * h,
                           axis=0))
  elif s == 'exy2':
    return (np.concatenate((np.linspace(-h / 200, h / 200, h,
                                        dtype=np.float32)[:, np.newaxis],) * w,
                           axis=1),
            (np.concatenate((np.linspace(-w / 200, w / 200, w,
                                         dtype=np.float32)[np.newaxis, :],) *
                            h,
                            axis=0)))

  elif s == 'z':
    u, v = z(h, w)
    # Zoom in %
    r = (h ** 2 + w ** 2) ** .5 / 200
    return u * r, v * r
  else:
    print("Unknown field:", s)
    raise NameError


def get_fields(l: list, h: int, w: int) -> np.ndarray:
  r = np.empty((h, w, 2, len(l)), dtype=np.float32)
  for i, s in enumerate(l):
    if isinstance(s, np.ndarray):
      r[:, :, :, i] = s
    else:
      r[:, :, 0, i], r[:, :, 1, i] = get_field(s, h, w)
  return r


class Fielder:
  def __init__(self, flist: list, h: int, w: int) -> None:
    self.nfields = len(flist)
    self.h = h
    self.w = w
    fields = get_fields(flist, h, w)
    self.fields = [fields[:, :, :, i] for i in range(fields.shape[3])]

  def get(self, *x: int) -> list:
    return sum([i * f for i, f in zip(x, self.fields)])


class Projector:
  def __init__(self,
               base: Union[np.ndarray, list],
               check_orthogonality: bool = True) -> None:
    if isinstance(base, list):
      self.base = base
    else:
      self.base = [base[:, :, :, i] for i in range(base.shape[3])]
    self.fielder = Fielder(self.base, *self.base[0].shape[:2])
    self.norms2 = [np.sum(b * b) for b in self.base]
    if check_orthogonality:
      from itertools import combinations
      s = []
      for a, b in combinations(self.base, 2):
        s.append(abs(np.sum(a * b)))
      maxs = max(s)
      if maxs / self.base[0].size > 1e-4:
        print("WARNING, base does not seem orthogonal!")
        print(s)

  def get_scal(self, flow) -> list:
    return [np.sum(vec * flow) / n2 for vec, n2 in zip(self.base, self.norms2)]

  def get_full(self, flow) -> list:
    return self.fielder.get(*self.get_scal(flow))


class OrthoProjector(Projector):
  def __init__(self, base: np.ndarray) -> None:
    vec = [base[:, :, :, i] for i in range(base.shape[3])]
    new_base = [vec[0]]
    for v in vec[1:]:
      p = Projector(new_base, check_orthogonality=False)
      new_base.append(v - p.get_full(v))
    Projector.__init__(self, new_base)


def avg_ampl(f: np.ndarray) -> float:
  return (np.sum(f[:, :, 0] ** 2 + f[:, :, 1] ** 2) / f.size * 2) ** .5


[docs]def remap(a: np.ndarray, r: np.ndarray) -> np.ndarray: """Remaps `a` using given `r` the displacement as a result from correlation.""" imy, imx = a.shape x, y = np.meshgrid(range(imx), range(imy)) return cv2.remap(a.astype(np.float32), (x + r[:, :, 0]).astype(np.float32), (y + r[:, :, 1]).astype(np.float32), 1)
def get_res(a: np.ndarray, b: np.ndarray, r: np.ndarray) -> np.ndarray: # return b - remap(a, -r) return a - remap(b, r)