# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""Extra utils depending on types that are shared between sync and async modules."""

import inspect
import logging
import sys
import typing
from typing import Any, Callable, Dict, Optional, Union, get_args, get_origin

import pydantic

from . import _common
from . import errors
from . import types

if sys.version_info >= (3, 10):
  from types import UnionType
else:
  UnionType = typing._UnionGenericAlias

_DEFAULT_MAX_REMOTE_CALLS_AFC = 10

logger = logging.getLogger('google_genai.models')


def _create_generate_content_config_model(
    config: types.GenerateContentConfigOrDict,
) -> types.GenerateContentConfig:
  if isinstance(config, dict):
    return types.GenerateContentConfig(**config)
  else:
    return config


def format_destination(
    src: str,
    config: Optional[types.CreateBatchJobConfigOrDict] = None,
) -> types.CreateBatchJobConfig:
  """Formats the destination uri based on the source uri."""
  config = (
      types._CreateBatchJobParameters(config=config).config
      or types.CreateBatchJobConfig()
  )

  unique_name = None
  if not config.display_name:
    unique_name = _common.timestamped_unique_name()
    config.display_name = f'genai_batch_job_{unique_name}'

  if not config.dest:
    if src.startswith('gs://') and src.endswith('.jsonl'):
      # If source uri is "gs://bucket/path/to/src.jsonl", then the destination
      # uri prefix will be "gs://bucket/path/to/src/dest".
      config.dest = f'{src[:-6]}/dest'
    elif src.startswith('bq://'):
      # If source uri is "bq://project.dataset.src", then the destination
      # uri will be "bq://project.dataset.src_dest_TIMESTAMP_UUID".
      unique_name = unique_name or _common.timestamped_unique_name()
      config.dest = f'{src}_dest_{unique_name}'
    else:
      raise ValueError(f'Unsupported source: {src}')
  return config


def get_function_map(
    config: Optional[types.GenerateContentConfigOrDict] = None,
) -> dict[str, Callable]:
  """Returns a function map from the config."""
  function_map: dict[str, Callable] = {}
  if not config:
    return function_map
  config_model = _create_generate_content_config_model(config)
  if config_model.tools:
    for tool in config_model.tools:
      if callable(tool):
        if inspect.iscoroutinefunction(tool):
          raise errors.UnsupportedFunctionError(
              f'Function {tool.__name__} is a coroutine function, which is not'
              ' supported for automatic function calling. Please manually'
              f' invoke {tool.__name__} to get the function response.'
          )
        function_map[tool.__name__] = tool
  return function_map


def convert_number_values_for_dict_function_call_args(
    args: dict[str, Any],
) -> dict[str, Any]:
  """Converts float values in dict with no decimal to integers."""
  return {
      key: convert_number_values_for_function_call_args(value)
      for key, value in args.items()
  }


def convert_number_values_for_function_call_args(
    args: Union[dict[str, object], list[object], object],
) -> Union[dict[str, object], list[object], object]:
  """Converts float values with no decimal to integers."""
  if isinstance(args, float) and args.is_integer():
    return int(args)
  if isinstance(args, dict):
    return {
        key: convert_number_values_for_function_call_args(value)
        for key, value in args.items()
    }
  if isinstance(args, list):
    return [
        convert_number_values_for_function_call_args(value) for value in args
    ]
  return args


def is_annotation_pydantic_model(annotation: Any) -> bool:
  try:
    return inspect.isclass(annotation) and issubclass(
        annotation, pydantic.BaseModel
    )
  # for python 3.10 and below, inspect.isclass(annotation) has inconsistent
  # results with versions above. for example, inspect.isclass(dict[str, int]) is
  # True in 3.10 and below but False in 3.11 and above.
  except TypeError:
    return False


def convert_if_exist_pydantic_model(
    value: Any, annotation: Any, param_name: str, func_name: str
) -> Any:
  if isinstance(value, dict) and is_annotation_pydantic_model(annotation):
    try:
      return annotation(**value)
    except pydantic.ValidationError as e:
      raise errors.UnknownFunctionCallArgumentError(
          f'Failed to parse parameter {param_name} for function'
          f' {func_name} from function call part because function call argument'
          f' value {value} is not compatible with parameter annotation'
          f' {annotation}, due to error {e}'
      )
  if isinstance(value, list) and get_origin(annotation) == list:
    item_type = get_args(annotation)[0]
    return [
        convert_if_exist_pydantic_model(item, item_type, param_name, func_name)
        for item in value
    ]
  if isinstance(value, dict) and get_origin(annotation) == dict:
    _, value_type = get_args(annotation)
    return {
        k: convert_if_exist_pydantic_model(v, value_type, param_name, func_name)
        for k, v in value.items()
    }
  # example 1: typing.Union[int, float]
  # example 2: int | float equivalent to UnionType[int, float]
  if get_origin(annotation) in (Union, UnionType):
    for arg in get_args(annotation):
      if (
          (get_args(arg) and get_origin(arg) is list)
          or isinstance(value, arg)
          or (isinstance(value, dict) and is_annotation_pydantic_model(arg))
      ):
        try:
          return convert_if_exist_pydantic_model(
              value, arg, param_name, func_name
          )
        # do not raise here because there could be multiple pydantic model types
        # in the union type.
        except pydantic.ValidationError:
          continue
    # if none of the union type is matched, raise error
    raise errors.UnknownFunctionCallArgumentError(
        f'Failed to parse parameter {param_name} for function'
        f' {func_name} from function call part because function call argument'
        f' value {value} cannot be converted to parameter annotation'
        f' {annotation}.'
    )
  # the only exception for value and annotation type to be different is int and
  # float. see convert_number_values_for_function_call_args function for context
  if isinstance(value, int) and annotation is float:
    return value
  if not isinstance(value, annotation):
    raise errors.UnknownFunctionCallArgumentError(
        f'Failed to parse parameter {param_name} for function {func_name} from'
        f' function call part because function call argument value {value} is'
        f' not compatible with parameter annotation {annotation}.'
    )
  return value


def invoke_function_from_dict_args(
    args: Dict[str, Any], function_to_invoke: Callable
) -> Any:
  signature = inspect.signature(function_to_invoke)
  func_name = function_to_invoke.__name__
  converted_args = {}
  for param_name, param in signature.parameters.items():
    if param_name in args:
      converted_args[param_name] = convert_if_exist_pydantic_model(
          args[param_name],
          param.annotation,
          param_name,
          func_name,
      )
  try:
    return function_to_invoke(**converted_args)
  except Exception as e:
    raise errors.FunctionInvocationError(
        f'Failed to invoke function {func_name} with converted arguments'
        f' {converted_args} from model returned function call argument'
        f' {args} because of error {e}'
    )


def get_function_response_parts(
    response: types.GenerateContentResponse,
    function_map: dict[str, Callable],
) -> list[types.Part]:
  """Returns the function response parts from the response."""
  func_response_parts = []
  if (
      response.candidates is not None
      and isinstance(response.candidates[0].content, types.Content)
      and response.candidates[0].content.parts is not None
  ):
    for part in response.candidates[0].content.parts:
      if not part.function_call:
        continue
      func_name = part.function_call.name
      if func_name is not None and part.function_call.args is not None:
        func = function_map[func_name]
        args = convert_number_values_for_dict_function_call_args(
            part.function_call.args
        )
        func_response: dict[str, Any]
        try:
          func_response = {
              'result': invoke_function_from_dict_args(args, func)
          }
        except Exception as e:  # pylint: disable=broad-except
          func_response = {'error': str(e)}
        func_response_part = types.Part.from_function_response(
            name=func_name, response=func_response
        )
        func_response_parts.append(func_response_part)
  return func_response_parts


def should_disable_afc(
    config: Optional[types.GenerateContentConfigOrDict] = None,
) -> bool:
  """Returns whether automatic function calling is enabled."""
  if not config:
    return False
  config_model = _create_generate_content_config_model(config)
  # If max_remote_calls is less or equal to 0, warn and disable AFC.
  if (
      config_model
      and config_model.automatic_function_calling
      and config_model.automatic_function_calling.maximum_remote_calls
      is not None
      and int(config_model.automatic_function_calling.maximum_remote_calls) <= 0
  ):
    logger.warning(
        'max_remote_calls in automatic_function_calling_config'
        f' {config_model.automatic_function_calling.maximum_remote_calls} is'
        ' less than or equal to 0. Disabling automatic function calling.'
        ' Please set max_remote_calls to a positive integer.'
    )
    return True

  # Default to enable AFC if not specified.
  if (
      not config_model.automatic_function_calling
      or config_model.automatic_function_calling.disable is None
  ):
    return False

  if (
      config_model.automatic_function_calling.disable
      and config_model.automatic_function_calling.maximum_remote_calls
      is not None
      # exclude the case where max_remote_calls is set to 10 by default.
      and 'maximum_remote_calls'
      in config_model.automatic_function_calling.model_fields_set
      and int(config_model.automatic_function_calling.maximum_remote_calls) > 0
  ):
    logger.warning(
        '`automatic_function_calling.disable` is set to `True`. And'
        ' `automatic_function_calling.maximum_remote_calls` is a'
        ' positive number'
        f' {config_model.automatic_function_calling.maximum_remote_calls}.'
        ' Disabling automatic function calling. If you want to enable'
        ' automatic function calling, please set'
        ' `automatic_function_calling.disable` to `False` or leave it unset,'
        ' and set `automatic_function_calling.maximum_remote_calls` to a'
        ' positive integer or leave'
        ' `automatic_function_calling.maximum_remote_calls` unset.'
    )

  return config_model.automatic_function_calling.disable


def get_max_remote_calls_afc(
    config: Optional[types.GenerateContentConfigOrDict] = None,
) -> int:
  if not config:
    return _DEFAULT_MAX_REMOTE_CALLS_AFC
  """Returns the remaining remote calls for automatic function calling."""
  if should_disable_afc(config):
    raise ValueError(
        'automatic function calling is not enabled, but SDK is trying to get'
        ' max remote calls.'
    )
  config_model = _create_generate_content_config_model(config)
  if (
      not config_model.automatic_function_calling
      or config_model.automatic_function_calling.maximum_remote_calls is None
  ):
    return _DEFAULT_MAX_REMOTE_CALLS_AFC
  return int(config_model.automatic_function_calling.maximum_remote_calls)


def should_append_afc_history(
    config: Optional[types.GenerateContentConfigOrDict] = None,
) -> bool:
  if not config:
    return True
  config_model = _create_generate_content_config_model(config)
  if not config_model.automatic_function_calling:
    return True
  return not config_model.automatic_function_calling.ignore_call_history
