-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[jvm-packages] java.lang.NullPointerException: null at ml.dmlc.xgboost4j.java.Booster.predict #5957
Comments
Because the model is trained through Python Sklearn, incompatibilities later occur.To save time, the algorithm team moved the Sklearn trained XGB model one layer over the Python XgBoost package.I wonder if that's what caused it |
Which version of XGBoost are you using? Previously we fixed a bug that jvm package doesn't throw exception correctly when prediction fail and continue with an empty prediction buffer. |
Version 1.0 of the company's algorithm platform is used, and version 0.9.0 of the algorithm project is used because of version compatibility issues.Algorithm colleagues used Python to convert the 1.0 model file to 0.9.0. I wonder if it is caused by this transformation |
I would suggest wait for 1.2 (#5734) and try again, we have some important bug fixes in this release. Also I would suggest using the same or later xgboost version for prediction. XGBoost's binary model is backward compatible, moving forward, JSON based model is recommended. |
I hit the same problem with 1.2.0. So the problem is still here. |
I also got the same problem. is there a workaround? |
This is a big problem for me, it failed jobs in production. |
@ranInc Are you using the latest version of XGBoost? So far we are not aware of the exact cause of this issue. We will address it on a best-effort basis, and since there's no guarantee as to when the issue could be addressed, I suggest that you investigate an alternative in the meanwhile. |
@ranInc You can help us by providing a small example program we (developers) can run on our own machine. |
I am running 1.2.0, the latest jar on maven repository. For the example: I will try and pinpoint the specific model/data that causing the job to fail later. |
Hi, this is how you recreate the bug (keep in mind that if you do not do the repartition here, it works - so it has something to do with amount of data or type of data in each partition):
|
Do you have any idea when this can be addressed? |
@ranInc Not yet. We'll let you know when we get around fixing the bug. Also, can you post the code in Scala? I don't think we ever officially supported the use of PySpark with XGBoost. |
import org.apache.spark.ml.{Pipeline, PipelineModel}
val df = spark.read.parquet("/tmp/6620294785024229130_features").repartition(200).persist()
df.count()
val model = PipelineModel.read.load("/tmp/6620294785024229130_model_xg_only")
val predictions = model.transform(df)
predictions.persist()
predictions.count()
predictions.show() |
Another pointer, |
I guess no one is working on this? |
Yeah sorry our hands are quite full right now. We'll get around to this issue at some point. I respectfully ask for your patience. Thanks. |
@ranInc I had some time today so I tried running the script you provided here. I have reproduced the Strangely, the latest development version (
I'll investigate further. |
I think the error message makes sense now, your input has more features than the model for prediction. Before the jvm package will continue after xgboost failure, resulting into an empty prediction buffer. I added a check guard recently. |
Just make sure the number of columns in your training dataset is greater than or equal to your prediction dataset. |
Hi, This has something to do with rows with all zero/missing features. |
@ranInc Can you post the full Scala program that generated the model? The error message seems to suggest that your model was trained with a single feature. |
I don't think it will help much as the code is very generic and has some propitiatory transformers, The best way to see that the number of features is not a problem is just to filter out the rows with all zero features and use the model - this works without a problem. |
@ranInc I filtered out rows with zero and still facing the same error ( ...
df.na.drop(minNonNulls = 1)
... Is this not the right way to do it?
I want to see how many features are being used at training and at prediction time. The error message
suggests that the model was trained with a single feature and prediction is being made with two features. Right now, I only have access to the data frame and the serialized model you uploaded. I lack insight into what went into the model training and what went wrong, hindering me from troubleshooting the issue any further. If your program has some proprietary information, is it possible to produce a clean example? |
import org.apache.spark.ml.linalg.{SparseVector, Vector}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.{callUDF, col}
.......
val isVectorAllZeros = (col: Vector) => {
col match {
case sparse: SparseVector =>
if(sparse.indices.isEmpty){
true
}else{
false
}
case _ => false
}
}
spark.udf.register("isVectorAllZeros", isVectorAllZeros)
df = df.withColumn("isEmpty",callUDF("isVectorAllZeros",
col("features_6620294785024229130"))).where("isEmpty == false") you can also just re-partition the dataframe like this:
|
How did you ensure this, if VectorAssembler causes to have variable number of features? |
VectorAssembler always creates the same amount of features, it just needs names of columns to grab from. I might be able to run the model creation again and send you the dataframe used for the model - or any other data you need. |
@ranInc Let me ask one more question: is it correct to say that the example data has a sparse column (VectorAssembler) that has at most two features? |
No. The example dataframe here has a vector column. |
@ranInc So all rows have two features, some values are missing and other are not. Got it. I will try your suggestion about filtering empty rows. As you may have guessed, I'm quite new to Spark ecosystem, so debugging effort may prove to be quite difficult. We are currently in need for more developers who knows more about Spark and Scala programming in general. If you personally know someone who would like to help us improve the JVM package of XGBoost, please do let us know. |
@ranInc I tried filtering empty rows according to your suggestion: Program A: Example script, without filtering for empty rows
import org.apache.spark.sql.SparkSession
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.linalg.{SparseVector, Vector}
import org.apache.spark.sql.functions.{callUDF, col}
object Main extends App {
val spark = SparkSession
.builder()
.appName("XGBoost4J-Spark Pipeline Example")
.getOrCreate()
val df = spark.read.parquet("/home/ubuntu/data/6620294785024229130_features").repartition(200).persist()
df.show()
val model = PipelineModel.read.load("/home/ubuntu/data/6620294785024229130_model_xg_only")
val predictions = model.transform(df)
predictions.persist()
predictions.count()
predictions.show()
} Program B: Example with empty row filtering
import org.apache.spark.sql.SparkSession
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.linalg.{SparseVector, Vector}
import org.apache.spark.sql.functions.{callUDF, col}
object Main extends App {
val spark = SparkSession
.builder()
.appName("XGBoost4J-Spark Pipeline Example")
.getOrCreate()
val isVectorAllZeros = (col: Vector) => {
col match {
case sparse: SparseVector => (sparse.indices.isEmpty)
case _ => false
}
}
spark.udf.register("isVectorAllZeros", isVectorAllZeros)
val df = spark.read.parquet("/home/ubuntu/data/6620294785024229130_features").repartition(200).persist()
.withColumn("isEmpty", callUDF("isVectorAllZeros", col("features_6620294785024229130")))
.where("isEmpty == false")
df.show()
val model = PipelineModel.read.load("/home/ubuntu/data/6620294785024229130_model_xg_only")
val predictions = model.transform(df)
predictions.persist()
predictions.count()
predictions.show()
} Some observations
which is odd because, according to @ranInc, the model was trained with data with two features.
is not found in the 1.2.0 version of the C++ codebase. Instead, the warning is found in the Lines 972 to 982 in ea6b117
So does it mean that the 1.2.0 JAR file on Maven Central has libxgboost4j.so from 1.0.0 ?? 🤯 😱
import ctypes
lib = ctypes.cdll.LoadLibrary('./libxgboost4j.so')
major = ctypes.c_int()
minor = ctypes.c_int()
patch = ctypes.c_int()
lib.XGBoostVersion(ctypes.byref(major), ctypes.byref(minor), ctypes.byref(patch))
print((major.value, minor.value, patch.value)) # prints (1, 0, 2), indicating version 1.0.2
|
Do you want me to grab the dataframe used to create the model? |
@ranInc My suspicion is that one of the two features in the training data consisted entirely of missing values, setting |
Alright, I think I'll have it ready by tomorrow. |
It seems you are wrong, training data has no missing values. import ml.dmlc.xgboost4j.scala.spark.XGBoostRegressor
import org.apache.spark.ml.Pipeline
import org.apache.spark.sql.SparkSession
import org.apache.spark.ml.PipelineModel
import org.apache.spark.sql.DataFrame
val df = spark.read.parquet("/tmp/6620294785024229130_only_features_creation").persist()
df.count()
val regressor = new XGBoostRegressor()
.setFeaturesCol("features_6620294785024229130")
.setLabelCol("label_6620294785024229130")
.setPredictionCol("prediction")
.setMissing(0.0F)
.setMaxDepth(3)
.setNumRound(100)
.setNumWorkers(1)
val pipeline = new Pipeline().setStages(Array(regressor))
val model = pipeline.fit(df)
val pred = spark.read.parquet("/tmp/6620294785024229130_features").persist()
pred.count()
pred.where("account_code == 4011593987").show()
model.transform(pred.where("account_code == 4011593987")).show() |
Thank you for posting the end-to-end example. The end-to-end example produced the NullPointerException on my machine, using the 1.2.0 version of XGBoost4J-Spark. On the other hand, the example runs successfully (no error) when I switched to the 1.2.1 patch version of XGBoost4J-Spark. I also tried 1.3.0-RC1 (available here) and the example also ran successfully. @ranInc Can you try the 1.2.1 patch release from Maven Central? Also, try 1.3.0-RC1 if you are feeling more adventurous. |
Possibly resolved by #6426 |
Closing this for now. Feel free to open a new issue if you run into a problem with 1.3.0 or 1.2.1 (patch) release. |
NPE exceptions occur when predicted through the JAVA API.
java.lang.NullPointerException: null
at ml.dmlc.xgboost4j.java.Booster.predict(Booster.java:309)
at ml.dmlc.xgboost4j.java.Booster.predict(Booster.java:375)
at com.tuhu.predict.predict.BaseModelPredict.predict(BaseModelPredict.java:71)
at com.tuhu.predict.predict.XgboostFindPageModelPredict.predict(XgboostFindPageModelPredict.java:53)
at com.tuhu.predict.service.impl.MlpFindPageFeatureServiceImpl.featureProcess(MlpFindPageFeatureServiceImpl.java:65)
at com.tuhu.predict.api.controller.MlpFindPageController.recommendPredict(MlpFindPageController.java:49)
at com.tuhu.predict.api.controller.MlpFindPageController$$FastClassBySpringCGLIB$$f694b9ff.invoke()
at org.springframework.cglib.proxy.MethodProxy.invoke(MethodProxy.java:204)
at org.springframework.aop.framework.CglibAopProxy$CglibMethodInvocation.invokeJoinpoint(CglibAopProxy.java:746)
at org.springframework.aop.framework.ReflectiveMethodInvocation.proceed(ReflectiveMethodInvocation.java:163)
at org.springframework.aop.framework.adapter.MethodBeforeAdviceInterceptor.invoke(MethodBeforeAdviceInterceptor.java:52)
at org.springframework.aop.framework.ReflectiveMethodInvocation.proceed(ReflectiveMethodInvocation.java:174)
at org.springframework.aop.aspectj.AspectJAfterAdvice.invoke(AspectJAfterAdvice.java:47)
at org.springframework.aop.framework.ReflectiveMethodInvocation.proceed(ReflectiveMethodInvocation.java:174)
at org.springframework.aop.framework.adapter.AfterReturningAdviceInterceptor.invoke(AfterReturningAdviceInterceptor.java:52)
at org.springframework.aop.framework.ReflectiveMethodInvocation.proceed(ReflectiveMethodInvocation.java:174)
at org.springframework.aop.aspectj.AspectJAfterThrowingAdvice.invoke(AspectJAfterThrowingAdvice.java:62)
at org.springframework.aop.framework.ReflectiveMethodInvocation.proceed(ReflectiveMethodInvocation.java:174)
at org.springframework.aop.aspectj.MethodInvocationProceedingJoinPoint.proceed(MethodInvocationProceedingJoinPoint.java:88)
at com.tuhu.springcloud.common.annotation.AbstractControllerLogAspect.doAround(AbstractControllerLogAspect.java:104)
at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.lang.reflect.Method.invoke(Method.java:498)
at org.springframework.aop.aspectj.AbstractAspectJAdvice.invokeAdviceMethodWithGivenArgs(AbstractAspectJAdvice.java:644)
at org.springframework.aop.aspectj.AbstractAspectJAdvice.invokeAdviceMethod(AbstractAspectJAdvice.java:633)
at org.springframework.aop.aspectj.AspectJAroundAdvice.invoke(AspectJAroundAdvice.java:70)
at org.springframework.aop.framework.ReflectiveMethodInvocation.proceed(ReflectiveMethodInvocation.java:174)
at org.springframework.aop.interceptor.ExposeInvocationInterceptor.invoke(ExposeInvocationInterceptor.java:92)
at org.springframework.aop.framework.ReflectiveMethodInvocation.proceed(ReflectiveMethodInvocation.java:185)
at org.springframework.aop.framework.CglibAopProxy$DynamicAdvisedInterceptor.intercept(CglibAopProxy.java:688)
at com.tuhu.predict.api.controller.MlpFindPageController$$EnhancerBySpringCGLIB$$560ed775.recommendPredict()
at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.lang.reflect.Method.invoke(Method.java:498)
at org.springframework.web.method.support.InvocableHandlerMethod.doInvoke(InvocableHandlerMethod.java:209)
at org.springframework.web.method.support.InvocableHandlerMethod.invokeForRequest(InvocableHandlerMethod.java:136)
at org.springframework.web.servlet.mvc.method.annotation.ServletInvocableHandlerMethod.invokeAndHandle(ServletInvocableHandlerMethod.java:102)
at org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerAdapter.invokeHandlerMethod(RequestMappingHandlerAdapter.java:877)
at org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerAdapter.handleInternal(RequestMappingHandlerAdapter.java:783)
at org.springframework.web.servlet.mvc.method.AbstractHandlerMethodAdapter.handle(AbstractHandlerMethodAdapter.java:87)
at org.springframework.web.servlet.DispatcherServlet.doDispatch(DispatcherServlet.java:991)
at org.springframework.web.servlet.DispatcherServlet.doService(DispatcherServlet.java:925)
at org.springframework.web.servlet.FrameworkServlet.processRequest(FrameworkServlet.java:974)
at org.springframework.web.servlet.FrameworkServlet.doPost(FrameworkServlet.java:877)
at javax.servlet.http.HttpServlet.service(HttpServlet.java:661)
at org.springframework.web.servlet.FrameworkServlet.service(FrameworkServlet.java:851)
at javax.servlet.http.HttpServlet.service(HttpServlet.java:742)
at org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:231)
at org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
at org.apache.tomcat.websocket.server.WsFilter.doFilter(WsFilter.java:52)
at org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)
at org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
at com.tuhu.soter.starter.filter.SoterDefaultFilter.doFilter(SoterDefaultFilter.java:79)
at org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)
at org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
at com.tuhu.boot.logback.filter.LogFilter.doFilter(LogFilter.java:54)
at org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)
at org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
at org.springframework.boot.actuate.metrics.web.servlet.WebMvcMetricsFilter.filterAndRecordMetrics(WebMvcMetricsFilter.java:158)
at org.springframework.boot.actuate.metrics.web.servlet.WebMvcMetricsFilter.filterAndRecordMetrics(WebMvcMetricsFilter.java:126)
at org.springframework.boot.actuate.metrics.web.servlet.WebMvcMetricsFilter.doFilterInternal(WebMvcMetricsFilter.java:111)
at org.springframework.web.filter.OncePerRequestFilter.doFilter(OncePerRequestFilter.java:107)
at org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)
at org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
at org.springframework.boot.actuate.web.trace.servlet.HttpTraceFilter.doFilterInternal(HttpTraceFilter.java:90)
at org.springframework.web.filter.OncePerRequestFilter.doFilter(OncePerRequestFilter.java:107)
at org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)
at org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
at com.tuhu.boot.common.filter.HeartbeatFilter.doFilter(HeartbeatFilter.java:42)
at org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)
at org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
at com.tuhu.boot.common.filter.MDCFilter.doFilter(MDCFilter.java:47)
at org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)
at org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
at org.springframework.web.filter.RequestContextFilter.doFilterInternal(RequestContextFilter.java:99)
at org.springframework.web.filter.OncePerRequestFilter.doFilter(OncePerRequestFilter.java:107)
at org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)
at org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
at org.springframework.web.filter.HttpPutFormContentFilter.doFilterInternal(HttpPutFormContentFilter.java:109)
at org.springframework.web.filter.OncePerRequestFilter.doFilter(OncePerRequestFilter.java:107)
at org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)
at org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
at org.springframework.web.filter.HiddenHttpMethodFilter.doFilterInternal(HiddenHttpMethodFilter.java:93)
at org.springframework.web.filter.OncePerRequestFilter.doFilter(OncePerRequestFilter.java:107)
at org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)
at org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
at org.springframework.web.filter.CharacterEncodingFilter.doFilterInternal(CharacterEncodingFilter.java:200)
at org.springframework.web.filter.OncePerRequestFilter.doFilter(OncePerRequestFilter.java:107)
at org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)
at org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
at org.apache.catalina.core.StandardWrapperValve.invoke(StandardWrapperValve.java:198)
at org.apache.catalina.core.StandardContextValve.invoke(StandardContextValve.java:96)
at org.apache.catalina.authenticator.AuthenticatorBase.invoke(AuthenticatorBase.java:496)
at org.apache.catalina.core.StandardHostValve.invoke(StandardHostValve.java:140)
at org.apache.catalina.valves.ErrorReportValve.invoke(ErrorReportValve.java:81)
at org.apache.catalina.core.StandardEngineValve.invoke(StandardEngineValve.java:87)
at org.apache.catalina.valves.RemoteIpValve.invoke(RemoteIpValve.java:677)
at org.apache.catalina.connector.CoyoteAdapter.service(CoyoteAdapter.java:342)
at org.apache.coyote.http11.Http11Processor.service(Http11Processor.java:803)
at org.apache.coyote.AbstractProcessorLight.process(AbstractProcessorLight.java:66)
at org.apache.coyote.AbstractProtocol$ConnectionHandler.process(AbstractProtocol.java:790)
at org.apache.tomcat.util.net.NioEndpoint$SocketProcessor.doRun(NioEndpoint.java:1468)
at org.apache.tomcat.util.net.SocketProcessorBase.run(SocketProcessorBase.java:49)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
at org.apache.tomcat.util.threads.TaskThread$WrappingRunnable.run(TaskThread.java:61)
at java.lang.Thread.run(Thread.java:748)
The text was updated successfully, but these errors were encountered: