Skip to content

Commit fa414ba

Browse files
committed
spark: sql相关
1 parent a5da4ce commit fa414ba

File tree

9 files changed

+326
-0
lines changed

9 files changed

+326
-0
lines changed

middleware/spark/SparkSql.md

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# 操作方式
2+
- SQL: 使用sql语法操作
3+
- DSL: 使用编程方式操作
4+
5+
# 自定义函数对象
6+
- UDF: 用户自定义函数,类似map操作
7+
- UDAF: 用户自定义聚合函数,类似reduce操作

middleware/spark/pom.xml

+5
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,11 @@
2323
<artifactId>spark-sql_2.12</artifactId>
2424
<version>3.5.1</version>
2525
</dependency>
26+
<dependency>
27+
<groupId>com.fasterxml.jackson.core</groupId>
28+
<artifactId>jackson-core</artifactId>
29+
<version>2.17.1</version>
30+
</dependency>
2631
</dependencies>
2732

2833
</project>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
package sql;
2+
3+
import java.io.Serializable;
4+
5+
public class Order implements Serializable {
6+
7+
private String id;
8+
private Long userId;
9+
private Long goodId;
10+
private Long count;
11+
12+
public Order() {
13+
}
14+
15+
public Order(String id, Long userId, Long goodId, Long count) {
16+
this.id = id;
17+
this.userId = userId;
18+
this.goodId = goodId;
19+
this.count = count;
20+
}
21+
22+
public String getId() {
23+
return id;
24+
}
25+
26+
public void setId(String id) {
27+
this.id = id;
28+
}
29+
30+
public Long getUserId() {
31+
return userId;
32+
}
33+
34+
public void setUserId(Long userId) {
35+
this.userId = userId;
36+
}
37+
38+
public Long getGoodId() {
39+
return goodId;
40+
}
41+
42+
public void setGoodId(Long goodId) {
43+
this.goodId = goodId;
44+
}
45+
46+
public Long getCount() {
47+
return count;
48+
}
49+
50+
public void setCount(Long count) {
51+
this.count = count;
52+
}
53+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
package sql;
2+
3+
import org.apache.spark.sql.Dataset;
4+
import org.apache.spark.sql.Row;
5+
import org.apache.spark.sql.SparkSession;
6+
7+
import static org.apache.spark.sql.functions.col;
8+
import static org.apache.spark.sql.functions.sum;
9+
10+
public class ReadByDsl {
11+
12+
public static void main(String[] args) {
13+
SparkSession spark = SparkSession
14+
.builder()
15+
.appName("test")
16+
.master("local")
17+
.getOrCreate();
18+
// 需先通过sql.Write生成数据
19+
Dataset<Row> dataset = spark.read().json("output");
20+
dataset.select(col("userId"), col("count"))
21+
.groupBy(col("userId"))
22+
.agg(sum("count"))
23+
.show();
24+
spark.close();
25+
}
26+
27+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
package sql;
2+
3+
import org.apache.spark.sql.SparkSession;
4+
5+
public class ReadBySql {
6+
7+
public static void main(String[] args) {
8+
SparkSession spark = SparkSession
9+
.builder()
10+
.appName("test")
11+
.master("local")
12+
.getOrCreate();
13+
// 需先通过sql.Write生成数据
14+
spark.read().json("output").createOrReplaceTempView("order");
15+
spark.sql("select userId, sum(count) from order group by userId").show();
16+
spark.close();
17+
}
18+
19+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
package sql;
2+
3+
import org.apache.spark.sql.Encoder;
4+
import org.apache.spark.sql.Encoders;
5+
import org.apache.spark.sql.SparkSession;
6+
import org.apache.spark.sql.expressions.Aggregator;
7+
import org.apache.spark.sql.functions;
8+
import scala.Serializable;
9+
10+
import java.util.HashMap;
11+
import java.util.Map;
12+
13+
public class UdafMulti {
14+
15+
public static void main(String[] args) {
16+
SparkSession spark = SparkSession
17+
.builder()
18+
.appName("test")
19+
.master("local")
20+
.getOrCreate();
21+
// 需先通过sql.Write生成数据
22+
spark.read().json("output").createOrReplaceTempView("order");
23+
spark.udf().register("userBuy", functions.udaf(new UserBuy(), Encoders.bean(Order.class)));
24+
spark.sql("select goodId, userBuy(id, userId, goodId, count) as userBuy from order group by goodId").show();
25+
spark.close();
26+
}
27+
28+
public static class UdafBuffer implements Serializable {
29+
private Map<Long, Long> map = new HashMap<>();
30+
31+
public Map<Long, Long> getMap() {
32+
return map;
33+
}
34+
35+
public void setMap(Map<Long, Long> map) {
36+
this.map = map;
37+
}
38+
}
39+
40+
public static class UserBuy extends Aggregator<Order, UdafBuffer, String> {
41+
42+
@Override
43+
public UdafBuffer zero() {
44+
return new UdafBuffer();
45+
}
46+
47+
@Override
48+
public UdafBuffer reduce(UdafBuffer b, Order a) {
49+
b.map.compute(a.getUserId(), (user, count) -> {
50+
if (count == null) {
51+
return Long.valueOf(a.getCount());
52+
} else {
53+
return count + a.getCount();
54+
}
55+
});
56+
return b;
57+
}
58+
59+
@Override
60+
public UdafBuffer merge(UdafBuffer b1, UdafBuffer b2) {
61+
UdafBuffer res = new UdafBuffer();
62+
b1.map.forEach((k, v) -> res.map.compute(k, (user, count) -> {
63+
if (count == null) {
64+
return Long.valueOf(v);
65+
} else {
66+
return count + v;
67+
}
68+
}));
69+
b2.map.forEach((k, v) -> res.map.compute(k, (user, count) -> {
70+
if (count == null) {
71+
return Long.valueOf(v);
72+
} else {
73+
return count + v;
74+
}
75+
}));
76+
return res;
77+
}
78+
79+
@Override
80+
public String finish(UdafBuffer reduction) {
81+
return reduction.map.toString();
82+
}
83+
84+
@Override
85+
public Encoder<UdafBuffer> bufferEncoder() {
86+
return Encoders.bean(UdafBuffer.class);
87+
}
88+
89+
@Override
90+
public Encoder<String> outputEncoder() {
91+
return Encoders.STRING();
92+
}
93+
94+
}
95+
96+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
package sql;
2+
3+
import org.apache.spark.sql.Encoder;
4+
import org.apache.spark.sql.Encoders;
5+
import org.apache.spark.sql.SparkSession;
6+
import org.apache.spark.sql.expressions.Aggregator;
7+
import org.apache.spark.sql.functions;
8+
import scala.Tuple2;
9+
10+
public class UdafSingle {
11+
12+
public static void main(String[] args) {
13+
SparkSession spark = SparkSession
14+
.builder()
15+
.appName("test")
16+
.master("local")
17+
.getOrCreate();
18+
// 需先通过sql.Write生成数据
19+
spark.read().json("output").createOrReplaceTempView("order");
20+
spark.udf().register("a", functions.udaf(new UdafAggregator(), Encoders.LONG()));
21+
spark.sql("select goodId, a(count) as avg from order group by goodId").show();
22+
spark.close();
23+
}
24+
25+
public static class UdafAggregator extends Aggregator<Long, Tuple2<Long, Long>, Double> {
26+
27+
@Override
28+
public Tuple2<Long, Long> zero() {
29+
return new Tuple2<>(0L, 0L);
30+
}
31+
32+
@Override
33+
public Tuple2<Long, Long> reduce(Tuple2<Long, Long> b, Long a) {
34+
return new Tuple2<>(b._1() + a, b._2() + 1);
35+
}
36+
37+
@Override
38+
public Tuple2<Long, Long> merge(Tuple2<Long, Long> b1, Tuple2<Long, Long> b2) {
39+
return new Tuple2<>(b1._1() + b2._1(), b1._2() + b2._2());
40+
}
41+
42+
@Override
43+
public Double finish(Tuple2<Long, Long> reduction) {
44+
return 1.0 * reduction._1() / reduction._2();
45+
}
46+
47+
@Override
48+
public Encoder<Tuple2<Long, Long>> bufferEncoder() {
49+
return Encoders.tuple(Encoders.LONG(), Encoders.LONG());
50+
}
51+
52+
@Override
53+
public Encoder<Double> outputEncoder() {
54+
return Encoders.DOUBLE();
55+
}
56+
57+
}
58+
59+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
package sql;
2+
3+
import org.apache.spark.sql.SparkSession;
4+
import org.apache.spark.sql.api.java.UDF2;
5+
import org.apache.spark.sql.types.DataTypes;
6+
7+
public class Udf {
8+
9+
public static void main(String[] args) {
10+
SparkSession spark = SparkSession
11+
.builder()
12+
.appName("test")
13+
.master("local")
14+
.getOrCreate();
15+
// 需先通过sql.Write生成数据
16+
spark.read().json("output").createOrReplaceTempView("order");
17+
spark.udf().register("prefix", (UDF2<String, Long, String>) (pre, column) -> pre + column, DataTypes.StringType);
18+
spark.sql("select prefix('user', userId) as user, prefix('good', goodId) as good, count from order").show();
19+
spark.close();
20+
}
21+
22+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
package sql;
2+
3+
import org.apache.spark.sql.Dataset;
4+
import org.apache.spark.sql.Row;
5+
import org.apache.spark.sql.SparkSession;
6+
import org.jetbrains.annotations.NotNull;
7+
8+
import java.util.ArrayList;
9+
import java.util.List;
10+
import java.util.Random;
11+
12+
public class Write {
13+
14+
public static void main(String[] args) {
15+
SparkSession spark = SparkSession
16+
.builder()
17+
.appName("test")
18+
.master("local")
19+
.getOrCreate();
20+
Dataset<Row> dataFrame = spark.createDataFrame(getData(), Order.class);
21+
dataFrame.write().json("output");
22+
spark.close();
23+
}
24+
25+
@NotNull
26+
private static List<Order> getData() {
27+
List<Order> orderList = new ArrayList<>(1000);
28+
Random random = new Random();
29+
for (int i = 0; i < 1000; i++) {
30+
orderList.add(new Order(Integer.toString(i),
31+
Math.abs(random.nextLong() % 10),
32+
Math.abs(random.nextLong() % 100),
33+
Math.abs(random.nextLong() % 10)));
34+
}
35+
return orderList;
36+
}
37+
38+
}

0 commit comments

Comments
 (0)