# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

from datetime import timedelta
import pyarrow.fs as fs
import pyarrow as pa
import pytest

encryption_unavailable = False

try:
    import pyarrow.dataset as ds
except ImportError:
    ds = None

try:
    from pyarrow.tests.parquet.encryption import InMemoryKmsClient
    import pyarrow.parquet.encryption as pe
except ImportError:
    encryption_unavailable = True


# Marks all of the tests in this module
pytestmark = pytest.mark.dataset


FOOTER_KEY = b"0123456789112345"
FOOTER_KEY_NAME = "footer_key"
COL_KEY = b"1234567890123450"
COL_KEY_NAME = "col_key"


def create_sample_table():
    return pa.table(
        {
            "year": [2020, 2022, 2021, 2022, 2019, 2021],
            "n_legs": [2, 2, 4, 4, 5, 100],
            "animal": [
                "Flamingo",
                "Parrot",
                "Dog",
                "Horse",
                "Brittle stars",
                "Centipede",
            ],
        }
    )


def create_encryption_config():
    return pe.EncryptionConfiguration(
        footer_key=FOOTER_KEY_NAME,
        plaintext_footer=False,
        column_keys={COL_KEY_NAME: ["n_legs", "animal"]},
        encryption_algorithm="AES_GCM_V1",
        # requires timedelta or an assertion is raised
        cache_lifetime=timedelta(minutes=5.0),
        data_key_length_bits=256,
    )


def create_decryption_config():
    return pe.DecryptionConfiguration(cache_lifetime=300)


def create_kms_connection_config():
    return pe.KmsConnectionConfig(
        custom_kms_conf={
            FOOTER_KEY_NAME: FOOTER_KEY.decode("UTF-8"),
            COL_KEY_NAME: COL_KEY.decode("UTF-8"),
        }
    )


def kms_factory(kms_connection_configuration):
    return InMemoryKmsClient(kms_connection_configuration)


@pytest.mark.skipif(
    encryption_unavailable, reason="Parquet Encryption is not currently enabled"
)
def test_dataset_encryption_decryption():
    table = create_sample_table()

    encryption_config = create_encryption_config()
    decryption_config = create_decryption_config()
    kms_connection_config = create_kms_connection_config()

    crypto_factory = pe.CryptoFactory(kms_factory)
    parquet_encryption_cfg = ds.ParquetEncryptionConfig(
        crypto_factory, kms_connection_config, encryption_config
    )
    parquet_decryption_cfg = ds.ParquetDecryptionConfig(
        crypto_factory, kms_connection_config, decryption_config
    )

    # create write_options with dataset encryption config
    pformat = pa.dataset.ParquetFileFormat()
    write_options = pformat.make_write_options(encryption_config=parquet_encryption_cfg)

    mockfs = fs._MockFileSystem()
    mockfs.create_dir("/")

    ds.write_dataset(
        data=table,
        base_dir="sample_dataset",
        format=pformat,
        file_options=write_options,
        filesystem=mockfs,
    )

    # read without decryption config -> should error is dataset was properly encrypted
    pformat = pa.dataset.ParquetFileFormat()
    with pytest.raises(IOError, match=r"no decryption"):
        ds.dataset("sample_dataset", format=pformat, filesystem=mockfs)

    # set decryption config for parquet fragment scan options
    pq_scan_opts = ds.ParquetFragmentScanOptions(
        decryption_config=parquet_decryption_cfg
    )
    pformat = pa.dataset.ParquetFileFormat(default_fragment_scan_options=pq_scan_opts)
    dataset = ds.dataset("sample_dataset", format=pformat, filesystem=mockfs)

    assert table.equals(dataset.to_table())


@pytest.mark.skipif(
    not encryption_unavailable, reason="Parquet Encryption is currently enabled"
)
def test_write_dataset_parquet_without_encryption():
    """Test write_dataset with ParquetFileFormat and test if an exception is thrown
    if you try to set encryption_config using make_write_options"""

    # Set the encryption configuration using ParquetFileFormat
    # and make_write_options
    pformat = pa.dataset.ParquetFileFormat()

    with pytest.raises(NotImplementedError):
        _ = pformat.make_write_options(encryption_config="some value")
