Flexible use of Spark window functions lead and lag for online time length statistics

brief introduction

In data statistics, it is often necessary to count some time-consuming data, such as online time. Some of these data are better to count, and some are a little bit more troublesome. For example, count the online time of users according to the log in and log out.

We can use the window functions lead and lag to complete, which is very convenient. The function of lead is to splice the data in the next n rows of a column to the current row, and lag is to splice the data in the first n rows of a specified column to the current row.


The parameter column is to select the columns to be spliced. The parameter n indicates to move several rows. Generally, one row will be moved. The default value is the default value. If there is no row in front of lag, the default value will be used if there is no row after lead.

The key points to use these two functions are: partitioning and sorting

select  gid, 
        lag(time,1,'0') over (partition by gid order by time) as lag_time, 
        lead(time,1,'0') over (partition by gid order by time) as lead_time
from  table_name;

Partition is grouping. Use partition by to group multiple columns and separate them with commas

order by is used to specify sorting, and comma is used to separate multiple row sequences

The combination of lead and lag can play beyond our imagination.

For example, for online time statistics through login and exit logs, if the requirements are not high, it is easy to directly: user id grouping, time ascending, and then use lead to splice the later exit time to the current login time line.

But considering the problem of cross day and log loss, it is not sure that the first one is login log, and the second one is exit log.

Through the combination of lead and lag, we can easily filter and lose illegal data.

Specific code

import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.api.java.UDF6;
import org.apache.spark.sql.functions;
import org.apache.spark.sql.types.DataTypes;
import org.junit.Before;
import org.junit.Test;

import java.io.Serializable;
import java.time.LocalDateTime;
import java.time.ZoneOffset;
import java.time.format.DateTimeFormatter;
import java.time.temporal.ChronoUnit;
import java.util.LinkedList;
import java.util.List;

public class SparkLoginTimeTest implements Serializable {

    private SparkSession sparkSession;

    public void setUp() {
        sparkSession = SparkSession

    private static List<Info> getInfos() {
        String[] gids = {"10001","10001","10002","10002","10003","10003","10004","10004","10005","10005"};
        LocalDateTime base = LocalDateTime.of(2020, 1, 1,0,0,0);
        LinkedList<Info> infos = new LinkedList<>();
        for(int i=0;i<50;i++){
            Info info = new Info();
            info.setResult(i % 2);
            info.setDate(base.plus(i * 5, ChronoUnit.MINUTES).toInstant(ZoneOffset.UTC).toEpochMilli());
        return infos;

    public void lag(){
        List<Info> infos = getInfos();
        sparkSession.udf().register("accTimes",accTimes(), DataTypes.LongType);

        Dataset<Info> dataset = sparkSession.createDataset(infos, Encoders.bean(Info.class));

        String sql = "select gid,result,date," +
                "lead(date,1,-1) over(partition by gid order by date) lead_date," +
                "lead(result,1,-1) over(partition by gid order by date) lead_result," +
                "lag(result,1,-1) over(partition by gid order by date) lag_result," +
                "lag(date,1,-1) over(partition by gid order by date) lag_date" +
                " from temp";

        Dataset<Row> baseDs = sparkSession.sql(sql);

        Dataset<Row> rs = baseDs.withColumn("acc_times",
                .select("gid", "accTimes");


    private static UDF6<Integer,Long,Integer,Long,Integer,Long,Long> accTimes(){
        return new UDF6<Integer, Long, Integer, Long, Integer, Long, Long>() {
            long dayMill = 86400000;
            public Long call(Integer result, Long time, Integer headResult, Long headTime, Integer lagResult, Long lagTime) {
                if(lagResult == -1){//first line
                    if(result == 1){//Exit, calculate the time from exit to the start of this day
                        return time - (time / dayMill) * dayMill ;
                if(headResult == -1){//Last line
                    if(result == 0){//Enter, count to the end of the day
                        return (time / dayMill + 1) * dayMill - time;
                if(result == 0 && headResult == 1){//The current line is in, and the move down line is out
                    long rs;
                    rs = headTime - time;
                    if(rs > 0) {
                        return rs;
                return 0L;

    public static class Info implements Serializable {
         * User unique identification
        private String gid;
         * Login and exit time
        private Long date;
         * 0-Login, 1-Exit
        private Integer result;

        public Integer getResult() {
            return result;

        public void setResult(Integer result) {
            this.result = result;

        public String getGid() {
            return gid;

        public void setGid(String gid) {
            this.gid = gid;

        public Long getDate() {
            return date;

        public void setDate(Long date) {
            this.date = date;

Other instances

Tags: Programming SQL Java Apache Spark

Posted on Tue, 09 Jun 2020 23:56:54 -0400 by dpiland