Skip to content

Commit 271cd0d

Browse files
Lingkai KongHyukjinKwon
Lingkai Kong
authored andcommitted
[SPARK-51981][SS] Add JobTags to queryStartedEvent
### What changes were proposed in this pull request? Adding a new jobTags parameter for QueryStartedEvent so that it can be connected to the actual spark connect command that triggered this streaming. Also besides adding the parameter, a fix has been applied to the timestamp because previously json reads the wrong argument ### Why are the changes needed? Without this, there is no way to tell where does this streaming originate from. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Unit test is added ### Was this patch authored or co-authored using generative AI tooling? No Closes #50780 from gjxdxh/lingkai-kong_data/SPARK-51981. Authored-by: Lingkai Kong <[email protected]> Signed-off-by: Hyukjin Kwon <[email protected]>
1 parent 0fb5445 commit 271cd0d

File tree

5 files changed

+126
-17
lines changed

5 files changed

+126
-17
lines changed

python/pyspark/sql/streaming/listener.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
#
1717
import uuid
1818
import json
19-
from typing import Any, Dict, List, Optional, TYPE_CHECKING
19+
from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING
2020
from abc import ABC, abstractmethod
2121

2222
from pyspark.sql import Row
@@ -178,29 +178,44 @@ class QueryStartedEvent:
178178
"""
179179

180180
def __init__(
181-
self, id: uuid.UUID, runId: uuid.UUID, name: Optional[str], timestamp: str
181+
self,
182+
id: uuid.UUID,
183+
runId: uuid.UUID,
184+
name: Optional[str],
185+
timestamp: str,
186+
jobTags: Set[str],
182187
) -> None:
183188
self._id: uuid.UUID = id
184189
self._runId: uuid.UUID = runId
185190
self._name: Optional[str] = name
186191
self._timestamp: str = timestamp
192+
self._jobTags: Set[str] = jobTags
187193

188194
@classmethod
189195
def fromJObject(cls, jevent: "JavaObject") -> "QueryStartedEvent":
196+
job_tags = set()
197+
java_iterator = jevent.jobTags().iterator()
198+
while java_iterator.hasNext():
199+
job_tags.add(java_iterator.next().toString())
200+
190201
return cls(
191202
id=uuid.UUID(jevent.id().toString()),
192203
runId=uuid.UUID(jevent.runId().toString()),
193204
name=jevent.name(),
194205
timestamp=jevent.timestamp(),
206+
jobTags=job_tags,
195207
)
196208

197209
@classmethod
198210
def fromJson(cls, j: Dict[str, Any]) -> "QueryStartedEvent":
211+
# Json4s will convert jobTags to a list, so we need to convert it back to a set.
212+
job_tags = j["jobTags"] if "jobTags" in j else []
199213
return cls(
200214
id=uuid.UUID(j["id"]),
201215
runId=uuid.UUID(j["runId"]),
202216
name=j["name"],
203217
timestamp=j["timestamp"],
218+
jobTags=set(job_tags),
204219
)
205220

206221
@property
@@ -233,6 +248,13 @@ def timestamp(self) -> str:
233248
"""
234249
return self._timestamp
235250

251+
@property
252+
def jobTags(self) -> Set[str]:
253+
"""
254+
The job tags of the query.
255+
"""
256+
return self._jobTags
257+
236258

237259
class QueryProgressEvent:
238260
"""

python/pyspark/sql/tests/streaming/test_streaming_listener.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def check_start_event(self, event):
4545
datetime.strptime(event.timestamp, "%Y-%m-%dT%H:%M:%S.%fZ")
4646
except ValueError:
4747
self.fail("'%s' is not in ISO 8601 format.")
48+
self.assertTrue(isinstance(event.jobTags, set))
4849

4950
def check_progress_event(self, event, is_stateful):
5051
"""Check QueryProgressEvent"""
@@ -287,7 +288,7 @@ def get_number_of_public_methods(clz):
287288
get_number_of_public_methods(
288289
"org.apache.spark.sql.streaming.StreamingQueryListener$QueryStartedEvent"
289290
),
290-
15,
291+
16,
291292
msg,
292293
)
293294
self.assertEqual(
@@ -451,20 +452,38 @@ def verify(test_listener):
451452
verify(TestListenerV2())
452453

453454
def test_query_started_event_fromJson(self):
454-
start_event = """
455+
start_event_old = """
455456
{
456457
"id" : "78923ec2-8f4d-4266-876e-1f50cf3c283b",
457458
"runId" : "55a95d45-e932-4e08-9caa-0a8ecd9391e8",
458459
"name" : null,
459460
"timestamp" : "2023-06-09T18:13:29.741Z"
460461
}
461462
"""
462-
start_event = QueryStartedEvent.fromJson(json.loads(start_event))
463-
self.check_start_event(start_event)
464-
self.assertEqual(start_event.id, uuid.UUID("78923ec2-8f4d-4266-876e-1f50cf3c283b"))
465-
self.assertEqual(start_event.runId, uuid.UUID("55a95d45-e932-4e08-9caa-0a8ecd9391e8"))
466-
self.assertIsNone(start_event.name)
467-
self.assertEqual(start_event.timestamp, "2023-06-09T18:13:29.741Z")
463+
start_event_old = QueryStartedEvent.fromJson(json.loads(start_event_old))
464+
self.check_start_event(start_event_old)
465+
self.assertEqual(start_event_old.id, uuid.UUID("78923ec2-8f4d-4266-876e-1f50cf3c283b"))
466+
self.assertEqual(start_event_old.runId, uuid.UUID("55a95d45-e932-4e08-9caa-0a8ecd9391e8"))
467+
self.assertIsNone(start_event_old.name)
468+
self.assertEqual(start_event_old.timestamp, "2023-06-09T18:13:29.741Z")
469+
self.assertEqual(start_event_old.jobTags, set())
470+
471+
start_event_new = """
472+
{
473+
"id" : "78923ec2-8f4d-4266-876e-1f50cf3c283b",
474+
"runId" : "55a95d45-e932-4e08-9caa-0a8ecd9391e8",
475+
"name" : null,
476+
"timestamp" : "2023-06-09T18:13:29.741Z",
477+
"jobTags": ["jobTag1", "jobTag2"]
478+
}
479+
"""
480+
start_event_new = QueryStartedEvent.fromJson(json.loads(start_event_new))
481+
self.check_start_event(start_event_new)
482+
self.assertEqual(start_event_new.id, uuid.UUID("78923ec2-8f4d-4266-876e-1f50cf3c283b"))
483+
self.assertEqual(start_event_new.runId, uuid.UUID("55a95d45-e932-4e08-9caa-0a8ecd9391e8"))
484+
self.assertIsNone(start_event_new.name)
485+
self.assertEqual(start_event_new.timestamp, "2023-06-09T18:13:29.741Z")
486+
self.assertEqual(start_event_new.jobTags, set(["jobTag1", "jobTag2"]))
468487

469488
def test_query_terminated_event_fromJson(self):
470489
terminated_json = """

sql/api/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,12 @@ package org.apache.spark.sql.streaming
1919

2020
import java.util.UUID
2121

22+
import scala.jdk.CollectionConverters._
23+
2224
import com.fasterxml.jackson.databind.{DeserializationFeature, ObjectMapper}
2325
import com.fasterxml.jackson.databind.node.TreeTraversingParser
2426
import com.fasterxml.jackson.module.scala.{ClassTagExtensions, DefaultScalaModule}
25-
import org.json4s.{JObject, JString}
27+
import org.json4s.{JArray, JObject, JString}
2628
import org.json4s.JsonAST.JValue
2729
import org.json4s.JsonDSL.{jobject2assoc, pair2Assoc}
2830
import org.json4s.jackson.JsonMethods.{compact, render}
@@ -123,6 +125,14 @@ object StreamingQueryListener extends Serializable {
123125
private val tree = mapper.readTree(json)
124126
def getString(name: String): String = tree.get(name).asText()
125127
def getUUID(name: String): UUID = UUID.fromString(getString(name))
128+
def getStringArray(name: String): List[String] = {
129+
val node = tree.get(name)
130+
if (node.isArray()) {
131+
node.elements().asScala.map(_.asText()).toList
132+
} else {
133+
List()
134+
}
135+
}
126136
def getProgress(name: String): StreamingQueryProgress = {
127137
val parser = new TreeTraversingParser(tree.get(name), mapper)
128138
parser.readValueAs(classOf[StreamingQueryProgress])
@@ -146,24 +156,32 @@ object StreamingQueryListener extends Serializable {
146156
* User-specified name of the query, null if not specified.
147157
* @param timestamp
148158
* The timestamp to start a query.
159+
* @param jobTags
160+
* The job tags that have been assigned to all the jobs started by this thread
149161
* @since 2.1.0
150162
*/
151163
@Evolving
152164
class QueryStartedEvent private[sql] (
153165
val id: UUID,
154166
val runId: UUID,
155167
val name: String,
156-
val timestamp: String)
168+
val timestamp: String,
169+
val jobTags: Set[String])
157170
extends Event
158171
with Serializable {
159172

173+
def this(id: UUID, runId: UUID, name: String, timestamp: String) = {
174+
this(id, runId, name, timestamp, Set())
175+
}
176+
160177
def json: String = compact(render(jsonValue))
161178

162179
private def jsonValue: JValue = {
163180
("id" -> JString(id.toString)) ~
164181
("runId" -> JString(runId.toString)) ~
165182
("name" -> JString(name)) ~
166-
("timestamp" -> JString(timestamp))
183+
("timestamp" -> JString(timestamp)) ~
184+
("jobTags" -> JArray(jobTags.toList.map(JString)))
167185
}
168186
}
169187

@@ -175,7 +193,8 @@ object StreamingQueryListener extends Serializable {
175193
parser.getUUID("id"),
176194
parser.getUUID("runId"),
177195
parser.getString("name"),
178-
parser.getString("name"))
196+
parser.getString("timestamp"),
197+
parser.getStringArray("jobTags").toSet)
179198
}
180199
}
181200

sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,14 @@ abstract class StreamExecution(
288288
// `postEvent` does not throw non fatal exception.
289289
val startTimestamp = triggerClock.getTimeMillis()
290290
postEvent(
291-
new QueryStartedEvent(id, runId, name, progressReporter.formatTimestamp(startTimestamp)))
291+
new QueryStartedEvent(
292+
id,
293+
runId,
294+
name,
295+
progressReporter.formatTimestamp(startTimestamp),
296+
sparkSession.sparkContext.getJobTags()
297+
)
298+
)
292299

293300
// Unblock starting thread
294301
startLatch.countDown()

sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -262,12 +262,28 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter {
262262
assert(newEvent.id === event.id)
263263
assert(newEvent.runId === event.runId)
264264
assert(newEvent.name === event.name)
265+
assert(newEvent.timestamp === event.timestamp)
266+
assert(newEvent.jobTags === event.jobTags)
265267
}
266268

267269
testSerialization(
268-
new QueryStartedEvent(UUID.randomUUID, UUID.randomUUID, "name", "2016-12-05T20:54:20.827Z"))
270+
new QueryStartedEvent(
271+
UUID.randomUUID,
272+
UUID.randomUUID,
273+
"name",
274+
"2016-12-05T20:54:20.827Z",
275+
Set()
276+
)
277+
)
269278
testSerialization(
270-
new QueryStartedEvent(UUID.randomUUID, UUID.randomUUID, null, "2016-12-05T20:54:20.827Z"))
279+
new QueryStartedEvent(
280+
UUID.randomUUID,
281+
UUID.randomUUID,
282+
null,
283+
"2016-12-05T20:54:20.827Z",
284+
Set("tag1", "tag2")
285+
)
286+
)
271287
}
272288

273289
test("QueryProgressEvent serialization") {
@@ -349,6 +365,32 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter {
349365
}
350366
}
351367

368+
test("QueryStartedEvent has the right jobTags set") {
369+
val session = spark.newSession()
370+
val collector = new EventCollectorV2
371+
val jobTag1 = "test-jobTag1"
372+
val jobTag2 = "test-jobTag2"
373+
374+
def runQuery(session: SparkSession): Unit = {
375+
collector.reset()
376+
session.sparkContext.addJobTag(jobTag1)
377+
session.sparkContext.addJobTag(jobTag2)
378+
val mem = MemoryStream[Int](implicitly[Encoder[Int]], session.sqlContext)
379+
testStream(mem.toDS())(
380+
AddData(mem, 1, 2, 3),
381+
CheckAnswer(1, 2, 3)
382+
)
383+
session.sparkContext.listenerBus.waitUntilEmpty()
384+
session.sparkContext.clearJobTags()
385+
}
386+
387+
withListenerAdded(collector, session) {
388+
runQuery(session)
389+
assert(collector.startEvent !== null)
390+
assert(collector.startEvent.jobTags === Set(jobTag1, jobTag2))
391+
}
392+
}
393+
352394
test("listener only posts events from queries started in the related sessions") {
353395
val session1 = spark.newSession()
354396
val session2 = spark.newSession()

0 commit comments

Comments
 (0)