diff --git a/arrayfire_wrapper/_backend.py b/arrayfire_wrapper/_backend.py index 585a92c..642d000 100644 --- a/arrayfire_wrapper/_backend.py +++ b/arrayfire_wrapper/_backend.py @@ -11,8 +11,6 @@ from pathlib import Path from typing import Iterator -from arrayfire_wrapper.defines import AFArray - from .version import ARRAYFIRE_VER_MAJOR VERBOSE_LOADS = os.environ.get("AF_VERBOSE_LOADS", "") == "1" @@ -37,7 +35,7 @@ def is_cygwin(cls, name: str) -> bool: class _BackendPathConfig: lib_prefix: str lib_postfix: str - af_path: Path + af_path: Path | None af_is_user_path: bool cuda_found: bool @@ -175,7 +173,7 @@ def __iter__(self) -> Iterator: class Backend: - _backend_type: BackendType + _backend_type: BackendType | None _clibs: dict[BackendType, ctypes.CDLL] def __init__(self) -> None: @@ -297,51 +295,17 @@ def _find_nvrtc_builtins_lib_name(self, search_path: Path) -> str | None: return f.name return None - # unified backend functions - def get_active_backend(self) -> str: - if self._backend_type == BackendType.unified: - from arrayfire_wrapper.lib.unified_api_functions import get_active_backend as unified_get_active_backend - - return unified_get_active_backend() - raise RuntimeError("Using unified function on non-unified backend") - - def get_available_backends(self) -> list[int]: - if self._backend_type == BackendType.unified: - from arrayfire_wrapper.lib.unified_api_functions import ( - get_available_backends as unified_get_available_backends, - ) - - return unified_get_available_backends() - raise RuntimeError("Using unified function on non-unified backend") - - def get_backend_count(self) -> int: - if self._backend_type == BackendType.unified: - from arrayfire_wrapper.lib.unified_api_functions import get_backend_count as unified_get_backend_count - - return unified_get_backend_count() - raise RuntimeError("Using unified function on non-unified backend") - - def get_backend_id(self, arr: AFArray, /) -> int: - if self._backend_type == BackendType.unified: - from arrayfire_wrapper.lib.unified_api_functions import get_backend_id as unified_get_backend_id - - return unified_get_backend_id(arr) - raise RuntimeError("Using unified function on non-unified backend") - - def get_device_id(self, arr: AFArray, /) -> int: - if self._backend_type == BackendType.unified: - from arrayfire_wrapper.lib.unified_api_functions import get_device_id as unified_get_device_id - - return unified_get_device_id(arr) - raise RuntimeError("Using unified function on non-unified backend") - @property def backend_type(self) -> BackendType: - return self._backend_type + if self._backend_type: + return self._backend_type + raise RuntimeError("No valid _backend_type") @property def clib(self) -> ctypes.CDLL: - return self._clibs[self._backend_type] + if self._backend_type: + return self._clibs[self._backend_type] + raise RuntimeError("No valid _backend_type") # Initialize the backend diff --git a/arrayfire_wrapper/lib/mathematical_functions/numeric_functions.py b/arrayfire_wrapper/lib/mathematical_functions/numeric_functions.py index 7cbb00c..7d96198 100644 --- a/arrayfire_wrapper/lib/mathematical_functions/numeric_functions.py +++ b/arrayfire_wrapper/lib/mathematical_functions/numeric_functions.py @@ -1,16 +1,13 @@ import ctypes +import arrayfire_wrapper.dtypes as dtype +import arrayfire_wrapper.lib as wrapper from arrayfire_wrapper.defines import AFArray -from arrayfire_wrapper.dtypes import float32 from arrayfire_wrapper.lib._utility import binary_op, call_from_clib, unary_op from arrayfire_wrapper.lib.create_and_modify_array.create_array import create_constant_array from arrayfire_wrapper.lib.mathematical_functions.arithmetic_operations import sub -import arrayfire_wrapper.dtypes as dtype -import arrayfire_wrapper.lib as wrapper - - def abs_(arr: AFArray, /) -> AFArray: """ source: https://arrayfire.org/docs/group__arith__func__abs.htm#ga7e8b3c848e6cda3d1f3b0c8b2b4c3f8f diff --git a/scripts/build_package_without_binaries.sh b/scripts/build_package_without_binaries.sh index a4e9f40..a73a917 100644 --- a/scripts/build_package_without_binaries.sh +++ b/scripts/build_package_without_binaries.sh @@ -1,15 +1,15 @@ #!/bin/bash -# Run the Python script and capture the output and error +# Run the Python script and capture the output or error output=$(python -m build 2>&1) -# Define the expected error message -expected_error="Could not load any ArrayFire libraries." +# Define the expected output message +expected_output="Successfully built" -# Check if the output contains the expected error message -if echo "$output" | grep -q "$expected_error"; then - echo "Expected error received." - exit 0 # Exit with success as the error is expected +# Check if the output contains the expected output message +if echo "$output" | grep -q "$expected_output"; then + echo "Expected output received." + exit 0 # Exit with success as the output is expected else echo "Unexpected output: $output" exit 1 # Exit with failure as the output was not expected diff --git a/tests/test_numeric.py b/tests/test_numeric.py index 46715f0..26847de 100644 --- a/tests/test_numeric.py +++ b/tests/test_numeric.py @@ -5,7 +5,7 @@ import arrayfire_wrapper.dtypes as dtype import arrayfire_wrapper.lib as wrapper -from tests.utility_functions import check_type_supported, get_all_types, get_real_types, get_complex_types +from tests.utility_functions import check_type_supported, get_all_types, get_complex_types, get_real_types @pytest.mark.parametrize(