#!/usr/bin/env python3

import sys
import numpy as np
import tifffile


def get_gradient(im, fd_type, axis):
    if fd_type == 'forward':
        g = np.roll(im, -1, axis=axis) - im
    elif fd_type == 'backward':
        g = im - np.roll(im, 1, axis=axis)
    elif fd_type == 'central':
        g = (np.roll(im, -1, axis=axis) - np.roll(im, 1, axis=axis)) / 2

    return g


def main(direction, fd_type):
    im = tifffile.imread('gradient-input.tif')
    ufo = tifffile.imread('ufo-gradient.tif')

    if direction == 'horizontal':
        g = get_gradient(im, fd_type, 1)
    elif direction == 'vertical':
        g = get_gradient(im, fd_type, 0)
    elif direction == 'both':
        g = get_gradient(im, fd_type, 1)
        g += get_gradient(im, fd_type, 0)
    elif direction == 'both_abs':
        g = np.abs(get_gradient(im, fd_type, 1))
        g += np.abs(get_gradient(im, fd_type, 0))
    elif direction == 'both_mag':
        horiz = get_gradient(im, fd_type, 1)
        vert = get_gradient(im, fd_type, 0)
        g = np.sqrt(horiz ** 2 + vert ** 2)

    try:
        np.testing.assert_equal(ufo, g)
    except:
        print("Ufo:\n", ufo)
        print("Numpy:\n", g)
        return 1

    return 0


if __name__ == '__main__':
    sys.exit(main(*sys.argv[1:]))
