# -*- coding: utf-8 -*-
# -----------------------------------------------------------------------------
# Copyright (c) 2009- Spyder Kernels Contributors
#
# Licensed under the terms of the MIT License
# (see spyder_kernels/__init__.py for details)
# -----------------------------------------------------------------------------

"""
Input/Output Utilities

Note: 'load' functions has to return a dictionary from which a globals()
      namespace may be updated
"""

from __future__ import print_function

# Standard library imports
import sys
import os
import os.path as osp
import tarfile
import tempfile
import shutil
import types
import json
import inspect
import dis
import copy
import glob

# Local imports
from spyder_kernels.py3compat import getcwd, pickle, PY2, to_text_string
from spyder_kernels.utils.lazymodules import (
    FakeObject, numpy as np, pandas as pd, PIL, scipy as sp)


class MatlabStruct(dict):
    """
    Matlab style struct, enhanced.

    Supports dictionary and attribute style access.  Can be pickled,
    and supports code completion in a REPL.

    Examples
    ========
    >>> from spyder.utils.iofuncs import MatlabStruct
    >>> a = MatlabStruct()
    >>> a.b = 'spam'  # a["b"] == 'spam'
    >>> a.c["d"] = 'eggs'  # a.c.d == 'eggs'
    >>> print(a)
    {'c': {'d': 'eggs'}, 'b': 'spam'}

    """
    def __getattr__(self, attr):
        """Access the dictionary keys for unknown attributes."""
        try:
            return self[attr]
        except KeyError:
            msg = "'MatlabStruct' object has no attribute %s" % attr
            raise AttributeError(msg)

    def __getitem__(self, attr):
        """
        Get a dict value; create a MatlabStruct if requesting a submember.

        Do not create a key if the attribute starts with an underscore.
        """
        if attr in self.keys() or attr.startswith('_'):
            return dict.__getitem__(self, attr)
        frame = inspect.currentframe()
        # step into the function that called us
        if frame.f_back.f_back and self._is_allowed(frame.f_back.f_back):
            dict.__setitem__(self, attr, MatlabStruct())
        elif self._is_allowed(frame.f_back):
            dict.__setitem__(self, attr, MatlabStruct())
        return dict.__getitem__(self, attr)

    def _is_allowed(self, frame):
        """Check for allowed op code in the calling frame"""
        allowed = [dis.opmap['STORE_ATTR'], dis.opmap['LOAD_CONST'],
                   dis.opmap.get('STOP_CODE', 0)]
        bytecode = frame.f_code.co_code
        instruction = bytecode[frame.f_lasti + 3]
        instruction = ord(instruction) if PY2 else instruction
        return instruction in allowed

    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

    @property
    def __dict__(self):
        """Allow for code completion in a REPL"""
        return self.copy()


def get_matlab_value(val):
    """
    Extract a value from a Matlab file

    From the oct2py project, see
    https://pythonhosted.org/oct2py/conversions.html
    """
    # Extract each item of a list.
    if isinstance(val, list):
        return [get_matlab_value(v) for v in val]

    # Ignore leaf objects.
    if not isinstance(val, np.ndarray):
        return val

    # Convert user defined classes.
    if hasattr(val, 'classname'):
        out = dict()
        for name in val.dtype.names:
            out[name] = get_matlab_value(val[name].squeeze().tolist())
        cls = type(val.classname, (object,), out)
        return cls()

    # Extract struct data.
    elif val.dtype.names:
        out = MatlabStruct()
        for name in val.dtype.names:
            out[name] = get_matlab_value(val[name].squeeze().tolist())
        val = out

    # Extract cells.
    elif val.dtype.kind == 'O':
        val = val.squeeze().tolist()
        if not isinstance(val, list):
            val = [val]
        val = get_matlab_value(val)

    # Compress singleton values.
    elif val.size == 1:
        val = val.item()

    # Compress empty values.
    elif val.size == 0:
        if val.dtype.kind in 'US':
            val = ''
        else:
            val = []

    return val


def load_matlab(filename):
    if sp.io is FakeObject:
        return None, ''

    try:
        out = sp.io.loadmat(filename, struct_as_record=True)
        data = dict()
        for (key, value) in out.items():
            data[key] = get_matlab_value(value)
        return data, None
    except Exception as error:
        return None, str(error)


def save_matlab(data, filename):
    if sp.io is FakeObject:
        return

    try:
        sp.io.savemat(filename, data, oned_as='row')
    except Exception as error:
        return str(error)


def load_array(filename):
    if np.load is FakeObject:
        return None, ''

    try:
        name = osp.splitext(osp.basename(filename))[0]
        data = np.load(filename)
        if isinstance(data, np.lib.npyio.NpzFile):
            return dict(data), None
        elif hasattr(data, 'keys'):
            return data, None
        else:
            return {name: data}, None
    except Exception as error:
        return None, str(error)


def __save_array(data, basename, index):
    """Save numpy array"""
    fname = basename + '_%04d.npy' % index
    np.save(fname, data)
    return fname


if sys.byteorder == 'little':
    _ENDIAN = '<'
else:
    _ENDIAN = '>'

DTYPES = {
    "1": ('|b1', None),
    "L": ('|u1', None),
    "I": ('%si4' % _ENDIAN, None),
    "F": ('%sf4' % _ENDIAN, None),
    "I;16": ('|u2', None),
    "I;16S": ('%si2' % _ENDIAN, None),
    "P": ('|u1', None),
    "RGB": ('|u1', 3),
    "RGBX": ('|u1', 4),
    "RGBA": ('|u1', 4),
    "CMYK": ('|u1', 4),
    "YCbCr": ('|u1', 4),
}


def __image_to_array(filename):
    img = PIL.Image.open(filename)
    try:
        dtype, extra = DTYPES[img.mode]
    except KeyError:
        raise RuntimeError("%s mode is not supported" % img.mode)
    shape = (img.size[1], img.size[0])
    if extra is not None:
        shape += (extra,)
    return np.array(img.getdata(), dtype=np.dtype(dtype)).reshape(shape)


def load_image(filename):
    if PIL.Image is FakeObject or np.array is FakeObject:
        return None, ''

    try:
        name = osp.splitext(osp.basename(filename))[0]
        return {name: __image_to_array(filename)}, None
    except Exception as error:
        return None, str(error)


def load_pickle(filename):
    """Load a pickle file as a dictionary"""
    try:
        if pd.read_pickle is not FakeObject:
            return pd.read_pickle(filename), None
        else:
            with open(filename, 'rb') as fid:
                data = pickle.load(fid)
            return data, None
    except Exception as err:
        return None, str(err)


def load_json(filename):
    """Load a json file as a dictionary"""
    try:
        if PY2:
            args = 'rb'
        else:
            args = 'r'
        with open(filename, args) as fid:
            data = json.load(fid)
        return data, None
    except Exception as err:
        return None, str(err)


def save_dictionary(data, filename):
    """Save dictionary in a single file .spydata file"""
    filename = osp.abspath(filename)
    old_cwd = getcwd()
    os.chdir(osp.dirname(filename))
    error_message = None
    skipped_keys = []
    data_copy = {}

    try:
        # Copy dictionary before modifying it to fix #6689
        for obj_name, obj_value in data.items():
            # Skip modules, since they can't be pickled, users virtually never
            # would want them to be and so they don't show up in the skip list.
            # Skip callables, since they are only pickled by reference and thus
            # must already be present in the user's environment anyway.
            if not (callable(obj_value) or isinstance(obj_value,
                                                      types.ModuleType)):
                # If an object cannot be deepcopied, then it cannot be pickled.
                # Ergo, we skip it and list it later.
                try:
                    data_copy[obj_name] = copy.deepcopy(obj_value)
                except Exception:
                    skipped_keys.append(obj_name)
        data = data_copy
        if not data:
            raise RuntimeError('No supported objects to save')

        saved_arrays = {}
        if np.ndarray is not FakeObject:
            # Saving numpy arrays with np.save
            arr_fname = osp.splitext(filename)[0]
            for name in list(data.keys()):
                try:
                    if (isinstance(data[name], np.ndarray) and
                            data[name].size > 0):
                        # Save arrays at data root
                        fname = __save_array(data[name], arr_fname,
                                             len(saved_arrays))
                        saved_arrays[(name, None)] = osp.basename(fname)
                        data.pop(name)
                    elif isinstance(data[name], (list, dict)):
                        # Save arrays nested in lists or dictionaries
                        if isinstance(data[name], list):
                            iterator = enumerate(data[name])
                        else:
                            iterator = iter(list(data[name].items()))
                        to_remove = []
                        for index, value in iterator:
                            if (isinstance(value, np.ndarray) and
                                    value.size > 0):
                                fname = __save_array(value, arr_fname,
                                                     len(saved_arrays))
                                saved_arrays[(name, index)] = (
                                    osp.basename(fname))
                                to_remove.append(index)
                        for index in sorted(to_remove, reverse=True):
                            data[name].pop(index)
                except (RuntimeError, pickle.PicklingError, TypeError,
                        AttributeError, IndexError):
                    # If an array can't be saved with numpy for some reason,
                    # leave the object intact and try to save it normally.
                    pass
            if saved_arrays:
                data['__saved_arrays__'] = saved_arrays

        pickle_filename = osp.splitext(filename)[0] + '.pickle'
        # Attempt to pickle everything.
        # If pickling fails, iterate through to eliminate problem objs & retry.
        with open(pickle_filename, 'w+b') as fdesc:
            try:
                pickle.dump(data, fdesc, protocol=2)
            except (pickle.PicklingError, AttributeError, TypeError,
                    ImportError, IndexError, RuntimeError):
                data_filtered = {}
                for obj_name, obj_value in data.items():
                    try:
                        pickle.dumps(obj_value, protocol=2)
                    except Exception:
                        skipped_keys.append(obj_name)
                    else:
                        data_filtered[obj_name] = obj_value
                if not data_filtered:
                    raise RuntimeError('No supported objects to save')
                pickle.dump(data_filtered, fdesc, protocol=2)

        # Use PAX (POSIX.1-2001) format instead of default GNU.
        # This improves interoperability and UTF-8/long variable name support.
        with tarfile.open(filename, "w", format=tarfile.PAX_FORMAT) as tar:
            for fname in ([pickle_filename]
                          + [fn for fn in list(saved_arrays.values())]):
                tar.add(osp.basename(fname))
                os.remove(fname)
    except (RuntimeError, pickle.PicklingError, TypeError) as error:
        error_message = to_text_string(error)
    else:
        if skipped_keys:
            skipped_keys.sort()
            error_message = ('Some objects could not be saved: '
                             + ', '.join(skipped_keys))
    finally:
        os.chdir(old_cwd)
    return error_message


def is_within_directory(directory, target):
    """Check if a file is within a directory."""
    abs_directory = os.path.abspath(directory)
    abs_target = os.path.abspath(target)
    prefix = os.path.commonprefix([abs_directory, abs_target])
    return prefix == abs_directory


def safe_extract(tar, path=".", members=None, numeric_owner=False):
    """Safely extract a tar file."""
    for member in tar.getmembers():
        member_path = os.path.join(path, member.name)
        if not is_within_directory(path, member_path):
            raise Exception(
                "Attempted path traversal in tar file {}".format(
                    repr(tar.name)
                )
            )
    tar.extractall(path, members, numeric_owner=numeric_owner)


def load_dictionary(filename):
    """Load dictionary from .spydata file"""
    filename = osp.abspath(filename)
    old_cwd = getcwd()
    tmp_folder = tempfile.mkdtemp()
    os.chdir(tmp_folder)
    data = None
    error_message = None
    try:
        with tarfile.open(filename, "r") as tar:
            if PY2:
                tar.extractall()
            else:
                safe_extract(tar)

        pickle_filename = glob.glob('*.pickle')[0]
        # 'New' format (Spyder >=2.2 for Python 2 and Python 3)
        with open(pickle_filename, 'rb') as fdesc:
            data = pickle.loads(fdesc.read())
        saved_arrays = {}
        if np.load is not FakeObject:
            # Loading numpy arrays saved with np.save
            try:
                saved_arrays = data.pop('__saved_arrays__')
                for (name, index), fname in list(saved_arrays.items()):
                    arr = np.load(osp.join(tmp_folder, fname), allow_pickle=True)
                    if index is None:
                        data[name] = arr
                    elif isinstance(data[name], dict):
                        data[name][index] = arr
                    else:
                        data[name].insert(index, arr)
            except KeyError:
                pass
    # Except AttributeError from e.g. trying to load function no longer present
    except (AttributeError, EOFError, ValueError) as error:
        error_message = to_text_string(error)
    # To ensure working dir gets changed back and temp dir wiped no matter what
    finally:
        os.chdir(old_cwd)
        try:
            shutil.rmtree(tmp_folder)
        except OSError as error:
            error_message = to_text_string(error)
    return data, error_message


class IOFunctions(object):
    def __init__(self):
        self.load_extensions = None
        self.save_extensions = None
        self.load_filters = None
        self.save_filters = None
        self.load_funcs = None
        self.save_funcs = None

    def setup(self):
        iofuncs = self.get_internal_funcs()+self.get_3rd_party_funcs()
        load_extensions = {}
        save_extensions = {}
        load_funcs = {}
        save_funcs = {}
        load_filters = []
        save_filters = []
        load_ext = []
        for ext, name, loadfunc, savefunc in iofuncs:
            filter_str = to_text_string(name + " (*%s)" % ext)
            if loadfunc is not None:
                load_filters.append(filter_str)
                load_extensions[filter_str] = ext
                load_funcs[ext] = loadfunc
                load_ext.append(ext)
            if savefunc is not None:
                save_extensions[filter_str] = ext
                save_filters.append(filter_str)
                save_funcs[ext] = savefunc
        load_filters.insert(0, to_text_string("Supported files"+" (*"+\
                                              " *".join(load_ext)+")"))
        load_filters.append(to_text_string("All files (*.*)"))
        self.load_filters = "\n".join(load_filters)
        self.save_filters = "\n".join(save_filters)
        self.load_funcs = load_funcs
        self.save_funcs = save_funcs
        self.load_extensions = load_extensions
        self.save_extensions = save_extensions

    def get_internal_funcs(self):
        return [
                ('.spydata', "Spyder data files",
                             load_dictionary, save_dictionary),
                ('.npy', "NumPy arrays", load_array, None),
                ('.npz', "NumPy zip arrays", load_array, None),
                ('.mat', "Matlab files", load_matlab, save_matlab),
                ('.csv', "CSV text files", 'import_wizard', None),
                ('.txt', "Text files", 'import_wizard', None),
                ('.jpg', "JPEG images", load_image, None),
                ('.png', "PNG images", load_image, None),
                ('.gif', "GIF images", load_image, None),
                ('.tif', "TIFF images", load_image, None),
                ('.pkl', "Pickle files", load_pickle, None),
                ('.pickle', "Pickle files", load_pickle, None),
                ('.json', "JSON files", load_json, None),
                ]

    def get_3rd_party_funcs(self):
        other_funcs = []
        try:
            from spyder.otherplugins import get_spyderplugins_mods
            for mod in get_spyderplugins_mods(io=True):
                try:
                    other_funcs.append((mod.FORMAT_EXT, mod.FORMAT_NAME,
                                        mod.FORMAT_LOAD, mod.FORMAT_SAVE))
                except AttributeError as error:
                    print("%s: %s" % (mod, str(error)), file=sys.stderr)
        except ImportError:
            pass
        return other_funcs

    def save(self, data, filename):
        ext = osp.splitext(filename)[1].lower()
        if ext in self.save_funcs:
            return self.save_funcs[ext](data, filename)
        else:
            return "<b>Unsupported file type '%s'</b>" % ext

    def load(self, filename):
        ext = osp.splitext(filename)[1].lower()
        if ext in self.load_funcs:
            return self.load_funcs[ext](filename)
        else:
            return None, "<b>Unsupported file type '%s'</b>" % ext

iofunctions = IOFunctions()
iofunctions.setup()


def save_auto(data, filename):
    """Save data into filename, depending on file extension"""
    pass


if __name__ == "__main__":
    import datetime
    testdict = {'d': 1, 'a': np.random.rand(10, 10), 'b': [1, 2]}
    testdate = datetime.date(1945, 5, 8)
    example = {'str': 'kjkj kj k j j kj k jkj',
               'unicode': u'éù',
               'list': [1, 3, [4, 5, 6], 'kjkj', None],
               'tuple': ([1, testdate, testdict], 'kjkj', None),
               'dict': testdict,
               'float': 1.2233,
               'array': np.random.rand(4000, 400),
               'empty_array': np.array([]),
               'date': testdate,
               'datetime': datetime.datetime(1945, 5, 8),
               }
    import time
    t0 = time.time()
    save_dictionary(example, "test.spydata")
    print(" Data saved in %.3f seconds" % (time.time()-t0))  # spyder: test-skip
    t0 = time.time()
    example2, ok = load_dictionary("test.spydata")
    os.remove("test.spydata")

    print("Data loaded in %.3f seconds" % (time.time()-t0))  # spyder: test-skip
