| try: |
| import jitfields |
| available = True |
| except (ImportError, ModuleNotFoundError): |
| jitfields = None |
| available = False |
| from .utils import make_list |
| import torch |
|
|
|
|
| def first2last(input, ndim): |
| insert = input.dim() <= ndim |
| if insert: |
| input = input.unsqueeze(-1) |
| else: |
| input = torch.movedim(input, -ndim-1, -1) |
| return input, insert |
|
|
|
|
| def last2first(input, ndim, inserted, grad=False): |
| if inserted: |
| input = input.squeeze(-1 - grad) |
| else: |
| input = torch.movedim(input, -1 - grad, -ndim-1 - grad) |
| return input |
|
|
|
|
| def grid_pull(input, grid, interpolation='linear', bound='zero', |
| extrapolate=False, prefilter=False): |
| ndim = grid.shape[-1] |
| input, inserted = first2last(input, ndim) |
| input = jitfields.pull(input, grid, order=interpolation, bound=bound, |
| extrapolate=extrapolate, prefilter=prefilter) |
| input = last2first(input, ndim, inserted) |
| return input |
|
|
|
|
| def grid_push(input, grid, shape=None, interpolation='linear', bound='zero', |
| extrapolate=False, prefilter=False): |
| ndim = grid.shape[-1] |
| input, inserted = first2last(input, ndim) |
| input = jitfields.push(input, grid, shape, order=interpolation, bound=bound, |
| extrapolate=extrapolate, prefilter=prefilter) |
| input = last2first(input, ndim, inserted) |
| return input |
|
|
|
|
| def grid_count(grid, shape=None, interpolation='linear', bound='zero', |
| extrapolate=False): |
| return jitfields.count(grid, shape, order=interpolation, bound=bound, |
| extrapolate=extrapolate) |
|
|
|
|
| def grid_grad(input, grid, interpolation='linear', bound='zero', |
| extrapolate=False, prefilter=False): |
| ndim = grid.shape[-1] |
| input, inserted = first2last(input, ndim) |
| input = jitfields.grad(input, grid, order=interpolation, bound=bound, |
| extrapolate=extrapolate, prefilter=prefilter) |
| input = last2first(input, ndim, inserted, True) |
| return input |
|
|
|
|
| def spline_coeff(input, interpolation='linear', bound='dct2', dim=-1, |
| inplace=False): |
| func = jitfields.spline_coeff_ if inplace else jitfields.spline_coeff |
| return func(input, interpolation, bound=bound, dim=dim) |
|
|
|
|
| def spline_coeff_nd(input, interpolation='linear', bound='dct2', dim=None, |
| inplace=False): |
| func = jitfields.spline_coeff_nd_ if inplace else jitfields.spline_coeff_nd |
| return func(input, interpolation, bound=bound, ndim=dim) |
|
|
|
|
| def resize(image, factor=None, shape=None, anchor='c', |
| interpolation=1, prefilter=True, **kwargs): |
| kwargs.setdefault('bound', 'nearest') |
| ndim = max(len(make_list(factor or [])), |
| len(make_list(shape or [])), |
| len(make_list(anchor or []))) or (image.dim() - 2) |
| return jitfields.resize(image, factor=factor, shape=shape, ndim=ndim, |
| anchor=anchor, order=interpolation, |
| bound=kwargs['bound'], prefilter=prefilter) |
|
|
|
|
| def restrict(image, factor=None, shape=None, anchor='c', |
| interpolation=1, reduce_sum=False, **kwargs): |
| kwargs.setdefault('bound', 'nearest') |
| ndim = max(len(make_list(factor or [])), |
| len(make_list(shape or [])), |
| len(make_list(anchor or []))) or (image.dim() - 2) |
| return jitfields.restrict(image, factor=factor, shape=shape, ndim=ndim, |
| anchor=anchor, order=interpolation, |
| bound=kwargs['bound'], reduce_sum=reduce_sum) |
|
|