4
4
# pylint: disable=protected-access, missing-function-docstring, line-too-long
5
5
6
6
original_torch_bmm = torch .bmm
7
- def torch_bmm (input , mat2 , * , out = None ):
8
- if input .dtype != mat2 .dtype :
9
- mat2 = mat2 .to (input .dtype )
10
-
11
- #ARC GPUs can't allocate more than 4GB to a single block, Slice it:
7
+ def torch_bmm_32_bit (input , mat2 , * , out = None ):
8
+ # ARC GPUs can't allocate more than 4GB to a single block, Slice it:
12
9
batch_size_attention , input_tokens , mat2_shape = input .shape [0 ], input .shape [1 ], mat2 .shape [2 ]
13
10
block_multiply = input .element_size ()
14
11
slice_block_size = input_tokens * mat2_shape / 1024 / 1024 * block_multiply
@@ -17,7 +14,7 @@ def torch_bmm(input, mat2, *, out=None):
17
14
split_slice_size = batch_size_attention
18
15
if block_size > 4 :
19
16
do_split = True
20
- #Find something divisible with the input_tokens
17
+ # Find something divisible with the input_tokens
21
18
while (split_slice_size * slice_block_size ) > 4 :
22
19
split_slice_size = split_slice_size // 2
23
20
if split_slice_size <= 1 :
@@ -30,7 +27,7 @@ def torch_bmm(input, mat2, *, out=None):
30
27
if split_slice_size * slice_block_size > 4 :
31
28
slice_block_size2 = split_slice_size * mat2_shape / 1024 / 1024 * block_multiply
32
29
do_split_2 = True
33
- #Find something divisible with the input_tokens
30
+ # Find something divisible with the input_tokens
34
31
while (split_2_slice_size * slice_block_size2 ) > 4 :
35
32
split_2_slice_size = split_2_slice_size // 2
36
33
if split_2_slice_size <= 1 :
@@ -64,8 +61,8 @@ def torch_bmm(input, mat2, *, out=None):
64
61
return hidden_states
65
62
66
63
original_scaled_dot_product_attention = torch .nn .functional .scaled_dot_product_attention
67
- def scaled_dot_product_attention (query , key , value , attn_mask = None , dropout_p = 0.0 , is_causal = False ):
68
- #ARC GPUs can't allocate more than 4GB to a single block, Slice it:
64
+ def scaled_dot_product_attention_32_bit (query , key , value , attn_mask = None , dropout_p = 0.0 , is_causal = False ):
65
+ # ARC GPUs can't allocate more than 4GB to a single block, Slice it:
69
66
if len (query .shape ) == 3 :
70
67
batch_size_attention , query_tokens , shape_four = query .shape
71
68
shape_one = 1
@@ -74,19 +71,14 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.
74
71
shape_one , batch_size_attention , query_tokens , shape_four = query .shape
75
72
no_shape_one = False
76
73
77
- if query .dtype != key .dtype :
78
- key = key .to (dtype = query .dtype )
79
- if query .dtype != value .dtype :
80
- value = value .to (dtype = query .dtype )
81
-
82
74
block_multiply = query .element_size ()
83
75
slice_block_size = shape_one * query_tokens * shape_four / 1024 / 1024 * block_multiply
84
76
block_size = batch_size_attention * slice_block_size
85
77
86
78
split_slice_size = batch_size_attention
87
79
if block_size > 6 :
88
80
do_split = True
89
- #Find something divisible with the shape_one
81
+ # Find something divisible with the shape_one
90
82
while (split_slice_size * slice_block_size ) > 4 :
91
83
split_slice_size = split_slice_size // 2
92
84
if split_slice_size <= 1 :
@@ -99,7 +91,7 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.
99
91
if split_slice_size * slice_block_size > 6 :
100
92
slice_block_size2 = shape_one * split_slice_size * shape_four / 1024 / 1024 * block_multiply
101
93
do_split_2 = True
102
- #Find something divisible with the batch_size_attention
94
+ # Find something divisible with the batch_size_attention
103
95
while (split_2_slice_size * slice_block_size2 ) > 4 :
104
96
split_2_slice_size = split_2_slice_size // 2
105
97
if split_2_slice_size <= 1 :
@@ -155,8 +147,3 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.
155
147
query , key , value , attn_mask = attn_mask , dropout_p = dropout_p , is_causal = is_causal
156
148
)
157
149
return hidden_states
158
-
159
- def attention_init ():
160
- #ARC GPUs can't allocate more than 4GB to a single block:
161
- torch .bmm = torch_bmm
162
- torch .nn .functional .scaled_dot_product_attention = scaled_dot_product_attention
0 commit comments