Coverage for tests\test_get_module_device.py: 70%
27 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-14 12:57 -0700
« prev ^ index » next coverage.py v7.6.12, created at 2025-02-14 12:57 -0700
1from __future__ import annotations
3import pytest
4import torch
6from zanj.torchutil import get_module_device
9def test_get_module_device_single_device():
10 # Create a model and move it to a device
11 model = torch.nn.Linear(10, 2)
12 device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
13 model.to(device)
15 # Run the function
16 is_single, device_or_dict = get_module_device(model)
18 # Assert that all parameters are on the same device and that device is returned
19 assert is_single
20 assert device_or_dict == device
23def test_get_module_device_multiple_devices():
24 # Create a model with parameters on different devices
25 if torch.cuda.device_count() < 1:
26 pytest.skip("This test requires at least one CUDA device")
28 model = torch.nn.Linear(10, 2)
29 model.weight.to("cuda:0")
30 model.bias.to("cpu")
32 # Run the function
33 is_single, device_or_dict = get_module_device(model)
35 # Assert that not all parameters are on the same device and a dict is returned
36 assert not is_single
37 assert isinstance(device_or_dict, dict)
39 # Check that the dict maps the correct devices
40 assert device_or_dict["weight"] == torch.device("cuda:0")
41 assert device_or_dict["bias"] == torch.device("cpu")
44def test_get_module_device_no_parameters():
45 # Create a model with no parameters
46 model = torch.nn.Sequential()
48 # Run the function
49 is_single, device_or_dict = get_module_device(model)
51 # Assert that an empty dict is returned
52 assert not is_single
53 assert device_or_dict == {}