Skip to content

Commit 0fe2764

Browse files
committed
IPEX fix dtype errors when GPU supports 64 bit
1 parent 0d805e3 commit 0fe2764

File tree

3 files changed

+42
-30
lines changed

3 files changed

+42
-30
lines changed

modules/intel/ipex/__init__.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -165,11 +165,6 @@ def ipex_init(): # pylint: disable=too-many-statements
165165

166166
ipex_hijacks()
167167
if not torch.xpu.has_fp64_dtype():
168-
try:
169-
from .attention import attention_init
170-
attention_init()
171-
except Exception: # pylint: disable=broad-exception-caught
172-
pass
173168
try:
174169
from .diffusers import ipex_diffusers
175170
ipex_diffusers()

modules/intel/ipex/attention.py

Lines changed: 8 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,8 @@
44
# pylint: disable=protected-access, missing-function-docstring, line-too-long
55

66
original_torch_bmm = torch.bmm
7-
def torch_bmm(input, mat2, *, out=None):
8-
if input.dtype != mat2.dtype:
9-
mat2 = mat2.to(input.dtype)
10-
11-
#ARC GPUs can't allocate more than 4GB to a single block, Slice it:
7+
def torch_bmm_32_bit(input, mat2, *, out=None):
8+
# ARC GPUs can't allocate more than 4GB to a single block, Slice it:
129
batch_size_attention, input_tokens, mat2_shape = input.shape[0], input.shape[1], mat2.shape[2]
1310
block_multiply = input.element_size()
1411
slice_block_size = input_tokens * mat2_shape / 1024 / 1024 * block_multiply
@@ -17,7 +14,7 @@ def torch_bmm(input, mat2, *, out=None):
1714
split_slice_size = batch_size_attention
1815
if block_size > 4:
1916
do_split = True
20-
#Find something divisible with the input_tokens
17+
# Find something divisible with the input_tokens
2118
while (split_slice_size * slice_block_size) > 4:
2219
split_slice_size = split_slice_size // 2
2320
if split_slice_size <= 1:
@@ -30,7 +27,7 @@ def torch_bmm(input, mat2, *, out=None):
3027
if split_slice_size * slice_block_size > 4:
3128
slice_block_size2 = split_slice_size * mat2_shape / 1024 / 1024 * block_multiply
3229
do_split_2 = True
33-
#Find something divisible with the input_tokens
30+
# Find something divisible with the input_tokens
3431
while (split_2_slice_size * slice_block_size2) > 4:
3532
split_2_slice_size = split_2_slice_size // 2
3633
if split_2_slice_size <= 1:
@@ -64,8 +61,8 @@ def torch_bmm(input, mat2, *, out=None):
6461
return hidden_states
6562

6663
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
67-
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
68-
#ARC GPUs can't allocate more than 4GB to a single block, Slice it:
64+
def scaled_dot_product_attention_32_bit(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
65+
# ARC GPUs can't allocate more than 4GB to a single block, Slice it:
6966
if len(query.shape) == 3:
7067
batch_size_attention, query_tokens, shape_four = query.shape
7168
shape_one = 1
@@ -74,19 +71,14 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.
7471
shape_one, batch_size_attention, query_tokens, shape_four = query.shape
7572
no_shape_one = False
7673

77-
if query.dtype != key.dtype:
78-
key = key.to(dtype=query.dtype)
79-
if query.dtype != value.dtype:
80-
value = value.to(dtype=query.dtype)
81-
8274
block_multiply = query.element_size()
8375
slice_block_size = shape_one * query_tokens * shape_four / 1024 / 1024 * block_multiply
8476
block_size = batch_size_attention * slice_block_size
8577

8678
split_slice_size = batch_size_attention
8779
if block_size > 6:
8880
do_split = True
89-
#Find something divisible with the shape_one
81+
# Find something divisible with the shape_one
9082
while (split_slice_size * slice_block_size) > 4:
9183
split_slice_size = split_slice_size // 2
9284
if split_slice_size <= 1:
@@ -99,7 +91,7 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.
9991
if split_slice_size * slice_block_size > 6:
10092
slice_block_size2 = shape_one * split_slice_size * shape_four / 1024 / 1024 * block_multiply
10193
do_split_2 = True
102-
#Find something divisible with the batch_size_attention
94+
# Find something divisible with the batch_size_attention
10395
while (split_2_slice_size * slice_block_size2) > 4:
10496
split_2_slice_size = split_2_slice_size // 2
10597
if split_2_slice_size <= 1:
@@ -155,8 +147,3 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.
155147
query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal
156148
)
157149
return hidden_states
158-
159-
def attention_init():
160-
#ARC GPUs can't allocate more than 4GB to a single block:
161-
torch.bmm = torch_bmm
162-
torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention

modules/intel/ipex/hijacks.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,31 @@ def linalg_solve(A, B, *args, **kwargs): # pylint: disable=invalid-name
9393
else:
9494
return original_linalg_solve(A, B, *args, **kwargs)
9595

96+
if torch.xpu.has_fp64_dtype():
97+
original_torch_bmm = torch.bmm
98+
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
99+
else:
100+
# 64 bit attention workarounds for Alchemist:
101+
try:
102+
from .attention import torch_bmm_32_bit as original_torch_bmm
103+
from .attention import scaled_dot_product_attention_32_bit as original_scaled_dot_product_attention
104+
except Exception: # pylint: disable=broad-exception-caught
105+
original_torch_bmm = torch.bmm
106+
original_scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
107+
108+
# dtype errors:
109+
def torch_bmm(input, mat2, *, out=None):
110+
if input.dtype != mat2.dtype:
111+
mat2 = mat2.to(input.dtype)
112+
return original_torch_bmm(input, mat2, out=out)
113+
114+
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False):
115+
if query.dtype != key.dtype:
116+
key = key.to(dtype=query.dtype)
117+
if query.dtype != value.dtype:
118+
value = value.to(dtype=query.dtype)
119+
return original_scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
120+
96121
@property
97122
def is_cuda(self):
98123
return self.device.type == 'xpu'
@@ -184,11 +209,16 @@ def ipex_hijacks():
184209
lambda orig_func, *args, **kwargs: True)
185210

186211
# Functions that make compile mad with CondFunc:
187-
torch.utils.data.dataloader._MultiProcessingDataLoaderIter._shutdown_workers = _shutdown_workers
188212
torch.nn.DataParallel = DummyDataParallel
213+
torch.utils.data.dataloader._MultiProcessingDataLoaderIter._shutdown_workers = _shutdown_workers
214+
189215
torch.autocast = ipex_autocast
190-
torch.cat = torch_cat
191-
torch.linalg.solve = linalg_solve
216+
torch.backends.cuda.sdp_kernel = return_null_context
192217
torch.UntypedStorage.is_cuda = is_cuda
218+
193219
torch.nn.functional.interpolate = interpolate
194-
torch.backends.cuda.sdp_kernel = return_null_context
220+
torch.linalg.solve = linalg_solve
221+
222+
torch.bmm = torch_bmm
223+
torch.cat = torch_cat
224+
torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention

0 commit comments

Comments
 (0)