Book a Demo!
CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutPoliciesSign UpSign In
automatic1111
GitHub Repository: automatic1111/stable-diffusion-webui
Path: blob/master/test/test_torch_utils.py
3055 views
1
import types
2
3
import pytest
4
import torch
5
6
from modules import torch_utils
7
8
9
@pytest.mark.parametrize("wrapped", [True, False])
10
def test_get_param(wrapped):
11
mod = torch.nn.Linear(1, 1)
12
cpu = torch.device("cpu")
13
mod.to(dtype=torch.float16, device=cpu)
14
if wrapped:
15
# more or less how spandrel wraps a thing
16
mod = types.SimpleNamespace(model=mod)
17
p = torch_utils.get_param(mod)
18
assert p.dtype == torch.float16
19
assert p.device == cpu
20
21